Higher-Order Neural ODEs

Following Dissecting Neural ODEs, In this tutorial we showcase how to handle higher-order neural ODEs in torchdyn

A higher–order neural ODEs can be simply defined as the initial value problem

\[\begin{split}\begin{aligned} &\frac{d^n \mathbf{z}}{ds^n} = f_\theta\left(s, \mathbf{x}, \mathbf{z}, \frac{d \mathbf{z}}{ds}, \frac{d^2 \mathbf{z}}{ds^2}, \dots, \frac{d^{n-1} \mathbf{z}}{ds^{n-1}}\right)\\ &\mathbf{z}(0), \frac{d \mathbf{z}}{ds}(0), \frac{d^2 \mathbf{z}}{ds^2}(0), \dots, \frac{d^{n-1} \mathbf{z}}{ds^{n-1}}(0) = h_{\mathbf{x}}(\mathbf{x})\\ &{\bf y} = h_{\bf y}\left(\mathbf{z}(S), \frac{d \mathbf{z}}{ds}(S), \frac{d^2 \mathbf{z}}{ds^2}(S), \dots, \frac{d^{n-1} \mathbf{z}}{ds^{n-1}}(S)\right) \end{aligned}\end{split}\]
[1]:
import sys ; sys.path.append('../')
from torchdyn.models import *
from torchdyn import *
from torchdyn.datasets import *

Data: we use again the moons dataset (with some added noise) simply because all the models will be effective to solve the binary classification problem.

[2]:
d = ToyDataset()
X, yn = d.generate(n_samples=2048, dataset_type='spirals', noise=.4)
[3]:
import matplotlib.pyplot as plt

colors = ['orange', 'blue']
fig = plt.figure(figsize=(3,3))
ax = fig.add_subplot(111)
for i in range(len(X)):
    ax.scatter(X[i,0], X[i,1], color=colors[yn[i].int()])
../_images/tutorials_06_higher_order_4_0.png
[4]:
import torch
import torch.utils.data as data
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

X_train = torch.Tensor(X).to(device)
y_train = torch.LongTensor(yn.long()).to(device)
train = data.TensorDataset(X_train, y_train)
trainloader = data.DataLoader(train, batch_size=len(X), shuffle=True)

Learner

[5]:
import torch.nn as nn
import pytorch_lightning as pl

class Learner(pl.LightningModule):
    def __init__(self, model:nn.Module):
        super().__init__()
        self.model = model
        self.c = 0

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.01)

    def train_dataloader(self):
        return trainloader

Train a simple Second-Order Model

Second order models have also been considered by [CITE]

We can train out of the box a 2nd order model for classification as follows

[6]:
# vector field parametrized by a NN
f = nn.Sequential(
        nn.Linear(4, 64),
        nn.Tanh(),
        nn.Linear(64, 2))

# Neural ODE
model = NeuralDE(f,
                 order=2,
                 solver='dopri5',
                 sensitivity='adjoint').to(device)

seq = nn.Sequential(Augmenter(1, 2, order='last'), model, nn.Linear(4, 2)).to(device)
[7]:
# train the neural ODE
learn = Learner(seq)
trainer = pl.Trainer(min_epochs=600, max_epochs=1200)
trainer.fit(learn)
GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 460
/home/jyp/michael_dev/testenv/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:25: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)

[7]:
1
[9]:
# Evaluate the data trajectories
s_span = torch.linspace(0,1,100)
X_d = seq[0](X_train[::10,:])
trajectory = model.trajectory(X_d, s_span).detach().cpu()
[10]:
# Trajectories in the depth domain
plot_2D_depth_trajectory(s_span, trajectory[:,:,:2], yn[::10], len(X)//10)
plot_2D_depth_trajectory(s_span, trajectory[:,:,2:4], yn[::10], len(X)//10)
../_images/tutorials_06_higher_order_12_0.png
../_images/tutorials_06_higher_order_12_1.png
[11]:
# Trajectories in the state-space
plot_2D_state_space(trajectory[:,:,-2:], yn[::10], len(X)//10)
../_images/tutorials_06_higher_order_13_0.png

Let us now train a higher–order model, e.g. 10th order, on the same task

Showcase of Higher-Order Models (10th order Neural ODE)

Here, we introduce an integral regularization term leveraing the generalized adjoint for Neural ODEs.

[12]:
class IntegralWReg(nn.Module):
    def __init__(self, f):
        super().__init__()
        self.f = f
    def forward(self, s, x):
        loss = 1e-6*torch.abs(self.f(x)).sum(1)
        return loss

With torchdyn, going beyond orders 1 and 2 is easy. Simply use the order argument during NeuralDE instantiations.

[13]:
# vector field parametrized by a NN
f = nn.Sequential(
        nn.Linear(20, 128),
        nn.Tanh(),
        nn.Linear(128, 2))

# Neural ODE
model = NeuralDE(f,
                 order=10,
                 solver='dopri5',
                 sensitivity='adjoint',
                 intloss=IntegralWReg(f)).to(device)

seq = nn.Sequential(Augmenter(1, 18, order='last'), model, nn.Linear(20,2)).to(device)

Note that the training will be slower than order=1 Neural ODEs: higher order dynamics are often more challenging to integrate, and adaptive-step solvers will require higher NFEs.

[14]:
# train the neural ODE
learn = Learner(seq)
trainer = pl.Trainer(min_epochs=600, max_epochs=1200)
trainer.fit(learn)
# Don't be alarmed by oscillations of the "terminal" loss (printed at every iter)! Having also an "integral" loss term, it is the sum of the two that really matters.
GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 2 K
/home/jyp/michael_dev/testenv/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:25: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)

[14]:
1

Plots

[15]:
# Evaluate the data trajectories
s_span = torch.linspace(0,1,100)
X_d = seq[0](X_train[::10,:])
trajectory = model.trajectory(X_d, s_span).detach().cpu()
[16]:
# Trajectories in the depth domain
plot_2D_depth_trajectory(s_span, trajectory[:,:,:2], yn[::10], len(X)//10)
plot_2D_depth_trajectory(s_span, trajectory[:,:,2:4], yn[::10], len(X)//10)
plot_2D_depth_trajectory(s_span, trajectory[:,:,4:6], yn[::10], len(X)//10)
plot_2D_depth_trajectory(s_span, trajectory[:,:,6:8], yn[::10], len(X)//10)
plot_2D_depth_trajectory(s_span, trajectory[:,:,8:10], yn[::10], len(X)//10)
plot_2D_depth_trajectory(s_span, trajectory[:,:,10:12], yn[::10], len(X)//10)
plot_2D_depth_trajectory(s_span, trajectory[:,:,12:14], yn[::10], len(X)//10)
plot_2D_depth_trajectory(s_span, trajectory[:,:,14:16], yn[::10], len(X)//10)
plot_2D_depth_trajectory(s_span, trajectory[:,:,16:18], yn[::10], len(X)//10)
plot_2D_depth_trajectory(s_span, trajectory[:,:,18:20], yn[::10], len(X)//10)
../_images/tutorials_06_higher_order_24_0.png
../_images/tutorials_06_higher_order_24_1.png
../_images/tutorials_06_higher_order_24_2.png
../_images/tutorials_06_higher_order_24_3.png
../_images/tutorials_06_higher_order_24_4.png
../_images/tutorials_06_higher_order_24_5.png
../_images/tutorials_06_higher_order_24_6.png
../_images/tutorials_06_higher_order_24_7.png
../_images/tutorials_06_higher_order_24_8.png
../_images/tutorials_06_higher_order_24_9.png
[17]:
# Trajectories in the state-space
plot_2D_state_space(trajectory[:,:,-2:], yn[::10], len(X)//10)
../_images/tutorials_06_higher_order_25_0.png