import torch
from torch import Tensor
import torch.nn as nn
from typing import Callable, Generator, Iterable, Union
from torchdyn.numerics.sensitivity import _gather_odefunc_adjoint, _gather_odefunc_interp_adjoint
from torchdyn.numerics.odeint import odeint, odeint_mshooting
from torchdyn.numerics.solvers.ode import str_to_solver, str_to_ms_solver
from torchdyn.core.utils import standardize_vf_call_signature
[docs]class ODEProblem(nn.Module):
def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn.Module], interpolator:Union[str, Callable, None]=None, order:int=1,
atol:float=1e-4, rtol:float=1e-4, sensitivity:str='autograd', solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-6,
rtol_adjoint:float=1e-6, seminorm:bool=False, integral_loss:Union[Callable, None]=None, optimizable_params:Union[Iterable, Generator]=()):
"""An ODE Problem coupling a given vector field with solver and sensitivity algorithm to compute gradients w.r.t different quantities.
Args:
vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`.
In the second case, the Callable is automatically wrapped for consistency
solver (Union[str, nn.Module]):
order (int, optional): Order of the ODE. Defaults to 1.
atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-4.
rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-4.
sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'.
solver_adjoint (Union[str, nn.Module, None], optional): ODE solver for the adjoint. Defaults to None.
atol_adjoint (float, optional): Defaults to 1e-6.
rtol_adjoint (float, optional): Defaults to 1e-6.
seminorm (bool, optional): Indicates whether the a seminorm should be used for error estimation during adjoint backsolves. Defaults to False.
integral_loss (Union[Callable, None]): Integral loss to optimize for. Defaults to None.
optimizable_parameters (Union[Iterable, Generator]): parameters to calculate sensitivies for. Defaults to ().
Notes:
Integral losses can be passed as generic function or `nn.Modules`.
"""
super().__init__()
# instantiate solver at initialization
if type(solver) == str: solver = str_to_solver(solver)
if solver_adjoint is None:
solver_adjoint = solver
else: solver_adjoint = str_to_solver(solver_adjoint)
self.solver, self.interpolator, self.atol, self.rtol = solver, interpolator, atol, rtol
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint = solver_adjoint, atol_adjoint, rtol_adjoint
self.sensitivity, self.integral_loss = sensitivity, integral_loss
# wrap vector field if `t, x` is not the call signature
vector_field = standardize_vf_call_signature(vector_field)
self.vf, self.order, self.sensalg = vector_field, order, sensitivity
optimizable_params = tuple(optimizable_params)
if len(tuple(self.vf.parameters())) > 0:
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
elif len(optimizable_params) > 0:
# use `optimizable_parameters` if f itself does not have a .parameters() iterable
# TODO: advanced logic to retain naming in case `state_dicts()` are passed
for k, p in enumerate(optimizable_params): self.vf.register_parameter(f'optimizable_parameter_{k}', p)
self.vf_params = torch.cat([p.contiguous().flatten() for p in optimizable_params])
else:
print("Your vector field does not have `nn.Parameters` to optimize.")
dummy_parameter = nn.Parameter(torch.zeros(1))
self.vf.register_parameter('dummy_parameter', dummy_parameter)
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
[docs] def _autograd_func(self):
"create autograd functions for backward pass"
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
problem_type='standard').apply
elif self.sensalg == 'interpolated_adjoint':
return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
problem_type='standard').apply
[docs] def odeint(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}):
"Returns Tuple(`t_eval`, `solution`)"
if self.sensalg == 'autograd':
return odeint(self.vf, x, t_span, self.solver, self.atol, self.rtol, interpolator=self.interpolator,
save_at=save_at, args=args)
else:
return self._autograd_func()(self.vf_params, x, t_span, save_at, args)
[docs] def forward(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}):
"For safety redirects to intended method `odeint`"
return self.odeint(x, t_span, save_at, args)
[docs]class MultipleShootingProblem(ODEProblem):
def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd',
maxiter:int=4, fine_steps:int=4, solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-6,
rtol_adjoint:float=1e-6, seminorm:bool=False, integral_loss:Union[Callable, None]=None):
"""An ODE problem solved with parallel-in-time methods.
Args:
vector_field (Callable): the vector field, called with `vector_field(t, x)` for `vector_field(x)`.
In the second case, the Callable is automatically wrapped for consistency
solver (str): parallel-in-time solver.
sensitivity (str, optional): . Defaults to 'autograd'.
solver_adjoint (Union[str, nn.Module, None], optional): . Defaults to None.
atol_adjoint (float, optional): . Defaults to 1e-6.
rtol_adjoint (float, optional): . Defaults to 1e-6.
seminorm (bool, optional): . Defaults to False.
integral_loss (Union[Callable, None], optional): . Defaults to None.
"""
super().__init__(vector_field=vector_field, solver=None, interpolator=None, order=1,
sensitivity=sensitivity, solver_adjoint=solver_adjoint, atol_adjoint=atol_adjoint,
rtol_adjoint=rtol_adjoint, seminorm=seminorm, integral_loss=integral_loss)
self.parallel_solver = solver
self.fine_steps, self.maxiter = fine_steps, maxiter
[docs] def _autograd_func(self):
"create autograd functions for backward pass"
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
'multiple_shooting', self.fine_steps, self.maxiter).apply
elif self.sensalg == 'interpolated_adjoint':
return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
'multiple_shooting', self.fine_steps, self.maxiter).apply
[docs] def odeint(self, x:Tensor, t_span:Tensor, B0:Tensor=None):
"Returns Tuple(`t_eval`, `solution`)"
if self.sensalg == 'autograd':
return odeint_mshooting(self.vf, x, t_span, self.parallel_solver, B0, self.fine_steps, self.maxiter)
else:
return self._autograd_func()(self.vf_params, x, t_span, B0)
[docs] def forward(self, x:Tensor, t_span:Tensor, B0:Tensor=None):
"For safety redirects to intended method `odeint`"
return self.odeint(x, t_span, B0)
[docs]class SDEProblem(nn.Module):
def __init__(self):
"Extension of `ODEProblem` to SDE"
super().__init__()
raise NotImplementedError("Hopefully soon...")