Stable Neural ODEs (Stable Neural Flows)¶
First introduce in Massaroli, Poli et al, 2020 Stable Neural FLows represent a stable variant of Neural ODEs. Their most simple realization has the general nural ODE form
where \(\varepsilon(x, z, \theta)\) is a neural network.
They can be used both as general-purpose modules (e.g. classification, continuous normalizing flows) or, thanks to their unique structure, they can be employed to learn dynamical systems in a similar fashion to Lagrangian/Hamiltonian-inspired models
[1]:
import sys
sys.path.append('../')
from torchdyn.models import *
from torchdyn import *
from torchdyn.datasets import *
[2]:
# Vanilla Version of stable neural flows
class Stable(nn.Module):
"""Stable Neural Flow"""
def __init__(self, net, depthvar=False, controlled=False):
super().__init__()
self.net, self.depthvar, self.controlled = net, depthvar, controlled
def forward(self, x):
with torch.set_grad_enabled(True):
bs, n = x.shape[0], x.shape[1] // 2
x = x.requires_grad_(True)
eps = self.net(x).sum()
out = -torch.autograd.grad(eps, x, allow_unused=False, create_graph=True)[0]
out = out[:,:-1] if self.depthvar else out
out = out[:,:-2] if self.controlled else out
return out
Learninig Dynamical Systems¶
Stable neural flows variants in a (autonomous) port–Hamiltonian form
generalizes the Hamiltonian paradigm to modeling multi-physics systems. They obey to the power balance equation
Therefore, if one wants to learn e.g. some conservative process (of any nature), it is sufficient to introduce the inductive bias on \(\bf F\) to be a skew-symmetric matrix such that \(\dot \varepsilon = 0\).
Here, we showcase the capibilities of stable neural flows (in port-Hamiltonian form) in such tasks.
[3]:
# Conservative variant of stable neural flow
class ConservativeStable(nn.Module):
"""Conservative Stable Neural Flow"""
def __init__(self, net, depthvar=False, controlled=False):
super().__init__()
self.net, self.depthvar, self.controlled = net, depthvar, controlled
self.M = torch.nn.Parameter(torch.randn(2,2)).to(device)
# impose the system matrix to be skew symmetric
def Skew(self):
return .5*(self.M - self.M.T)
def forward(self, x):
with torch.set_grad_enabled(True):
bs, n = x.shape[0], x.shape[1] // 2
x = x.requires_grad_(True)
eps = self.net(x).sum()
out = -torch.autograd.grad(eps, x, allow_unused=False, create_graph=True)[0]
#self.out = out
out = out[:,:-1] if self.depthvar else out
out = out[:,:-2] if self.controlled else out
return out @ self.Skew()
We aim at using a stable neural ODE learning the following conservative nonlinear dynamical system
[4]:
# We use this class to simulate through torchdyn the above nonlinear system
class odefunc(nn.Module):
def __init__(self, sys):
super().__init__()
self.sys = sys
def forward(self, x):
return self.sys(x)
## nonlinear conservative vector field
def sys(x):
dxdt = x[:,1]
dydt = 1*np.pi*torch.cos(np.pi*x[:,0]-np.pi/2) - np.pi*x[:,0]# - .5*np.pi*x[:,1]
return torch.cat([dxdt[:,None], dydt[:,None]], 1)
[5]:
# define the system model just like a neural ODE
system = NeuralDE(odefunc(sys))
x0, t_span = torch.randn(1000,2), torch.linspace(0, 2, 100)
# simulate the system
traj = system.trajectory(x0, t_span)
# plot the trajectories
for i in range(len(x0)):
plt.plot(traj[:,i,0], traj[:,i,1], color='blue', alpha=.1)

Train the conservative stable neural flow
[6]:
import torch.utils.data as data
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Data
vf = odefunc(sys)
X = 4*torch.rand(2048,2).to(device)
y = vf(X)
train = data.TensorDataset(X, y)
trainloader = data.DataLoader(train, batch_size=len(X), shuffle=False)
[7]:
import pytorch_lightning as pl
import copy
class Learner(pl.LightningModule):
def __init__(self, model:nn.Module):
super().__init__()
self.model = model
def forward(self, x):
return self.model.defunc(0,x)
def loss(self, y, y_hat):
return ((y-y_hat)**2).sum(1).mean()
def training_step(self, batch, batch_idx):
x = torch.randn(2048,2).to(device)
y = vf(x)
y_hat = self.model.defunc(0,x)
loss = self.loss(y_hat, y)
logs = {'train_loss': loss}
return {'loss': loss, 'log': logs}
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.001)
def train_dataloader(self):
return trainloader
[8]:
# vector field parametrized by a NN
h_dim = 128
f = ConservativeStable(nn.Sequential(
nn.Linear(2,h_dim),
nn.Tanh(),
nn.Linear(h_dim,h_dim),
nn.Tanh(),
nn.Linear(h_dim,h_dim),
nn.Tanh(),
nn.Linear(h_dim, 1)))
# neural ODE
model = NeuralDE(f,
order=1,
solver='dopri5',
sensitivity='adjoint').to(device)
seq = nn.Sequential(model).to(device)
[ ]:
learn = Learner(model)
trainer = pl.Trainer(min_epochs=500, max_epochs=1000)
trainer.fit(learn)
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
| Name | Type | Params
-----------------------------------
0 | model | NeuralDE | 33 K
[10]:
# Sample random initial conditions
X_t = torch.randn(1000, 2).to(device)
# Evaluate the model's trajectories
s_span = torch.linspace(0, 5, 100)
traj = model.trajectory(X_t, s_span).detach().cpu()
sys_traj = system.trajectory(X_t, s_span).detach().cpu()
# Plot the trajectories with random ICs
fig = plt.figure(figsize=(10,3))
ax = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
for i in range(len(X_t)):
ax.plot(traj[:,i,0], traj[:,i,1], color='blue', alpha=0.1);
ax.set_xlim([-3,3])
ax.set_ylim([-3,3])
ax.set_xlabel(r"$q$")
ax.set_ylabel(r"$p$")
ax.set_title("Reconstructed")
for i in range(len(X_t)):
ax2.plot(sys_traj[:,i,0], sys_traj[:,i,1], color='blue', alpha=0.1);
ax2.set_xlim([-3,3])
ax2.set_ylim([-3,3])
ax2.set_xlabel(r"$q$")
ax2.set_ylabel(r"$p$")
ax2.set_title("Nominal")
[10]:
Text(0.5, 1.0, 'Nominal')

[11]:
# Compare the learned vector field to the nominal one
import time
fig = plt.figure(figsize=(10,3))
ax0 = fig.add_subplot(121)
ax1 = fig.add_subplot(122)
n_grid = 25
q = torch.linspace(-3,3,n_grid)
Q, P = torch.meshgrid(q,q)
H, U, V = torch.zeros(Q.shape), torch.zeros(Q.shape), torch.zeros(Q.shape)
Ur, Vr = torch.zeros(Q.shape), torch.zeros(Q.shape)
for i in range(n_grid):
for j in range(n_grid):
x = torch.cat([Q[i,j].reshape(1,1),P[i,j].reshape(1,1)],1).to(device)
H[i,j] = model.defunc.m.net(x).detach().cpu()
O = model.defunc(0,x).detach().cpu()
U[i,j], V[i,j] = O[0,0], O[0,1]
Ur[i,j], Vr[i,j] = vf(x)[0,0].detach().cpu(), vf(x)[0,1].detach().cpu()
ax0.contourf(Q,P,H,100,cmap='inferno')
ax0.streamplot(Q.T.numpy(),P.T.numpy(),U.T.numpy(),V.T.numpy(), color='white')
ax1.streamplot(Q.T.numpy(),P.T.numpy(),Ur.T.numpy(),Vr.T.numpy(), color='black')
ax0.set_xlim([Q.min(),Q.max()]) ; ax1.set_xlim([Q.min(),Q.max()])
ax0.set_ylim([P.min(),P.max()]) ; ax1.set_ylim([P.min(),P.max()])
ax0.set_xticks([]) ; ax1.set_xticks([])
ax0.set_yticks([]) ; ax1.set_yticks([])
ax0.set_title(f"Learnerd Energy & Vector Field") ; ax1.set_title("Nominal Vector Field")
[11]:
Text(0.5, 1.0, 'Nominal Vector Field')
