Shortcuts

Source code for torchdyn.core.problems

import torch
from torch import Tensor
import torch.nn as nn
from typing import Callable, Generator, Iterable, Union

from torchdyn.numerics.sensitivity import _gather_odefunc_adjoint, _gather_odefunc_interp_adjoint
from torchdyn.numerics.odeint import odeint, odeint_mshooting
from torchdyn.numerics.solvers.ode import str_to_solver, str_to_ms_solver
from torchdyn.core.utils import standardize_vf_call_signature


[docs]class ODEProblem(nn.Module): def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn.Module], interpolator:Union[str, Callable, None]=None, order:int=1, atol:float=1e-4, rtol:float=1e-4, sensitivity:str='autograd', solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-6, rtol_adjoint:float=1e-6, seminorm:bool=False, integral_loss:Union[Callable, None]=None, optimizable_params:Union[Iterable, Generator]=()): """An ODE Problem coupling a given vector field with solver and sensitivity algorithm to compute gradients w.r.t different quantities. Args: vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`. In the second case, the Callable is automatically wrapped for consistency solver (Union[str, nn.Module]): order (int, optional): Order of the ODE. Defaults to 1. atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-4. rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-4. sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'. solver_adjoint (Union[str, nn.Module, None], optional): ODE solver for the adjoint. Defaults to None. atol_adjoint (float, optional): Defaults to 1e-6. rtol_adjoint (float, optional): Defaults to 1e-6. seminorm (bool, optional): Indicates whether the a seminorm should be used for error estimation during adjoint backsolves. Defaults to False. integral_loss (Union[Callable, None]): Integral loss to optimize for. Defaults to None. optimizable_parameters (Union[Iterable, Generator]): parameters to calculate sensitivies for. Defaults to (). Notes: Integral losses can be passed as generic function or `nn.Modules`. """ super().__init__() # instantiate solver at initialization if type(solver) == str: solver = str_to_solver(solver) if solver_adjoint is None: solver_adjoint = solver else: solver_adjoint = str_to_solver(solver_adjoint) self.solver, self.interpolator, self.atol, self.rtol = solver, interpolator, atol, rtol self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint = solver_adjoint, atol_adjoint, rtol_adjoint self.sensitivity, self.integral_loss = sensitivity, integral_loss # wrap vector field if `t, x` is not the call signature vector_field = standardize_vf_call_signature(vector_field) self.vf, self.order, self.sensalg = vector_field, order, sensitivity optimizable_params = tuple(optimizable_params) if len(tuple(self.vf.parameters())) > 0: self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()]) elif len(optimizable_params) > 0: # use `optimizable_parameters` if f itself does not have a .parameters() iterable # TODO: advanced logic to retain naming in case `state_dicts()` are passed for k, p in enumerate(optimizable_params): self.vf.register_parameter(f'optimizable_parameter_{k}', p) self.vf_params = torch.cat([p.contiguous().flatten() for p in optimizable_params]) else: print("Your vector field does not have `nn.Parameters` to optimize.") dummy_parameter = nn.Parameter(torch.zeros(1)) self.vf.register_parameter('dummy_parameter', dummy_parameter) self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
[docs] def _autograd_func(self): "create autograd functions for backward pass" self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()]) if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator, self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, problem_type='standard').apply elif self.sensalg == 'interpolated_adjoint': return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator, self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, problem_type='standard').apply
[docs] def odeint(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}): "Returns Tuple(`t_eval`, `solution`)" if self.sensalg == 'autograd': return odeint(self.vf, x, t_span, self.solver, self.atol, self.rtol, interpolator=self.interpolator, save_at=save_at, args=args) else: return self._autograd_func()(self.vf_params, x, t_span, save_at, args)
[docs] def forward(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}): "For safety redirects to intended method `odeint`" return self.odeint(x, t_span, save_at, args)
[docs]class MultipleShootingProblem(ODEProblem): def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd', maxiter:int=4, fine_steps:int=4, solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-6, rtol_adjoint:float=1e-6, seminorm:bool=False, integral_loss:Union[Callable, None]=None): """An ODE problem solved with parallel-in-time methods. Args: vector_field (Callable): the vector field, called with `vector_field(t, x)` for `vector_field(x)`. In the second case, the Callable is automatically wrapped for consistency solver (str): parallel-in-time solver. sensitivity (str, optional): . Defaults to 'autograd'. solver_adjoint (Union[str, nn.Module, None], optional): . Defaults to None. atol_adjoint (float, optional): . Defaults to 1e-6. rtol_adjoint (float, optional): . Defaults to 1e-6. seminorm (bool, optional): . Defaults to False. integral_loss (Union[Callable, None], optional): . Defaults to None. """ super().__init__(vector_field=vector_field, solver=None, interpolator=None, order=1, sensitivity=sensitivity, solver_adjoint=solver_adjoint, atol_adjoint=atol_adjoint, rtol_adjoint=rtol_adjoint, seminorm=seminorm, integral_loss=integral_loss) self.parallel_solver = solver self.fine_steps, self.maxiter = fine_steps, maxiter
[docs] def _autograd_func(self): "create autograd functions for backward pass" self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()]) if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None, self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, 'multiple_shooting', self.fine_steps, self.maxiter).apply elif self.sensalg == 'interpolated_adjoint': return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None, self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, 'multiple_shooting', self.fine_steps, self.maxiter).apply
[docs] def odeint(self, x:Tensor, t_span:Tensor, B0:Tensor=None): "Returns Tuple(`t_eval`, `solution`)" if self.sensalg == 'autograd': return odeint_mshooting(self.vf, x, t_span, self.parallel_solver, B0, self.fine_steps, self.maxiter) else: return self._autograd_func()(self.vf_params, x, t_span, B0)
[docs] def forward(self, x:Tensor, t_span:Tensor, B0:Tensor=None): "For safety redirects to intended method `odeint`" return self.odeint(x, t_span, B0)
[docs]class SDEProblem(nn.Module): def __init__(self): "Extension of `ODEProblem` to SDE" super().__init__() raise NotImplementedError("Hopefully soon...")

© Copyright 2020, Stefano Massaroli & Michael Poli. Revision 7b05e463.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.