Source code for torchdyn.nn.node_layers
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
[docs]class Augmenter(nn.Module):
"""Augmentation class. Can handle several types of augmentation strategies for Neural DEs.
:param augment_dims: number of augmented dimensions to initialize
:type augment_dims: int
:param augment_idx: index of dimension to augment
:type augment_idx: int
:param augment_func: nn.Module applied to the input datasets of dimension `d` to determine the augmented initial condition of dimension `d + a`.
`a` is defined implicitly in `augment_func` e.g. augment_func=nn.Linear(2, 5) augments a 2 dimensional input with 3 additional dimensions.
:type augment_func: nn.Module
:param order: whether to augment before datasets [augmentation, x] or after [x, augmentation] along dimension `augment_idx`. Options: ('first', 'last')
:type order: str
"""
def __init__(self, augment_idx:int=1, augment_dims:int=5, augment_func=None, order='first'):
super().__init__()
self.augment_dims, self.augment_idx, self.augment_func = augment_dims, augment_idx, augment_func
self.order = order
[docs] def forward(self, x: torch.Tensor):
if not self.augment_func:
new_dims = list(x.shape)
new_dims[self.augment_idx] = self.augment_dims
# if-else check for augmentation order
if self.order == 'first':
x = torch.cat([torch.zeros(new_dims).to(x), x],
self.augment_idx)
else:
x = torch.cat([x, torch.zeros(new_dims).to(x)],
self.augment_idx)
else:
# if-else check for augmentation order
if self.order == 'first':
x = torch.cat([self.augment_func(x).to(x), x],
self.augment_idx)
else:
x = torch.cat([x, self.augment_func(x).to(x)],
self.augment_idx)
return x
[docs]class DepthCat(nn.Module):
"""Depth variable `t` concatenation module. Allows for easy concatenation of `t` each call of the numerical solver, at specified nn of the DEFunc.
:param idx_cat: index of the datasets dimension to concatenate `t` to.
:type idx_cat: int
"""
def __init__(self, idx_cat=1):
super().__init__()
self.idx_cat, self.t = idx_cat, None
[docs] def forward(self, x):
t_shape = list(x.shape)
t_shape[self.idx_cat] = 1
t = self.t * torch.ones(t_shape).to(x)
return torch.cat([x, t], self.idx_cat).to(x)
[docs]class DataControl(nn.Module):
"""Data-control module. Allows for datasets-control inputs at arbitrary points of the DEFunc
"""
def __init__(self):
super().__init__()
self._control = None
[docs] def forward(self, x):
return torch.cat([x, self._control], 1).to(x)