Density Estimation with FFJORDs

Free-form Jacobian of Reversible Dynamics (FFJORD) is a continuous normalizing flow (CNF) variants proposed in Grathwohl et al., 2018. The core novelty is proposing utilization of stochastic trace estimators to improve scalability of the Jacobian trace computation (\(O(1)\) calls to autograd instead of \(O(D)\)).

[1]:
import sys ; sys.path.append('../')
import torchdyn; from torchdyn.models import *; from torchdyn.datasets import *

This time, we use a fun dataset: the DiffEqML logo.

[2]:
data = ToyDataset()
n_samples = 1 << 14
n_gaussians = 7

X, yn = data.generate(n_samples, 'diffeqml', noise=5e-2)
X = (X - X.mean())/X.std()

import matplotlib.pyplot as plt
plt.figure(figsize=(3, 3))
plt.scatter(X[:,0], X[:,1], c='olive', alpha=0.3, s=1)
[2]:
<matplotlib.collections.PathCollection at 0x22c749bf3c8>
../_images/tutorials_07b_ffjord_4_1.png
[3]:
import torch
import torch.utils.data as data
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
X_train = torch.Tensor(X).to(device)
train = data.TensorDataset(X_train)
trainloader = data.DataLoader(train, batch_size=1024, shuffle=True)

The FFJORD model

In torchdyn, we decouple CNFs from the Jacobian trace estimators they use. This allows us to easily implement variants without alternations to the original module. Indeed, we can simply define the Hutchinson stochastic estimator separately as follows

[4]:
def hutch_trace(x_out, x_in, noise=None, **kwargs):
    """Hutchinson's trace Jacobian estimator, O(1) call to autograd"""
    jvp = torch.autograd.grad(x_out, x_in, noise, create_graph=True)[0]
    trJ = torch.einsum('bi,bi->b', jvp, noise)
    return trJ

And then instantiate a CNF as before.

[5]:
f = nn.Sequential(
        nn.Linear(2, 64),
        nn.Softplus(),
        nn.Linear(64, 64),
        nn.Softplus(),
        nn.Linear(64, 64),
        nn.Softplus(),
        nn.Linear(64, 2),
    )

from torch.distributions import MultivariateNormal, Uniform, TransformedDistribution, SigmoidTransform, Categorical
prior = MultivariateNormal(torch.zeros(2).to(device), torch.eye(2).to(device))

# stochastic estimators require a definition of a distribution where "noise" vectors are sampled from
noise_dist = MultivariateNormal(torch.zeros(2).to(device), torch.eye(2).to(device))
# cnf wraps the net as with other energy models
cnf = CNF(f, trace_estimator=hutch_trace, noise_dist=noise_dist)
nde = NeuralDE(cnf, solver='dopri5', s_span=torch.linspace(0, 1, 2), sensitivity='adjoint', atol=1e-4, rtol=1e-4)

Augmenter takes care of setting up the additional scalar dimension for the divergence dynamics. The DEFunc wrapper (implicitly defined when passing f to the NeuralDE) will ensure compatibility of depth-concatenation and data-control with the divergence dimension.

Utilizing additional augmented dimensions is also compatible, as only the first will be used for the jacobian trace.

[6]:
model = nn.Sequential(Augmenter(augment_idx=1, augment_dims=1),
                      nde).to(device)

Standard Learner. It is often useful to visualize samples during normalizing flow training, in order to identify issues quickly and stop runs that are not promising. For an example of how to log images using PyTorch Lightning and Wandb, refer to torchdyn’s benchmark notebooks.

[7]:
import pytorch_lightning as pl

class Learner(pl.LightningModule):
    def __init__(self, model:nn.Module):
        super().__init__()
        self.model = model
        self.iters = 0

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        self.iters += 1
        x = batch[0]
        xtrJ = self.model(x)
        logprob = prior.log_prob(xtrJ[:,1:]).to(x) - xtrJ[:,0] # logp(z_S) = logp(z_0) - \int_0^S trJ
        loss = -torch.mean(logprob)
        nde.nfe = 0
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=2e-3, weight_decay=1e-5)

    def train_dataloader(self):
        return trainloader
[8]:
learn = Learner(model)
trainer = pl.Trainer(max_epochs=600)
trainer.fit(learn);
GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 8 K

Visualizing the Samples

Sampling from CNFs is easy: we query the prior latent normal and then pass the samples through the z -> x CNF flow. To reverse the flow, we flip s_span:

[9]:
sample = prior.sample(torch.Size([1 << 14]))
# integrating from 1 to 0
model[1].s_span = torch.linspace(1, 0, 2)
new_x = model(sample).cpu().detach()

We then plot, samples, flows and density like so:

[10]:
plt.figure(figsize=(12, 4))
plt.subplot(121)
plt.scatter(new_x[:,1], new_x[:,2], s=2.3, alpha=0.2, linewidths=0.1, c='blue', edgecolors='black')
plt.xlim(-2, 2)
plt.ylim(-2, 2)

plt.subplot(122)
plt.scatter(X[:,0], X[:,1], s=3.3, alpha=0.2, c='red',  linewidths=0.1, edgecolors='black')
plt.xlim(-2, 2)
plt.ylim(-2, 2)
[10]:
(-2.0, 2.0)
../_images/tutorials_07b_ffjord_20_1.png

We plot the flows from prior to data distribution:

[11]:
traj = model[1].trajectory(Augmenter(1, 1)(sample.to(device)), s_span=torch.linspace(1,0,100)).detach().cpu() ; sample = sample.cpu()
traj = traj[:, :, 1:] # scrapping first dimension := jacobian trace
[12]:
n = 2000
plt.figure(figsize=(6,6))
plt.scatter(sample[:n,0], sample[:n,1], s=10, alpha=0.8, c='black')
plt.scatter(traj[:,:n,0], traj[:,:n,1], s=0.2, alpha=0.2, c='olive')
plt.scatter(traj[-1,:n,0], traj[-1,:n,1], s=4, alpha=1, c='blue')
plt.legend(['Prior sample z(S)', 'Flow', 'z(0)'])
[12]:
<matplotlib.legend.Legend at 0x22c1489ac48>
../_images/tutorials_07b_ffjord_23_1.png

Following tutorials in this series (WIP) will discuss regularization strategies to speed up CNF training, such as those discussed for example in Finlay et al. 2020.