# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Union, Iterable, Generator, Dict
from torchdyn.core.problems import MultipleShootingProblem, ODEProblem, SDEProblem
from torchdyn.numerics import odeint
from torchdyn.core.defunc import SDEFunc
from torchdyn.core.utils import standardize_vf_call_signature
import pytorch_lightning as pl
import torch
from torch import Tensor
import torch.nn as nn
import torchsde
import warnings
[docs]class NeuralODE(ODEProblem, pl.LightningModule):
def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn.Module]='tsit5', order:int=1,
atol:float=1e-3, rtol:float=1e-3, sensitivity='autograd', solver_adjoint:Union[str, nn.Module, None] = None,
atol_adjoint:float=1e-4, rtol_adjoint:float=1e-4, interpolator:Union[str, Callable, None]=None, \
integral_loss:Union[Callable, None]=None, seminorm:bool=False, return_t_eval:bool=True, optimizable_params:Union[Iterable, Generator]=()):
"""Generic Neural Ordinary Differential Equation.
Args:
vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`.
In the second case, the Callable is automatically wrapped for consistency
solver (Union[str, nn.Module]):
order (int, optional): Order of the ODE. Defaults to 1.
atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-4.
rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-4.
sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'.
solver_adjoint (Union[str, nn.Module, None], optional): ODE solver for the adjoint. Defaults to None.
atol_adjoint (float, optional): Defaults to 1e-6.
rtol_adjoint (float, optional): Defaults to 1e-6.
integral_loss (Union[Callable, None], optional): Defaults to None.
seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False.
return_t_eval (bool): Whether to return (t_eval, sol) or only sol. Useful for chaining NeuralODEs in `nn.Sequential`.
optimizable_parameters (Union[Iterable, Generator]): parameters to calculate sensitivies for. Defaults to ().
Notes:
In `torchdyn`-style, forward calls to a Neural ODE return both a tensor `t_eval` of time points at which the solution is evaluated
as well as the solution itself. This behavior can be controlled by setting `return_t_eval` to False. Calling `trajectory` also returns
the solution only.
The Neural ODE class automates certain delicate steps that must be done depending on the solver and model used.
The `prep_odeint` method carries out such steps. Neural ODEs wrap `ODEProblem`.
"""
super().__init__(vector_field=standardize_vf_call_signature(vector_field, order, defunc_wrap=True), order=order, sensitivity=sensitivity,
solver=solver, atol=atol, rtol=rtol, solver_adjoint=solver_adjoint, atol_adjoint=atol_adjoint, rtol_adjoint=rtol_adjoint,
seminorm=seminorm, interpolator=interpolator, integral_loss=integral_loss, optimizable_params=optimizable_params)
self._control, self.controlled, self.t_span = None, False, None # data-control conditioning
self.return_t_eval = return_t_eval
if integral_loss is not None: self.vf.integral_loss = integral_loss
self.vf.sensitivity = sensitivity
[docs] def _prep_integration(self, x:Tensor, t_span:Tensor) -> Tensor:
"Performs generic checks before integration. Assigns data control inputs and augments state for CNFs"
# assign a basic value to `t_span` for `forward` calls that do no explicitly pass an integration interval
if t_span is None and self.t_span is None: t_span = torch.linspace(0, 1, 2)
elif t_span is None: t_span = self.t_span
# loss dimension detection routine; for CNF div propagation and integral losses w/ autograd
excess_dims = 0
if (not self.integral_loss is None) and self.sensitivity == 'autograd':
excess_dims += 1
# handle aux. operations required for some jacobian trace CNF estimators e.g Hutchinson's
# as well as datasets-control set to DataControl module
for _, module in self.vf.named_modules():
if hasattr(module, 'trace_estimator'):
if module.noise_dist is not None: module.noise = module.noise_dist.sample((x.shape[0],))
excess_dims += 1
# data-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC
if hasattr(module, '_control'):
self.controlled = True
module._control = x[:, excess_dims:].detach()
return x, t_span
[docs] def forward(self, x:Union[Tensor, Dict], t_span:Tensor=None, save_at:Iterable=(), args={}):
x, t_span = self._prep_integration(x, t_span)
t_eval, sol = super().forward(x, t_span, save_at, args)
if self.return_t_eval: return t_eval, sol
else: return sol
def trajectory(self, x:torch.Tensor, t_span:Tensor):
x, t_span = self._prep_integration(x, t_span)
_, sol = odeint(self.vf, x, t_span, solver=self.solver, atol=self.atol, rtol=self.rtol)
return sol
def __repr__(self):
npar = sum([p.numel() for p in self.vf.parameters()])
return f"Neural ODE:\n\t- order: {self.order}\
\n\t- solver: {self.solver}\n\t- adjoint solver: {self.solver_adjoint}\
\n\t- tolerances: relative {self.rtol} absolute {self.atol}\
\n\t- adjoint tolerances: relative {self.rtol_adjoint} absolute {self.atol_adjoint}\
\n\t- num_parameters: {npar}\
\n\t- NFE: {self.vf.nfe}"
[docs]class NeuralSDE(SDEProblem, pl.LightningModule):
def __init__(self, drift_func, diffusion_func, noise_type ='diagonal', sde_type = 'ito', order=1,
sensitivity='autograd', s_span=torch.linspace(0, 1, 2), solver='srk',
atol=1e-4, rtol=1e-4, ds = 1e-3, intloss=None):
"""Generic Neural Stochastic Differential Equation. Follows the same design of the `NeuralODE` class.
Args:
drift_func ([type]): drift function
diffusion_func ([type]): diffusion function
noise_type (str, optional): Defaults to 'diagonal'.
sde_type (str, optional): Defaults to 'ito'.
order (int, optional): Defaults to 1.
sensitivity (str, optional): Defaults to 'autograd'.
s_span ([type], optional): Defaults to torch.linspace(0, 1, 2).
solver (str, optional): Defaults to 'srk'.
atol ([type], optional): Defaults to 1e-4.
rtol ([type], optional): Defaults to 1e-4.
ds ([type], optional): Defaults to 1e-3.
intloss ([type], optional): Defaults to None.
Raises:
NotImplementedError: higher-order Neural SDEs are not yet implemented, raised by setting `order` to >1.
Notes:
The current implementation is rougher around the edges compared to `NeuralODE`, and is not guaranteed to have the same features.
"""
super().__init__(func=SDEFunc(f=drift_func, g=diffusion_func, order=order), order=order, sensitivity=sensitivity, s_span=s_span, solver=solver,
atol=atol, rtol=rtol)
if order != 1: raise NotImplementedError
self.defunc.noise_type, self.defunc.sde_type = noise_type, sde_type
self.adaptive = False
self.intloss = intloss
self._control, self.controlled = None, False # datasets-control
self.ds = ds
def _prep_sdeint(self, x:torch.Tensor):
self.s_span = self.s_span.to(x)
# datasets-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC
excess_dims = 0
for _, module in self.defunc.named_modules():
if hasattr(module, '_control'):
self.controlled = True
module._control = x[:, excess_dims:].detach()
return x
[docs] def forward(self, x:torch.Tensor):
x = self._prep_sdeint(x)
switcher = {
'autograd': self._autograd,
'adjoint': self._adjoint,
}
sdeint = switcher.get(self.sensitivity)
out = sdeint(x)
return out
def trajectory(self, x:torch.Tensor, s_span:torch.Tensor):
x = self._prep_sdeint(x)
sol = torchsde.sdeint(self.defunc, x, s_span, rtol=self.rtol, atol=self.atol,
method=self.solver, dt=self.ds)
return sol
def backward_trajectory(self, x:torch.Tensor, s_span:torch.Tensor):
raise NotImplementedError
def _autograd(self, x):
self.defunc.intloss, self.defunc.sensitivity = self.intloss, self.sensitivity
return torchsde.sdeint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol,
adaptive=self.adaptive, method=self.solver, dt=self.ds)[-1]
def _adjoint(self, x):
out = torchsde.sdeint_adjoint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol,
adaptive=self.adaptive, method=self.solver, dt=self.ds)[-1]
return out
[docs]class MultipleShootingLayer(MultipleShootingProblem, pl.LightningModule):
def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd',
maxiter:int=4, fine_steps:int=4, solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-6,
rtol_adjoint:float=1e-6, seminorm:bool=False, integral_loss:Union[Callable, None]=None):
"""Multiple Shooting Layer as defined in https://arxiv.org/abs/2106.03885.
Uses parallel-in-time ODE solvers to solve an ODE parametrized by neural network `vector_field`.
Args:
vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`.
In the second case, the Callable is automatically wrapped for consistency
solver (Union[str, nn.Module]): parallel-in-time solver, ['zero', 'direct']
sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'.
maxiter (int): number of iterations of the root finding routine defined to parallel solve the ODE.
fine_steps (int): number of fine-solver steps to perform in each subinterval of the parallel solution.
solver_adjoint (Union[str, nn.Module, None], optional): Standard sequential ODE solver for the adjoint system.
atol_adjoint (float, optional): Defaults to 1e-6.
rtol_adjoint (float, optional): Defaults to 1e-6.
integral_loss (Union[Callable, None], optional): Currently not implemented
seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False.
Notes:
The number of shooting parameters (first dimension in `B0`) is implicitly defined by passing `t_span` during forward calls.
For example, a `t_span=torch.linspace(0, 1, 10)` will define 9 intervals and 10 shooting parameters.
For the moment only a thin wrapper around `MultipleShootingProblem`. At this level will be convenience routines for special
initializations of shooting parameters `B0`, as well as usual convenience checks for integral losses.
"""
super().__init__(vector_field, solver, sensitivity, maxiter, fine_steps, solver_adjoint, atol_adjoint,
rtol_adjoint, seminorm, integral_loss)