[2]:
import math
import numpy as np
import scipy.sparse as sp
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

import dgl
import dgl.function as fn

import dgl.data
import networkx as nx

from torchdyn.models import *; from torchdyn.datasets import *
from torchdyn import *
Using backend: pytorch

Neural Graph Differential Equations

Semi-supervised node classification

This notebook introduces Neural GDEs as a general high-performance model for graph structured data. Notebook 07_graph_differential_equations is designed from the ground up as an introduction to Neural GDEs and therefore contains ample comments to provide insights on some of our design choices. To be accessible to practicioners/researchers without prior experience on GNNs, we discuss some features of dgl as well, one of the PyTorch ecosystems for geometric deep learning.

Data preparation

[3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# seed for repeatability
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

torch.manual_seed(0)
np.random.seed(0)
[4]:
# dgl offers convenient access to GNN benchmark datasets via `dgl.data`...
# other standard datasets (e.g. Citeseer / Pubmed) are also accessible via the dgl.data
# API. The rest of the notebook is compatible with Cora / Citeseer / Pubmed with minimal
# modification required.
data = dgl.data.CoraDataset()
Downloading /root/.dgl/cora.zip from https://data.dgl.ai/dataset/cora_raw.zip...
Extracting file to /root/.dgl/cora
[5]:
# Cora is a node-classification datasets with 2708 nodes
X = torch.FloatTensor(data.features).to(device)
Y = torch.LongTensor(data.labels).to(device)

# In transductive semi-supervised node classification tasks on graphs, the model has access to all
# node features but only a masked subset of the labels
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)

num_feats = X.shape[1]
n_classes = data.num_labels

# 140 training samples, 300 validation, 1000 test
n_classes, train_mask.sum().item(), val_mask.sum().item(),test_mask.sum().item()
[5]:
(7, 140, 300, 1000)
[6]:
# add self-edge for each node
g = data.graph
g.remove_edges_from(nx.selfloop_edges(g))
g.add_edges_from(zip(g.nodes(), g.nodes()))
g = dgl.DGLGraph(g)
edges = g.edges()
n_edges = g.number_of_edges()

n_edges
[6]:
13264
[7]:
# compute diagonal of normalization matrix D according to standard formula
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
# add to dgl.Graph in order for the norm to be accessible at training time
g.ndata['norm'] = norm.unsqueeze(1).to(device)

Neural GCDE

As Neural ODEs, GDEs require specification of an ODE function (ODEFunc), representing the set of layers that will be called repeatedly by the ODE solver.

Here, we use the convolutional variant of Neural GDEs based on GCNs: Neural GCDEs. The only difference with alternative neural GDEs resides in the type of GNN layer utilized in the ODEFunc.

For adaptive step GDEs (dopri5) we increase the hidden dimension to 64 to reduce the stiffness of the ODE and therefore the number of ODEFunc evaluations (NFE: Number Function Evaluation)

First, we define the auxiliary GNN model as a standard GCN. Luckily, in this example the graph is static and can thus be assigned during initialization. For varying graphs, additional bookeeping is required.

[8]:
def accuracy(y_hat:torch.Tensor, y:torch.Tensor):
    preds = torch.max(y_hat, 1)[1]
    return torch.mean((y == preds).float())
[9]:
class GCNLayer(nn.Module):
    def __init__(self, g:dgl.DGLGraph, in_feats:int, out_feats:int, activation,
                 dropout:int, bias:bool=True):
        super().__init__()
        self.g = g
        self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_feats))
        else:
            self.bias = None
        self.activation = activation
        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0.
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, h):
        if self.dropout:
            h = self.dropout(h)
        h = torch.mm(h, self.weight)
        # normalization by square root of src degree
        h = h * self.g.ndata['norm']
        self.g.ndata['h'] = h
        self.g.update_all(fn.copy_src(src='h', out='m'),
                          fn.sum(msg='m', out='h'))
        h = self.g.ndata.pop('h')
        # normalization by square root of dst degree
        h = h * self.g.ndata['norm']
        # bias
        if self.bias is not None:
            h = h + self.bias
        if self.activation:
            h = self.activation(h)
        return h

Then, we construct the Neural GDE as follows:

[34]:
func = nn.Sequential(GCNLayer(g=g, in_feats=64, out_feats=64, activation=nn.Softplus(), dropout=0.9),
                     GCNLayer(g=g, in_feats=64, out_feats=64, activation=None, dropout=0.9)
                     ).to(device)
[35]:
neuralDE = NeuralDE(func, solver='rk4', s_span=torch.linspace(0, 1, 3)).to(device)
[36]:
m = nn.Sequential(GCNLayer(g=g, in_feats=num_feats, out_feats=64, activation=None, dropout=0.4),
                  neuralDE,
                  GCNLayer(g=g, in_feats=64, out_feats=n_classes, activation=None, dropout=0.)
                  ).to(device)

Training loop

[37]:
class PerformanceContainer(object):
    """ Simple data class for metrics logging."""
    def __init__(self, data:dict):
        self.data = data

    @staticmethod
    def deep_update(x, y):
        for key in y.keys():
            x.update({key: list(x[key] + y[key])})
        return x
[38]:
opt = torch.optim.Adam(m.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
logger = PerformanceContainer(data={'train_loss':[], 'train_accuracy':[],
                                   'test_loss':[], 'test_accuracy':[],
                                   'forward_time':[], 'backward_time':[],
                                   })

[39]:
steps = 5000
verbose_step = 150
num_grad_steps = 0

for i in range(steps): # looping over epochs
    m.train()
    outputs = m(X)
    y_pred = outputs
    loss = criterion(y_pred[train_mask], Y[train_mask])
    opt.zero_grad()

    start_time = time.time()
    loss.backward()

    opt.step()
    num_grad_steps += 1

    with torch.no_grad():
        m.eval()

        # calculating outputs again with zeroed dropout
        y_pred = m(X)

        train_loss = loss.item()
        train_acc = accuracy(y_pred[train_mask], Y[train_mask]).item()
        test_acc = accuracy(y_pred[test_mask], Y[test_mask]).item()
        test_loss = criterion(y_pred[test_mask], Y[test_mask]).item()
        logger.deep_update(logger.data, dict(train_loss=[train_loss], train_accuracy=[train_acc],
                           test_loss=[test_loss], test_accuracy=[test_acc])
                          )

    if num_grad_steps % verbose_step == 0:
        print('[{}], Loss: {:3.3f}, Train Accuracy: {:3.3f}, Test Accuracy: {:3.3f}'.format(num_grad_steps,
                                                                                                    train_loss,
                                                                                                    train_acc,
                                                                                                    test_acc,
                                                                                                    ))
[150], Loss: 1.457, Train Accuracy: 0.514, Test Accuracy: 0.377
[300], Loss: 0.730, Train Accuracy: 0.907, Test Accuracy: 0.731
[450], Loss: 0.542, Train Accuracy: 0.921, Test Accuracy: 0.766
[600], Loss: 0.416, Train Accuracy: 0.950, Test Accuracy: 0.816
[750], Loss: 0.557, Train Accuracy: 0.943, Test Accuracy: 0.810
[900], Loss: 0.353, Train Accuracy: 0.964, Test Accuracy: 0.819
[1050], Loss: 0.265, Train Accuracy: 0.971, Test Accuracy: 0.807
[1200], Loss: 0.340, Train Accuracy: 0.964, Test Accuracy: 0.828
[1350], Loss: 0.201, Train Accuracy: 0.971, Test Accuracy: 0.828
[1500], Loss: 0.368, Train Accuracy: 0.971, Test Accuracy: 0.824
[1650], Loss: 0.255, Train Accuracy: 0.979, Test Accuracy: 0.812
[1800], Loss: 0.241, Train Accuracy: 0.971, Test Accuracy: 0.820
[1950], Loss: 0.304, Train Accuracy: 0.979, Test Accuracy: 0.821
[2100], Loss: 0.248, Train Accuracy: 0.971, Test Accuracy: 0.828
[2250], Loss: 0.223, Train Accuracy: 0.979, Test Accuracy: 0.815
[2400], Loss: 0.180, Train Accuracy: 0.979, Test Accuracy: 0.834
[2550], Loss: 0.321, Train Accuracy: 0.986, Test Accuracy: 0.825
[2700], Loss: 0.166, Train Accuracy: 0.986, Test Accuracy: 0.808
[2850], Loss: 0.171, Train Accuracy: 0.986, Test Accuracy: 0.821
[3000], Loss: 0.190, Train Accuracy: 0.986, Test Accuracy: 0.827
[3150], Loss: 0.207, Train Accuracy: 0.993, Test Accuracy: 0.823
[3300], Loss: 0.159, Train Accuracy: 0.986, Test Accuracy: 0.817
[3450], Loss: 0.183, Train Accuracy: 0.993, Test Accuracy: 0.829
[3600], Loss: 0.161, Train Accuracy: 0.986, Test Accuracy: 0.831
[3750], Loss: 0.143, Train Accuracy: 0.986, Test Accuracy: 0.826
[3900], Loss: 0.182, Train Accuracy: 0.986, Test Accuracy: 0.826
[4050], Loss: 0.156, Train Accuracy: 0.993, Test Accuracy: 0.817
[4200], Loss: 0.177, Train Accuracy: 0.986, Test Accuracy: 0.819
[4350], Loss: 0.160, Train Accuracy: 0.993, Test Accuracy: 0.811
[4500], Loss: 0.182, Train Accuracy: 0.986, Test Accuracy: 0.829
[4650], Loss: 0.130, Train Accuracy: 0.986, Test Accuracy: 0.812
[4800], Loss: 0.143, Train Accuracy: 0.986, Test Accuracy: 0.830
[4950], Loss: 0.207, Train Accuracy: 0.993, Test Accuracy: 0.818