Shortcuts

Source code for torchdyn.core.neuralde

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Union, Iterable, Generator, Dict

from torchdyn.core.problems import MultipleShootingProblem, ODEProblem, SDEProblem
from torchdyn.numerics import odeint
from torchdyn.core.defunc import SDEFunc
from torchdyn.core.utils import standardize_vf_call_signature

import pytorch_lightning as pl
import torch
from torch import Tensor
import torch.nn as nn
import torchsde

import warnings


[docs]class NeuralODE(ODEProblem, pl.LightningModule): def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn.Module]='tsit5', order:int=1, atol:float=1e-3, rtol:float=1e-3, sensitivity='autograd', solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-4, rtol_adjoint:float=1e-4, interpolator:Union[str, Callable, None]=None, \ integral_loss:Union[Callable, None]=None, seminorm:bool=False, return_t_eval:bool=True, optimizable_params:Union[Iterable, Generator]=()): """Generic Neural Ordinary Differential Equation. Args: vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`. In the second case, the Callable is automatically wrapped for consistency solver (Union[str, nn.Module]): order (int, optional): Order of the ODE. Defaults to 1. atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-4. rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-4. sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'. solver_adjoint (Union[str, nn.Module, None], optional): ODE solver for the adjoint. Defaults to None. atol_adjoint (float, optional): Defaults to 1e-6. rtol_adjoint (float, optional): Defaults to 1e-6. integral_loss (Union[Callable, None], optional): Defaults to None. seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False. return_t_eval (bool): Whether to return (t_eval, sol) or only sol. Useful for chaining NeuralODEs in `nn.Sequential`. optimizable_parameters (Union[Iterable, Generator]): parameters to calculate sensitivies for. Defaults to (). Notes: In `torchdyn`-style, forward calls to a Neural ODE return both a tensor `t_eval` of time points at which the solution is evaluated as well as the solution itself. This behavior can be controlled by setting `return_t_eval` to False. Calling `trajectory` also returns the solution only. The Neural ODE class automates certain delicate steps that must be done depending on the solver and model used. The `prep_odeint` method carries out such steps. Neural ODEs wrap `ODEProblem`. """ super().__init__(vector_field=standardize_vf_call_signature(vector_field, order, defunc_wrap=True), order=order, sensitivity=sensitivity, solver=solver, atol=atol, rtol=rtol, solver_adjoint=solver_adjoint, atol_adjoint=atol_adjoint, rtol_adjoint=rtol_adjoint, seminorm=seminorm, interpolator=interpolator, integral_loss=integral_loss, optimizable_params=optimizable_params) self._control, self.controlled, self.t_span = None, False, None # data-control conditioning self.return_t_eval = return_t_eval if integral_loss is not None: self.vf.integral_loss = integral_loss self.vf.sensitivity = sensitivity
[docs] def _prep_integration(self, x:Tensor, t_span:Tensor) -> Tensor: "Performs generic checks before integration. Assigns data control inputs and augments state for CNFs" # assign a basic value to `t_span` for `forward` calls that do no explicitly pass an integration interval if t_span is None and self.t_span is None: t_span = torch.linspace(0, 1, 2) elif t_span is None: t_span = self.t_span # loss dimension detection routine; for CNF div propagation and integral losses w/ autograd excess_dims = 0 if (not self.integral_loss is None) and self.sensitivity == 'autograd': excess_dims += 1 # handle aux. operations required for some jacobian trace CNF estimators e.g Hutchinson's # as well as datasets-control set to DataControl module for _, module in self.vf.named_modules(): if hasattr(module, 'trace_estimator'): if module.noise_dist is not None: module.noise = module.noise_dist.sample((x.shape[0],)) excess_dims += 1 # data-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC if hasattr(module, '_control'): self.controlled = True module._control = x[:, excess_dims:].detach() return x, t_span
[docs] def forward(self, x:Union[Tensor, Dict], t_span:Tensor=None, save_at:Iterable=(), args={}): x, t_span = self._prep_integration(x, t_span) t_eval, sol = super().forward(x, t_span, save_at, args) if self.return_t_eval: return t_eval, sol else: return sol
def trajectory(self, x:torch.Tensor, t_span:Tensor): x, t_span = self._prep_integration(x, t_span) _, sol = odeint(self.vf, x, t_span, solver=self.solver, atol=self.atol, rtol=self.rtol) return sol def __repr__(self): npar = sum([p.numel() for p in self.vf.parameters()]) return f"Neural ODE:\n\t- order: {self.order}\ \n\t- solver: {self.solver}\n\t- adjoint solver: {self.solver_adjoint}\ \n\t- tolerances: relative {self.rtol} absolute {self.atol}\ \n\t- adjoint tolerances: relative {self.rtol_adjoint} absolute {self.atol_adjoint}\ \n\t- num_parameters: {npar}\ \n\t- NFE: {self.vf.nfe}"
[docs]class NeuralSDE(SDEProblem, pl.LightningModule): def __init__(self, drift_func, diffusion_func, noise_type ='diagonal', sde_type = 'ito', order=1, sensitivity='autograd', s_span=torch.linspace(0, 1, 2), solver='srk', atol=1e-4, rtol=1e-4, ds = 1e-3, intloss=None): """Generic Neural Stochastic Differential Equation. Follows the same design of the `NeuralODE` class. Args: drift_func ([type]): drift function diffusion_func ([type]): diffusion function noise_type (str, optional): Defaults to 'diagonal'. sde_type (str, optional): Defaults to 'ito'. order (int, optional): Defaults to 1. sensitivity (str, optional): Defaults to 'autograd'. s_span ([type], optional): Defaults to torch.linspace(0, 1, 2). solver (str, optional): Defaults to 'srk'. atol ([type], optional): Defaults to 1e-4. rtol ([type], optional): Defaults to 1e-4. ds ([type], optional): Defaults to 1e-3. intloss ([type], optional): Defaults to None. Raises: NotImplementedError: higher-order Neural SDEs are not yet implemented, raised by setting `order` to >1. Notes: The current implementation is rougher around the edges compared to `NeuralODE`, and is not guaranteed to have the same features. """ super().__init__(func=SDEFunc(f=drift_func, g=diffusion_func, order=order), order=order, sensitivity=sensitivity, s_span=s_span, solver=solver, atol=atol, rtol=rtol) if order != 1: raise NotImplementedError self.defunc.noise_type, self.defunc.sde_type = noise_type, sde_type self.adaptive = False self.intloss = intloss self._control, self.controlled = None, False # datasets-control self.ds = ds def _prep_sdeint(self, x:torch.Tensor): self.s_span = self.s_span.to(x) # datasets-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC excess_dims = 0 for _, module in self.defunc.named_modules(): if hasattr(module, '_control'): self.controlled = True module._control = x[:, excess_dims:].detach() return x
[docs] def forward(self, x:torch.Tensor): x = self._prep_sdeint(x) switcher = { 'autograd': self._autograd, 'adjoint': self._adjoint, } sdeint = switcher.get(self.sensitivity) out = sdeint(x) return out
def trajectory(self, x:torch.Tensor, s_span:torch.Tensor): x = self._prep_sdeint(x) sol = torchsde.sdeint(self.defunc, x, s_span, rtol=self.rtol, atol=self.atol, method=self.solver, dt=self.ds) return sol def backward_trajectory(self, x:torch.Tensor, s_span:torch.Tensor): raise NotImplementedError def _autograd(self, x): self.defunc.intloss, self.defunc.sensitivity = self.intloss, self.sensitivity return torchsde.sdeint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol, adaptive=self.adaptive, method=self.solver, dt=self.ds)[-1] def _adjoint(self, x): out = torchsde.sdeint_adjoint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol, adaptive=self.adaptive, method=self.solver, dt=self.ds)[-1] return out
[docs]class MultipleShootingLayer(MultipleShootingProblem, pl.LightningModule): def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd', maxiter:int=4, fine_steps:int=4, solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-6, rtol_adjoint:float=1e-6, seminorm:bool=False, integral_loss:Union[Callable, None]=None): """Multiple Shooting Layer as defined in https://arxiv.org/abs/2106.03885. Uses parallel-in-time ODE solvers to solve an ODE parametrized by neural network `vector_field`. Args: vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`. In the second case, the Callable is automatically wrapped for consistency solver (Union[str, nn.Module]): parallel-in-time solver, ['zero', 'direct'] sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'. maxiter (int): number of iterations of the root finding routine defined to parallel solve the ODE. fine_steps (int): number of fine-solver steps to perform in each subinterval of the parallel solution. solver_adjoint (Union[str, nn.Module, None], optional): Standard sequential ODE solver for the adjoint system. atol_adjoint (float, optional): Defaults to 1e-6. rtol_adjoint (float, optional): Defaults to 1e-6. integral_loss (Union[Callable, None], optional): Currently not implemented seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False. Notes: The number of shooting parameters (first dimension in `B0`) is implicitly defined by passing `t_span` during forward calls. For example, a `t_span=torch.linspace(0, 1, 10)` will define 9 intervals and 10 shooting parameters. For the moment only a thin wrapper around `MultipleShootingProblem`. At this level will be convenience routines for special initializations of shooting parameters `B0`, as well as usual convenience checks for integral losses. """ super().__init__(vector_field, solver, sensitivity, maxiter, fine_steps, solver_adjoint, atol_adjoint, rtol_adjoint, seminorm, integral_loss)

© Copyright 2020, Stefano Massaroli & Michael Poli. Revision 7b05e463.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.