Source code for torchdyn.nn.galerkin

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import math


[docs]class GaussianRBF(nn.Module): """Eigenbasis expansion using gaussian radial basis functions. $phi(r) = e^{-(\eps r)^2}$ with $r := || x - x0 ||_2$" :param deg: degree of the eigenbasis expansion :type deg: int :param adaptive: whether to adjust `centers` and `eps_scales` during training. :type adaptive: bool :param eps_scales: scaling in the rbf formula ($\eps$) :type eps_scales: int :param centers: centers of the radial basis functions (one per degree). Same center across all degrees. x0 in the radius formulas :type centers: int """ def __init__(self, deg, adaptive=False, eps_scales=2, centers=0): super().__init__() self.deg, self.n_eig = deg, 1 if adaptive: self.centers = torch.nn.Parameter(centers * torch.ones(deg + 1)) self.eps_scales = torch.nn.Parameter(eps_scales * torch.ones((deg + 1))) else: self.centers = 0 self.eps_scales = 2
[docs] def forward(self, n_range, s): n_range_scaled = (n_range - self.centers) / self.eps_scales r = torch.norm(s - self.centers, p=2) basis = [math.e ** (-(r * n_range_scaled) ** 2)] return basis
[docs]class VanillaRBF(nn.Module): """Eigenbasis expansion using vanilla radial basis functions." :param deg: degree of the eigenbasis expansion :type deg: int :param adaptive: whether to adjust `centers` and `eps_scales` during training. :type adaptive: bool :param eps_scales: scaling in the rbf formula ($\eps$) :type eps_scales: int :param centers: centers of the radial basis functions (one per degree). Same center across all degrees. x0 in the radius formulas :type centers: int """ def __init__(self, deg, adaptive=False, eps_scales=2, centers=0): super().__init__() self.deg, self.n_eig = deg, 1 if adaptive: self.centers = torch.nn.Parameter(centers * torch.ones(deg + 1)) self.eps_scales = torch.nn.Parameter(eps_scales * torch.ones((deg + 1))) else: self.centers = 0 self.eps_scales = 2
[docs] def forward(self, n_range, s): n_range_scaled = n_range / self.eps_scales r = torch.norm(s - self.centers, p=2) basis = [r * n_range_scaled] return basis
[docs]class MultiquadRBF(nn.Module): """Eigenbasis expansion using multiquadratic radial basis functions." :param deg: degree of the eigenbasis expansion :type deg: int :param adaptive: whether to adjust `centers` and `eps_scales` during training. :type adaptive: bool :param eps_scales: scaling in the rbf formula ($\eps$) :type eps_scales: int :param centers: centers of the radial basis functions (one per degree). Same center across all degrees. x0 in the radius formulas :type centers: int """ def __init__(self, deg, adaptive=False, eps_scales=2, centers=0): super().__init__() self.deg, self.n_eig = deg, 1 if adaptive: self.centers = torch.nn.Parameter(centers * torch.ones(deg + 1)) self.eps_scales = torch.nn.Parameter(eps_scales * torch.ones((deg + 1))) else: self.centers = 0 self.eps_scales = 2
[docs] def forward(self, n_range, s): n_range_scaled = n_range / self.eps_scales r = torch.norm(s - self.centers, p=2) basis = [1 + torch.sqrt(1 + (r * n_range_scaled) ** 2)] return basis
[docs]class Fourier(nn.Module): """Eigenbasis expansion using fourier functions." :param deg: degree of the eigenbasis expansion :type deg: int :param adaptive: does nothing (for now) :type adaptive: bool """ def __init__(self, deg, adaptive=False): super().__init__() self.deg, self.n_eig = deg, 2
[docs] def forward(self, n_range, s): s_n_range = s * n_range basis = [torch.cos(s_n_range), torch.sin(s_n_range)] return basis
[docs]class Polynomial(nn.Module): """Eigenbasis expansion using polynomials." :param deg: degree of the eigenbasis expansion :type deg: int :param adaptive: does nothing (for now) :type adaptive: bool """ def __init__(self, deg, adaptive=False): super().__init__() self.deg, self.n_eig = deg, 1
[docs] def forward(self, n_range, s): basis = [s ** n_range] return basis
[docs]class Chebychev(nn.Module): """Eigenbasis expansion using chebychev polynomials." :param deg: degree of the eigenbasis expansion :type deg: int :param adaptive: does nothing (for now) :type adaptive: bool """ def __init__(self, deg, adaptive=False): super().__init__() self.deg, self.n_eig = deg, 1
[docs] def forward(self, n_range, s): max_order = n_range[-1].int().item() basis = [1] # Based on numpy's Cheb code if max_order > 0: s2 = 2 * s basis += [s.item()] for i in range(2, max_order): basis += [basis[-1] * s2 - basis[-2]] return [torch.tensor(basis).to(n_range)]
class GalLayer(nn.Module): """Galerkin layer template. Introduced in https://arxiv.org/abs/2002.08071""" def __init__(self, bias=True, expfunc=Fourier(5), dilation=True, shift=True): super().__init__() self.dilation = torch.ones(1) if not dilation else nn.Parameter(data=torch.ones(1), requires_grad=True) self.shift = torch.zeros(1) if not shift else nn.Parameter(data=torch.zeros(1), requires_grad=True) self.expfunc = expfunc self.n_eig = n_eig = self.expfunc.n_eig self.deg = deg = self.expfunc.deg def reset_parameters(self): torch.nn.init.zeros_(self.coeffs) def calculate_weights(self, t): "Expands `t` following the chosen eigenbasis" n_range = torch.linspace(0, self.deg, self.deg).to(self.coeffs.device) basis = self.expfunc(n_range, t*self.dilation.to(self.coeffs.device) + self.shift.to(self.coeffs.device)) B = [] for i in range(self.n_eig): Bin = torch.eye(self.deg).to(self.coeffs.device) Bin[range(self.deg), range(self.deg)] = basis[i] B.append(Bin) B = torch.cat(B, 1).to(self.coeffs.device) coeffs = torch.cat([self.coeffs[:,:,i] for i in range(self.n_eig)],1).transpose(0,1).to(self.coeffs.device) X = torch.matmul(B, coeffs) return X.sum(0)
[docs]class GalLinear(GalLayer): """Linear Galerkin layer for depth--variant neural differential equations. Introduced in https://arxiv.org/abs/2002.08071 :param in_features: input dimensions :type in_features: int :param out_features: output dimensions :type out_features: int :param bias: include bias parameter vector in the layer computation :type bias: bool :param expfunc: {'Fourier', 'Polynomial', 'Chebychev', 'VanillaRBF', 'MultiquadRBF', 'GaussianRBF'}. Choice of eigenfunction expansion. :type expfunc: str :param dilation: whether to optimize for `dilation` parameter. Allows the GalLayer to dilate the eigenfunction period. :type dilation: bool :param shift: whether to optimize for `shift` parameter. Allows the GalLayer to shift the eigenfunction period. :type shift: bool """ def __init__(self, in_features, out_features, bias=True, expfunc=Fourier(5), dilation=True, shift=True): super().__init__(bias, expfunc, dilation, shift) self.in_features, self.out_features = in_features, out_features self.weight = torch.Tensor(out_features, in_features) if bias: self.bias = torch.Tensor(out_features) else: self.register_parameter('bias', None) self.coeffs = torch.nn.Parameter(torch.Tensor((in_features+1)*out_features, self.deg, self.n_eig)) self.reset_parameters()
[docs] def forward(self, input): # For the moment, GalLayers rely on DepthCat to access the `t` variable. t = input[-1,-1] input = input[:,:-1] w = self.calculate_weights(t) self.weight = w[0:self.in_features*self.out_features].reshape(self.out_features, self.in_features) self.bias = w[self.in_features*self.out_features:(self.in_features+1)*self.out_features].reshape(self.out_features) return torch.nn.functional.linear(input, self.weight, self.bias)
[docs]class GalConv2d(GalLayer): """2D convolutional Galerkin layer for depth--variant neural differential equations. Introduced in https://arxiv.org/abs/2002.08071 :param in_channels: number of channels in the input image :type in_channels: int :param out_channels: number of channels produced by the convolution :type out_channels: int :param kernel_size: size of the convolving kernel :type kernel_size: int :param stride: stride of the convolution. Default: 1 :type stride: int :param padding: zero-padding added to both sides of the input. Default: 0 :type padding: int :param bias: include bias parameter vector in the layer computation :type bias: bool :param expfunc: {'Fourier', 'Polynomial', 'Chebychev', 'VanillaRBF', 'MultiquadRBF', 'GaussianRBF'}. Choice of eigenfunction expansion. :type expfunc: str :param dilation: whether to optimize for `dilation` parameter. Allows the GalLayer to dilate the eigenfunction period. :type dilation: bool :param shift: whether to optimize for `shift` parameter. Allows the GalLayer to shift the eigenfunction period. :type shift: bool """ __constants__ = ['bias', 'in_channels', 'out_channels', 'kernel_size', 'stride', 'padding', 'deg'] def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True, expfunc=Fourier(5), dilation=True, shift=True): super().__init__(bias, expfunc, dilation, shift) self.ic, self.oc, self.ks = in_channels, out_channels, kernel_size self.pad, self.stride = padding, stride self.weight = torch.Tensor(out_channels, in_channels, kernel_size, kernel_size) if bias: self.bias = torch.Tensor(out_channels) else: self.register_parameter('bias', None) self.coeffs = torch.nn.Parameter(torch.Tensor(((out_channels)*in_channels*(kernel_size**2)+out_channels), self.deg, 2)) self.reset_parameters()
[docs] def forward(self, input): t = input[-1,-1,0,0] input = input[:,:-1] w = self.calculate_weights(t) n = self.oc*self.ic*self.ks*self.ks self.weight = w[0:n].reshape(self.oc, self.ic, self.ks, self.ks) self.bias = w[n:].reshape(self.oc) return torch.nn.functional.conv2d(input, self.weight, self.bias, stride=self.stride, padding=self.pad)