Source code for torchdyn.core.defunc
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict
import torch
from torch import Tensor, cat
import torch.nn as nn
class DEFuncBase(nn.Module):
def __init__(self, vector_field:Callable, has_time_arg:bool=True):
"""Basic wrapper to ensure call signature compatibility between generic torch Modules and vector fields.
Args:
vector_field (Callable): callable defining the dynamics / vector field / `dxdt` / forcing function
has_time_arg (bool, optional): Internal arg. to indicate whether the callable has `t` in its `__call__'
or `forward` method. Defaults to True.
"""
super().__init__()
self.nfe, self.vf, self.has_time_arg = 0., vector_field, has_time_arg
def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
self.nfe += 1
if self.has_time_arg: return self.vf(t, x, args=args)
else: return self.vf(x)
[docs]class DEFunc(nn.Module):
def __init__(self, vector_field:Callable, order:int=1):
"""Special vector field wrapper for Neural ODEs.
Handles auxiliary tasks: time ("depth") concatenation, higher-order dynamics and forward propagated integral losses.
Args:
vector_field (Callable): callable defining the dynamics / vector field / `dxdt` / forcing function
order (int, optional): order of the differential equation. Defaults to 1.
Notes:
Currently handles the following:
(1) assigns time tensor to each submodule requiring it (e.g. `GalLinear`).
(2) in case of integral losses + reverse-mode differentiation, propagates the loss in the first dimension of `x`
and automatically splits the Tensor into `x[:, 0]` and `x[:, 1:]` for vector field computation
(3) in case of higher-order dynamics, adjusts the vector field forward to recursively compute various orders.
"""
super().__init__()
self.vf, self.nfe, = vector_field, 0.
self.order, self.integral_loss, self.sensitivity = order, None, None
# identify whether vector field already has time arg
[docs] def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
self.nfe += 1
# set `t` depth-variable to DepthCat modules
for _, module in self.vf.named_modules():
if hasattr(module, 't'):
module.t = t
# if-else to handle autograd training with integral loss propagated in x[:, 0]
if (self.integral_loss is not None) and self.sensitivity == 'autograd':
x_dyn = x[:, 1:]
dlds = self.integral_loss(t, x_dyn)
if len(dlds.shape) == 1: dlds = dlds[:, None]
if self.order > 1: x_dyn = self.horder_forward(t, x_dyn, args)
else: x_dyn = self.vf(t, x_dyn)
return cat([dlds, x_dyn], 1).to(x_dyn)
# regular forward
else:
if self.order > 1: x = self.higher_order_forward(t, x)
else: x = self.vf(t, x, args=args)
return x
def higher_order_forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
x_new = []
size_order = x.size(1) // self.order
for i in range(1, self.order):
x_new.append(x[:, size_order*i : size_order*(i+1)])
x_new.append(self.vf(t, x))
return cat(x_new, dim=1).to(x)
class SDEFunc(nn.Module):
def __init__(self, f:Callable, g:Callable, order:int=1):
""""Special vector field wrapper for Neural SDEs.
Args:
f (Callable): callable defining the drift
g (Callable): callable defining the diffusion term
order (int, optional): order of the differential equation. Defaults to 1.
"""
super().__init__()
self.order, self.intloss, self.sensitivity = order, None, None
self.f_func, self.g_func = f, g
self.nfe = 0
def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
pass
def f(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
self.nfe += 1
for _, module in self.f_func.named_modules():
if hasattr(module, 't'):
module.t = t
return self.f_func(x, args)
def g(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
for _, module in self.g_func.named_modules():
if hasattr(module, 't'):
module.t = t
return self.g_func(x, args)