Density Estimation with FFJORDs¶
Free-form Jacobian of Reversible Dynamics (FFJORD) is a continuous normalizing flow (CNF) variants proposed in Grathwohl et al., 2018. The core novelty is proposing utilization of stochastic trace estimators to improve scalability of the Jacobian trace computation (\(O(1)\) calls to autograd instead of \(O(D)\)).
[1]:
import sys ; sys.path.append('../')
import torchdyn; from torchdyn.models import *; from torchdyn.datasets import *
This time, we use a fun dataset: the DiffEqML
logo.
[2]:
data = ToyDataset()
n_samples = 1 << 14
n_gaussians = 7
X, yn = data.generate(n_samples, 'diffeqml', noise=5e-2)
X = (X - X.mean())/X.std()
import matplotlib.pyplot as plt
plt.figure(figsize=(3, 3))
plt.scatter(X[:,0], X[:,1], c='olive', alpha=0.3, s=1)
[2]:
<matplotlib.collections.PathCollection at 0x22c749bf3c8>

[3]:
import torch
import torch.utils.data as data
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
X_train = torch.Tensor(X).to(device)
train = data.TensorDataset(X_train)
trainloader = data.DataLoader(train, batch_size=1024, shuffle=True)
The FFJORD model¶
In torchdyn
, we decouple CNFs from the Jacobian trace estimators they use. This allows us to easily implement variants without alternations to the original module. Indeed, we can simply define the Hutchinson stochastic estimator separately as follows
[4]:
def hutch_trace(x_out, x_in, noise=None, **kwargs):
"""Hutchinson's trace Jacobian estimator, O(1) call to autograd"""
jvp = torch.autograd.grad(x_out, x_in, noise, create_graph=True)[0]
trJ = torch.einsum('bi,bi->b', jvp, noise)
return trJ
And then instantiate a CNF as before.
[5]:
f = nn.Sequential(
nn.Linear(2, 64),
nn.Softplus(),
nn.Linear(64, 64),
nn.Softplus(),
nn.Linear(64, 64),
nn.Softplus(),
nn.Linear(64, 2),
)
from torch.distributions import MultivariateNormal, Uniform, TransformedDistribution, SigmoidTransform, Categorical
prior = MultivariateNormal(torch.zeros(2).to(device), torch.eye(2).to(device))
# stochastic estimators require a definition of a distribution where "noise" vectors are sampled from
noise_dist = MultivariateNormal(torch.zeros(2).to(device), torch.eye(2).to(device))
# cnf wraps the net as with other energy models
cnf = CNF(f, trace_estimator=hutch_trace, noise_dist=noise_dist)
nde = NeuralDE(cnf, solver='dopri5', s_span=torch.linspace(0, 1, 2), sensitivity='adjoint', atol=1e-4, rtol=1e-4)
Augmenter takes care of setting up the additional scalar dimension for the divergence dynamics. The DEFunc
wrapper (implicitly defined when passing f
to the NeuralDE) will ensure compatibility of depth-concatenation and data-control with the divergence dimension.
Utilizing additional augmented dimensions is also compatible, as only the first will be used for the jacobian trace.
[6]:
model = nn.Sequential(Augmenter(augment_idx=1, augment_dims=1),
nde).to(device)
Standard Learner. It is often useful to visualize samples during normalizing flow training, in order to identify issues quickly and stop runs that are not promising. For an example of how to log images using PyTorch Lightning
and Wandb
, refer to torchdyn’s benchmark
notebooks.
[7]:
import pytorch_lightning as pl
class Learner(pl.LightningModule):
def __init__(self, model:nn.Module):
super().__init__()
self.model = model
self.iters = 0
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
self.iters += 1
x = batch[0]
xtrJ = self.model(x)
logprob = prior.log_prob(xtrJ[:,1:]).to(x) - xtrJ[:,0] # logp(z_S) = logp(z_0) - \int_0^S trJ
loss = -torch.mean(logprob)
nde.nfe = 0
return {'loss': loss}
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters(), lr=2e-3, weight_decay=1e-5)
def train_dataloader(self):
return trainloader
[8]:
learn = Learner(model)
trainer = pl.Trainer(max_epochs=600)
trainer.fit(learn);
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
| Name | Type | Params
-------------------------------------
0 | model | Sequential | 8 K
Visualizing the Samples¶
Sampling from CNFs is easy: we query the prior latent normal and then pass the samples through the z -> x
CNF flow. To reverse the flow, we flip s_span
:
[9]:
sample = prior.sample(torch.Size([1 << 14]))
# integrating from 1 to 0
model[1].s_span = torch.linspace(1, 0, 2)
new_x = model(sample).cpu().detach()
We then plot, samples, flows and density like so:
[10]:
plt.figure(figsize=(12, 4))
plt.subplot(121)
plt.scatter(new_x[:,1], new_x[:,2], s=2.3, alpha=0.2, linewidths=0.1, c='blue', edgecolors='black')
plt.xlim(-2, 2)
plt.ylim(-2, 2)
plt.subplot(122)
plt.scatter(X[:,0], X[:,1], s=3.3, alpha=0.2, c='red', linewidths=0.1, edgecolors='black')
plt.xlim(-2, 2)
plt.ylim(-2, 2)
[10]:
(-2.0, 2.0)

We plot the flows from prior to data distribution:
[11]:
traj = model[1].trajectory(Augmenter(1, 1)(sample.to(device)), s_span=torch.linspace(1,0,100)).detach().cpu() ; sample = sample.cpu()
traj = traj[:, :, 1:] # scrapping first dimension := jacobian trace
[12]:
n = 2000
plt.figure(figsize=(6,6))
plt.scatter(sample[:n,0], sample[:n,1], s=10, alpha=0.8, c='black')
plt.scatter(traj[:,:n,0], traj[:,:n,1], s=0.2, alpha=0.2, c='olive')
plt.scatter(traj[-1,:n,0], traj[-1,:n,1], s=4, alpha=1, c='blue')
plt.legend(['Prior sample z(S)', 'Flow', 'z(0)'])
[12]:
<matplotlib.legend.Legend at 0x22c1489ac48>

Following tutorials in this series (WIP) will discuss regularization strategies to speed up CNF training, such as those discussed for example in Finlay et al. 2020.