Shortcuts

Source code for torchdyn.models.energy

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import torch
from torch import Tensor
import torch.nn as nn
from torch.autograd import grad
from torch.autograd.functional import hessian, jacobian


[docs]class ConservativeLinearSNF(nn.Module): def __init__(self, energy, J): """Stable Neural Flows: https://arxiv.org/abs/2003.08063 A generalization of Hamiltonian Neural Networks and other energy-based parametrization of Neural ODEs Conservative version with energy preservation. Input assumed to be of dimensions `batch, dim` Args: energy: function parametrizing the energy. J: network parametrizing the skew-symmetric interconnection matrix """ super().__init__() self.energy = energy self.J = J
[docs] def forward(self, x: Tensor): with torch.set_grad_enabled(True): self.n = x.shape[1] // 2 x = x.requires_grad_(True) dHdx = torch.autograd.grad(self.H(x).sum(), x, create_graph=True)[0] dHdx = torch.einsum('ijk, ij -> ik', self._skew(x), dHdx) return dHdx
def _generate_skew(self, x): M = self.J(x).reshape(-1, *x.shape[1:]) return (M - M.transpose(0, 2, 1)) / 2
[docs]class GNF(nn.Module): def __init__(self, energy:nn.Module): """Gradient Neural Flows version of SNFs: https://arxiv.org/abs/2003.08063 Args: energy (nn.Module): function parametrizing the energy. """ super().__init__() self.energy = energy
[docs] def forward(self, x): with torch.set_grad_enabled(True): x = x.requires_grad_(True) eps = self.energy(x).sum() out = -torch.autograd.grad(eps, x, allow_unused=False, create_graph=True)[0] return out
[docs]class HNN(nn.Module): def __init__(self, net:nn.Module): """Hamiltonian Neural ODE Args: net (nn.Module): function parametrizing the vector field. """ super().__init__() self.net = net
[docs] def forward(self, x): with torch.set_grad_enabled(True): n = x.shape[1] // 2 x = x.requires_grad_(True) gradH = grad(self.net(x).sum(), x, create_graph=True)[0] return torch.cat([gradH[:, n:], -gradH[:, :n]], 1).to(x)
[docs]class LNN(nn.Module): def __init__(self, net): """Lagrangian Neural Network. Args: net (nn.Module) Notes: LNNs are currently quite slow. Improvements will be made whenever `functorch` is either merged upstream or included as a dependency. """ super().__init__() self.net = net
[docs] def forward(self, x): self.n = n = x.shape[1]//2 bs = x.shape[0] x = x.requires_grad_(True) qqd_batch = tuple(x[i, :] for i in range(bs)) jac = tuple(map(partial(jacobian, self._lagrangian, create_graph=True), qqd_batch)) hess = tuple(map(partial(hessian, self._lagrangian, create_graph=True), qqd_batch)) qdd_batch = tuple(map(self._qdd, zip(jac, hess, qqd_batch))) qd, qdd = x[:, n:], torch.cat([qdd[None] for qdd in qdd_batch]) return torch.cat([qd, qdd], 1)
def _lagrangian(self, qqd): return self.net(qqd).sum() def _qdd(self, inp): n = self.n ; jac, hess, qqd = inp return hess[n:, n:].pinverse()@(jac[:n] - hess[n:, :n]@qqd[n:])

© Copyright 2020, Stefano Massaroli & Michael Poli. Revision 7b05e463.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.