# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Union, List, Iterable, Generator
from torchdyn.core.problems import MultipleShootingProblem, ODEProblem, SDEProblem
from torchdyn.numerics import odeint
from torchdyn.core.defunc import DEFunc, DEFuncBase, SDEFunc
from torchdyn.core.utils import standardize_vf_call_signature
import pytorch_lightning as pl
import torch
from torch import Tensor
import torch.nn as nn
import torchsde
import warnings
[docs]class NeuralODE(ODEProblem, pl.LightningModule):
def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn.Module]='tsit5', order:int=1,
atol:float=1e-3, rtol:float=1e-3, sensitivity='autograd', solver_adjoint:Union[str, nn.Module, None] = None,
atol_adjoint:float=1e-4, rtol_adjoint:float=1e-4, interpolator:Union[str, Callable, None]=None, \
integral_loss:Union[Callable, None]=None, seminorm:bool=False, return_t_eval:bool=True, optimizable_params:Union[Iterable, Generator]=()):
"""Generic Neural Ordinary Differential Equation.
Args:
vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`.
In the second case, the Callable is automatically wrapped for consistency
solver (Union[str, nn.Module]):
order (int, optional): Order of the ODE. Defaults to 1.
atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-4.
rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-4.
sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'.
solver_adjoint (Union[str, nn.Module, None], optional): ODE solver for the adjoint. Defaults to None.
atol_adjoint (float, optional): Defaults to 1e-6.
rtol_adjoint (float, optional): Defaults to 1e-6.
integral_loss (Union[Callable, None], optional): Defaults to None.
seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False.
return_t_eval (bool): Whether to return (t_eval, sol) or only sol. Useful for chaining NeuralODEs in `nn.Sequential`.
optimizable_parameters (Union[Iterable, Generator]): parameters to calculate sensitivies for. Defaults to ().
Notes:
In `torchdyn`-style, forward calls to a Neural ODE return both a tensor `t_eval` of time points at which the solution is evaluated
as well as the solution itself. This behavior can be controlled by setting `return_t_eval` to False. Calling `trajectory` also returns
the solution only.
The Neural ODE class automates certain delicate steps that must be done depending on the solver and model used.
The `prep_odeint` method carries out such steps. Neural ODEs wrap `ODEProblem`.
"""
super().__init__(vector_field=standardize_vf_call_signature(vector_field, order, defunc_wrap=True), order=order, sensitivity=sensitivity,
solver=solver, atol=atol, rtol=rtol, solver_adjoint=solver_adjoint, atol_adjoint=atol_adjoint, rtol_adjoint=rtol_adjoint,
seminorm=seminorm, interpolator=interpolator, integral_loss=integral_loss, optimizable_params=optimizable_params)
self._control, self.controlled, self.t_span = None, False, None # data-control conditioning
self.return_t_eval = return_t_eval
if integral_loss is not None: self.vf.integral_loss = integral_loss
self.vf.sensitivity = sensitivity
[docs] def _prep_integration(self, x:Tensor, t_span:Tensor) -> Tensor:
"Performs generic checks before integration. Assigns data control inputs and augments state for CNFs"
# assign a basic value to `t_span` for `forward` calls that do no explicitly pass an integration interval
if t_span is None and self.t_span is None: t_span = torch.linspace(0, 1, 2)
elif t_span is None: t_span = self.t_span
# loss dimension detection routine; for CNF div propagation and integral losses w/ autograd
excess_dims = 0
if (not self.integral_loss is None) and self.sensitivity == 'autograd':
excess_dims += 1
# handle aux. operations required for some jacobian trace CNF estimators e.g Hutchinson's
# as well as datasets-control set to DataControl module
for _, module in self.vf.named_modules():
if hasattr(module, 'trace_estimator'):
if module.noise_dist is not None: module.noise = module.noise_dist.sample((x.shape[0],))
excess_dims += 1
# data-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC
if hasattr(module, '_control'):
self.controlled = True
module._control = x[:, excess_dims:].detach()
return x, t_span
[docs] def forward(self, x:Tensor, t_span:Tensor=None, save_at:Iterable=(), args={}):
x, t_span = self._prep_integration(x, t_span)
t_eval, sol = super().forward(x, t_span, save_at, args)
if self.return_t_eval: return t_eval, sol
else: return sol
def trajectory(self, x:torch.Tensor, t_span:Tensor):
x, t_span = self._prep_integration(x, t_span)
_, sol = odeint(self.vf, x, t_span, solver=self.solver, atol=self.atol, rtol=self.rtol)
return sol
def __repr__(self):
npar = sum([p.numel() for p in self.vf.parameters()])
return f"Neural ODE:\n\t- order: {self.order}\
\n\t- solver: {self.solver}\n\t- adjoint solver: {self.solver_adjoint}\
\n\t- tolerances: relative {self.rtol} absolute {self.atol}\
\n\t- adjoint tolerances: relative {self.rtol_adjoint} absolute {self.atol_adjoint}\
\n\t- num_parameters: {npar}\
\n\t- NFE: {self.vf.nfe}"
[docs]class NeuralSDE(SDEProblem, pl.LightningModule):
def __init__(self, drift_func, diffusion_func, noise_type ='diagonal', sde_type = 'ito', order=1,
sensitivity='autograd', s_span=torch.linspace(0, 1, 2), solver='srk',
atol=1e-4, rtol=1e-4, ds = 1e-3, intloss=None):
"""Generic Neural Stochastic Differential Equation. Follows the same design of the `NeuralODE` class.
Args:
drift_func ([type]): drift function
diffusion_func ([type]): diffusion function
noise_type (str, optional): Defaults to 'diagonal'.
sde_type (str, optional): Defaults to 'ito'.
order (int, optional): Defaults to 1.
sensitivity (str, optional): Defaults to 'autograd'.
s_span ([type], optional): Defaults to torch.linspace(0, 1, 2).
solver (str, optional): Defaults to 'srk'.
atol ([type], optional): Defaults to 1e-4.
rtol ([type], optional): Defaults to 1e-4.
ds ([type], optional): Defaults to 1e-3.
intloss ([type], optional): Defaults to None.
Raises:
NotImplementedError: higher-order Neural SDEs are not yet implemented, raised by setting `order` to >1.
Notes:
The current implementation is rougher around the edges compared to `NeuralODE`, and is not guaranteed to have the same features.
"""
super().__init__(func=SDEFunc(f=drift_func, g=diffusion_func, order=order), order=order, sensitivity=sensitivity, s_span=s_span, solver=solver,
atol=atol, rtol=rtol)
if order != 1: raise NotImplementedError
self.defunc.noise_type, self.defunc.sde_type = noise_type, sde_type
self.adaptive = False
self.intloss = intloss
self._control, self.controlled = None, False # datasets-control
self.ds = ds
def _prep_sdeint(self, x:torch.Tensor):
self.s_span = self.s_span.to(x)
# datasets-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC
excess_dims = 0
for _, module in self.defunc.named_modules():
if hasattr(module, '_control'):
self.controlled = True
module._control = x[:, excess_dims:].detach()
return x
[docs] def forward(self, x:torch.Tensor):
x = self._prep_sdeint(x)
switcher = {
'autograd': self._autograd,
'adjoint': self._adjoint,
}
sdeint = switcher.get(self.sensitivity)
out = sdeint(x)
return out
def trajectory(self, x:torch.Tensor, s_span:torch.Tensor):
x = self._prep_sdeint(x)
sol = torchsde.sdeint(self.defunc, x, s_span, rtol=self.rtol, atol=self.atol,
method=self.solver, dt=self.ds)
return sol
def backward_trajectory(self, x:torch.Tensor, s_span:torch.Tensor):
raise NotImplementedError
def _autograd(self, x):
self.defunc.intloss, self.defunc.sensitivity = self.intloss, self.sensitivity
return torchsde.sdeint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol,
adaptive=self.adaptive, method=self.solver, dt=self.ds)[-1]
def _adjoint(self, x):
out = torchsde.sdeint_adjoint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol,
adaptive=self.adaptive, method=self.solver, dt=self.ds)[-1]
return out
[docs]class MultipleShootingLayer(MultipleShootingProblem, pl.LightningModule):
def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd',
maxiter:int=4, fine_steps:int=4, solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-6,
rtol_adjoint:float=1e-6, seminorm:bool=False, integral_loss:Union[Callable, None]=None):
"""Multiple Shooting Layer as defined in https://arxiv.org/abs/2106.03885.
Uses parallel-in-time ODE solvers to solve an ODE parametrized by neural network `vector_field`.
Args:
vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`.
In the second case, the Callable is automatically wrapped for consistency
solver (Union[str, nn.Module]): parallel-in-time solver, ['zero', 'direct']
sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'.
maxiter (int): number of iterations of the root finding routine defined to parallel solve the ODE.
fine_steps (int): number of fine-solver steps to perform in each subinterval of the parallel solution.
solver_adjoint (Union[str, nn.Module, None], optional): Standard sequential ODE solver for the adjoint system.
atol_adjoint (float, optional): Defaults to 1e-6.
rtol_adjoint (float, optional): Defaults to 1e-6.
integral_loss (Union[Callable, None], optional): Currently not implemented
seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False.
Notes:
The number of shooting parameters (first dimension in `B0`) is implicitly defined by passing `t_span` during forward calls.
For example, a `t_span=torch.linspace(0, 1, 10)` will define 9 intervals and 10 shooting parameters.
For the moment only a thin wrapper around `MultipleShootingProblem`. At this level will be convenience routines for special
initializations of shooting parameters `B0`, as well as usual convenience checks for integral losses.
"""
super().__init__(vector_field, solver, sensitivity, maxiter, fine_steps, solver_adjoint, atol_adjoint,
rtol_adjoint, seminorm, integral_loss)