Molecular graph generation with PyTorch and PyGeometric
We use GraphVAE for molecular generation with one shot generation of a probabilistic graph with predefined maximum size.
#collapse-hide
#Initial imports
import numpy as np
import torch
import matplotlib.pyplot as plt
from glob import glob
import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
from torch.utils.data import Dataset, DataLoader
import torch
from torch_geometric.data import Batch
from torch.utils.data import random_split
Introduction
We represent a molecule as graph $G = (\mathcal{X, A})$ using PyGeometric framework. Each molecule is represented by a feature matrix $\mathcal{X}$ and adjacency matrix $\mathcal{A}$. We use QM9 dataset from MoleculeNet:A Benchmark for Molecular Machine Learning implemented in torch_geometric.datasets.QM9
. PyGeometric relies on rdkit to process the SMILES string and convert them into graphs.
We modify the data processing script in two ways:
- We strip hydrogen atoms from the molecules to keep only the heavy atoms
- We kekulize the molecules to convert aromatic rings to Kekule form
The modified script can be found here
After processing the dataset, we have a set of molecules with 4 heavy atoms (C, N, O, F) and 3 bond types (SINGLE, DOUBLE and TRIPLE) with maximum graph size of 9.
The decoder outputs the graph as one-hot encoded vectors for atoms [9 x 5]
and bonds [9 x 4]
. The label 0 represents empty atom or edge.
#Imports for data pre-processing
import torch_geometric
from qm9_modified import QM9
from torch_geometric.utils.convert import to_networkx
import networkx
# Setting up variables for the dataset
MAX_ATOM = 5
MAX_EDGE = 4
path = '/scratch/project_2002655/datasets/qm9_noH' # Change the path for your local directory structure
dataset = QM9(path)
# Store the max. graph size
MAX_N = -1
for data in dataset:
if MAX_N < data.x.shape[0]: MAX_N = data.x.shape[0]
MAX_E = int(MAX_N * (MAX_N - 1))
print('MAX ATOMS: {}'.format(MAX_N)) # Maximum number of atoms in a graph in the dataset
print('MAX EDGE: {}'.format(MAX_E)) # Corresponding size of upper triangle adjacency matrix
torch_geometric
stores the graph as torch_geometric.data.Data
and we generate the one-hot representation of the graph $G$ as described above. For each graph $G$, we create a vector $\mathcal{X}$ as one-hot encoded for atom of dimension [MAX_N x MAX_ATOM]
and vector bond of dimension [MAX_E x MAX_EDGE]
.
# We create a matrix to map the index of the edge vector $\mathcal{A}$ to the upper triangular adjacency matrix.
index_array = torch.zeros([MAX_N, MAX_N], dtype=int)
idx = 0
for i in range(MAX_N):
for j in range(MAX_N):
if i < j:
index_array[i, j] = idx
idx+=1
print(index_array)
We process the torch_geometric.dataset
to generate matrices $\mathcal{X}$ and $\mathcal{A}$ which act as the ground truth for our decoder. We will also setup utility functions to convert between the vector representation $(\mathcal{X}, \mathcal{A})$ and torch_geometric.data
representation $\mathcal{G}$. We use the following key for atoms and bonds:
C: 1 SINGLE: 1
N: 2 DOUBLE: 2
O: 3 TRIPLE: 3
F: 4
0
is the placeholder label for empty entry.
# Initialize the labels with -1
edge_labels = torch.ones(len(dataset), MAX_E) * -1
atom_labels = torch.ones(len(dataset), MAX_N) * -1
idx = 0
for data in dataset:
edge_attr = data.edge_attr # One hot encoded bond labels
edge_index = data.edge_index # Bond indices as one hot adjacency list
upper_index = edge_index[0] < edge_index[1] # Bond indices as upper triangular adjacency matrix
_, edge_label = torch.max(edge_attr, dim=-1)# Bond labels from one hot vectors
x = data.x[:, 1:5] # One hot encoded atom labels
_, atom_label = torch.max(x, dim=-1) # Atom labels from one hot vectors
# Expand the label vectors to size [MAX_N x MAX_ATOM] and [MAX_E x MAX_EDGE]
atom_labels[idx][:len(atom_label)] = atom_label
a0 = edge_index[0,upper_index]
a1 = edge_index[1,upper_index]
up_idx = index_array[a0, a1]
edge_labels[idx][up_idx] = edge_label[upper_index].float()
idx += 1
atom_labels += 1
edge_labels += 1
Now that we have the dataset represented as $(\mathcal{X}, \mathcal{A})$ let's plot some graphs to visually check if the molecules are as we expected. We use rdkit
to plot the molecules which does a lot of having lifting for us. The function graphToMol
takes in the vectors $(\mathcal{X}, \mathcal{A})$ and returns an object of type rdkit.Mol
. We can also obtain visualizations for the graphs $\mathcal{G}$ by using torch_geometric.utils.convert.to_networkx
and then ploting the netowrkx graph. But rdkit
plots the molecules in a canonical orientation and is built to minimize intramolecular clashes, i.e. to maximize the clarity of the drawing.
#collapse-hide
def get_index(index, index_array):
for i in range(9):
for j in range(9):
if i < j:
if(index_array[i, j] == index):
return [i, j]
def graphToMol(atom, edge):
possible_atoms = {
0: 'H',
1: 'C',
2: 'N',
3: 'O',
4: 'F'
}
possible_edges = {
1: Chem.rdchem.BondType.SINGLE,
2: Chem.rdchem.BondType.DOUBLE,
3: Chem.rdchem.BondType.TRIPLE
}
max_n = 9
mol = Chem.RWMol()
rem_idxs = []
for a in atom:
atom_symbol = possible_atoms[a.item()]
mol.AddAtom(Chem.Atom(atom_symbol))
for a in mol.GetAtoms():
if a.GetAtomicNum() == 1:
rem_idxs.append(a.GetIdx())
for i, e in enumerate(edge):
e = e.item()
if e != 0:
a0, a1 = get_index(i, index_array)
if a0 in rem_idxs or a1 in rem_idxs:
return None
bond_type = possible_edges[e]
mol.AddBond(a0, a1, order=bond_type)
rem_idxs.sort(reverse=True)
for i in rem_idxs:
mol.RemoveAtom(i)
return mol
# We pick 9 random molecules from QM9 dataset to plot
mols = []
for i in np.random.randint(0, 100, size=9):
mols.append(graphToMol(atom_labels[i], edge_labels[i]))
Chem.Draw.IPythonConsole.ShowMols(mols)
The final step is to create a custom dataset for PyTorch which combines the two representations $\mathcal{G}$ and $(\mathcal{X}, \mathcal{A})$. The dataset returns a list containing the graph, atom label and edge label. We also define a collate_fn
which creates mini batches for the data loader that we will need for training and testing our model.
#collapse-hide
# Utility function to combine the torch_geometric.data and torch.tensor into one mini batch
def collate_fn(datas):
graphs, atoms, edges = [], [], []
for data in datas:
graph, atom, edge = data
graphs.append(graph)
atoms.append(atom)
edges.append(edge)
graphs = Batch.from_data_list(graphs)
atoms = torch.stack(atoms)
edges = torch.stack(edges)
return graphs, atoms, edges
# Returns the dataloader with a given batch size
def get_dataloader(dataset, batch_size=32, shuffle=True):
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
# Create the dataset from the QM9 dataset from PyGeometric and the label matrices
class GraphTensorDataset(Dataset):
def __init__(self):
self.graph_dataset = QM9('/scratch/project_2002655/datasets/qm9_noH')
self.atom_labels = torch.load('/scratch/project_2002655/drug_discovery/saved/atom_labels_noH.tensor')
self.edge_labels = torch.load('/scratch/project_2002655/drug_discovery/saved/edge_labels_noH.tensor')
def __len__(self):
return len(self.graph_dataset)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
graphs = self.graph_dataset[idx]
atoms = self.atom_labels[idx]
edges = self.edge_labels[idx]
return graphs, atoms, edges
# return the subset of data with the specified size
def get_splits(self, train_size, test_size, cv_size=None):
total_data = len(self.graph_dataset)
if cv_size is not None:
trainset, testset = random_split(self, [train_size+cv_size, total_data - train_size - cv_size])
trainset, cvset = random_split(trainset, [train_size, cv_size])
testset, _ = random_split(testset, [test_size, len(testset) - test_size])
self.trainset = trainset
self.cvset = cvset
self.testset = testset
return self.trainset, self.testset, self.cvset
else:
trainset, testset = random_split(self, [train_size, total_data - train_size])
testset, _ = random_split(testset, [test_size, len(testset) - test_size])
self.trainset = trainset
self.testset = testset
return self.trainset, self.testset
Now we wrap up this section by creating the train and test loaders with a mini batch of size 32
. We train the model on a subset of 500 molecules and test the reconstruction on 10000 molecules. To monitor training we look at three metrics averaged over the minibatch:
- cross-entropy loss over the atom and edge labels
- Atom label accuracy
- Edge label accuracy
dataset = GraphTensorDataset()
train_size = int(500)
test_size = int(10000)
torch.manual_seed(0) # We set the random seed before the random split so we can get the same subset for reproducibility
trainset, testset = dataset.get_splits(train_size, test_size)
trainloader = get_dataloader(trainset)
testloader = get_dataloader(testset)
Model
Our generative model is a VAE with an encoder $q_\phi (\textbf{z} | \mathcal{G})$ and a decoder $p_\theta (\mathcal{X}, \mathcal{A} | \textbf{z})$ parameterized by $\phi$ and $\theta$ respectively. We place a prior over the latent vector $\textbf{z} \sim \mathcal{N}(0, I)$ which act as a regularizer to our model. Our goal is to map the graph $\mathcal{G}$ into the latent vector $\textbf{z}$ and generate new graphs by using the decoder to map new latent vectors sampled from the prior.
Encoder: For a graph $\mathcal{G}(\mathcal{V}, \mathcal{E})$ with upto N nodes, we aim to map the graph to a latent vector $\textbf{z} \in \mathbb{R}^D$. We use graph convolution network(GCN) to learn graph level feature vector $h_{\mathcal{G}}$. For each node $\mathcal{v} \in \mathcal{V}$, GCN update the node level features as, \begin{equation} h_{\mathcal{v}} = \sum_{\mathcal{u} \in Nbd(\mathcal{v}, \mathcal{G})}\Phi (W_l^T X_{\mathcal{u}}) \end{equation} where $Nbd(\mathcal{v}, \mathcal{G})$ is the neighborhood of node $\mathcal{v}$ in graph $\mathcal{G}$, $W_l$ are the shared weights analogous to the kernel in convolution layer and $X_\mathcal{u}$ is the feature vector of node $\mathcal{u}$ with $\Phi$ as the non-linearity.
After 2 graph convolution layers, we create a graph level vector representation by taking the mean of all node feature vectors $h_{\mathcal{v}}$. \begin{equation} h_{\mathcal{G}} = 1 / |\mathcal{V}| \sum_{\mathcal{v} \in \mathcal{V}} h_{\mathcal{v}} \end{equation} Now we use fully connected layers to learn the latent vector $\textbf{z}$ with the inference model as, \begin{equation} q(\textbf{z}|\mathcal{G}) = \mathcal{N}(\textbf{z}|\mu(h_{\mathcal{G}}), \sigma^2(h_{\mathcal{G}})) \end{equation} where $\mu(h_{\mathcal{G}})$ and $\sigma^2(h_{\mathcal{G}})$ are parameterized by neural networks to learn the latent space.
Decoder: The primary problem with graph generation is that order of nodes can be arbritary which makes comparing two graphs a combatorial problem. We sidestep this problem by framing the graph generation problem as multi class classification problem. Keeping in the spirit of a toy example, we consider two independent classification problems.
We train a decoder to predict atom and edge label vectors $(\mathcal{X}, \mathcal{A})$ respectively. We can train the model by using the standard VAE loss, \begin{equation} \mathcal{L} = \alpha_{\mathcal{a}} \mathcal{L}_{\mathcal{a}} + \alpha_\mathcal{e} \mathcal{L}_{\mathcal{e}} + \alpha_{KL} KL[q(\textbf{z}|\mathcal{G}) | p(\textbf{z})] \end{equation}
$\mathcal{L}_{\mathcal{a}}$ and $\mathcal{L}_{\mathcal{e}}$ are the cross entropy loss for multiclass vectors for atoms and edges respectively. This restricts the decoder to memorize the ordering for the atom labels and independently learn the edges based only on the latent vector $\textbf{z}$. We set $\alpha_\mathcal{a} = 1/|\mathcal{X}|$, $\alpha_\mathcal{e} = 1/|{\mathcal{A}}|$ and $\alpha_KL = 1/ 128$ so all three loses are on the same scale. $\mathcal{X}$ is the number of atoms in a mini batch, $|\mathcal{A}|$ is the number of edges in a mini batch and we set the dimension of the latent vector as $128$.
We use a single sample from the latent distribution to train the VAE and 100 samples during testing.
#collapse-hide
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import GCNConv, BatchNorm, GlobalAttention
from torch_geometric.utils import add_self_loops
class VAE(nn.Module):
def __init__(self, max_n, max_atoms, max_edges, in_features):
super(VAE, self).__init__()
n_output = int(max_n * (max_n - 1) / 2)
self.n_output = n_output
self.max_n = max_n
self.max_atoms = max_atoms
self.max_edges = max_edges
self.gc1 = GCNConv(in_features, 32)
self.gc2 = GCNConv(32, 64)
self.gc1_bn = nn.BatchNorm1d(32)
self.gc2_bn = nn.BatchNorm1d(64)
attn1 = nn.Linear(64, 1)
attn2 = nn.Linear(64, 64)
self.graph_pool = GlobalAttention(attn1, attn2)
self.fc1 = nn.Linear(64, 128)
self.fc1_bn = nn.BatchNorm1d(128)
self.mu = nn.Linear(128, 128)
self.logv = nn.Linear(128, 128)
self.dec1 = nn.Linear(128, 128)
self.dec1_bn = nn.BatchNorm1d(128)
self.dec2 = nn.Linear(128, 256)
self.dec2_bn = nn.BatchNorm1d(256)
self.dec3 = nn.Linear(256, 512)
self.dec3_bn = nn.BatchNorm1d(512)
self.atoms = nn.Linear(512, max_atoms*max_n)
self.edges = nn.Linear(512, max_edges*n_output)
def encode(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
edge_index, _ = add_self_loops(edge_index)
x = self.gc1(x, edge_index)
x = F.relu(self.gc1_bn(x))
x = self.gc2(x, edge_index)
x = F.relu(self.gc2_bn(x))
#x = self.graph_pool(x, batch)
#x = F.tanh(x)
x = torch_geometric.nn.global_add_pool(x, data.batch)
x = self.fc1(x)
x = F.relu(self.fc1_bn(x))
return self.mu(x), self.logv(x)
def sample(self, mu, logv):
eps = torch.randn_like(logv).to(mu.device)
return mu + eps * torch.exp(logv)
def decode(self, sample):
x = self.dec1(sample)
x = F.relu(self.dec1_bn(x))
x = self.dec2(x)
x = F.relu(self.dec2_bn(x))
x = self.dec3(x)
x = F.relu(self.dec3_bn(x))
atoms = self.atoms(x)
atoms = atoms.view(-1, self.max_n, self.max_atoms)
atoms = F.softmax(atoms, dim=-1)
edges = self.edges(x)
edges = edges.view(-1, self.n_output, self.max_edges)
edges = F.softmax(edges, dim=-1)
return atoms, edges
def forward(self, data, num_samples=1):
mu, logv = self.encode(data)
_atoms, _edges = [], []
for i in range(num_samples):
sample = self.sample(mu, logv)
atoms, edges = self.decode(sample)
_atoms.append(atoms)
_edges.append(edges)
return mu, logv, _atoms, _edges
Now that we have the dataloader and the model, we can setup utility functions train
and test
along with the metrics that we need to evaluate the model performance. We use multi class accuracy to monitor as an indicator for model performance.
#collapse-hide
from pathlib import Path
# KL loss for the latent vector
def get_vae_loss(mu, logv):
kl_loss = 0.5 * torch.sum(torch.exp(logv) + mu**2 - 1.0 - logv)
return kl_loss/128
# Reconstruction loss for atom and edge vectors
def get_recon_loss(atom_true, atom_prob, edge_true, edge_prob):
ashape = atom_prob.shape
num_atoms = (atom_true!=0).int().sum()
atom_prob = torch.log(atom_prob)
atom_loss = F.nll_loss(atom_prob.view(-1, ashape[-1]),
atom_true.view(-1).long(), reduction='sum') / num_atoms
eshape = edge_prob.shape
edge_prob = torch.log(edge_prob)
num_edges = (edge_true!=0).int().sum()
edge_loss = F.nll_loss(edge_prob.view(-1, eshape[-1]),
edge_true.view(-1).long(), reduction='sum') / num_edges
return atom_loss, edge_loss
# Return the correct predictions for multi class label classification
def get_accuracy(y_true, y_prob, threshold = 0.5):
if len(y_prob.shape) <=2:
y_pred = (y_prob >= threshold).int()
else:
_, y_pred = torch.max(y_prob, dim=-1)
index = y_true != 0
correct = y_true[index] == y_pred[index]
return correct.int().sum(), index.int().sum()
# Class to run and store training statistics we can monitor for model performance
class Experiment():
def __init__(self, model, optimizer, path = './saved/tmp/'):
self.model = model
self.optimizer = optimizer
self.path = path
Path(path+'models/').mkdir(parents=True, exist_ok=True)
Path(path+'metrics/').mkdir(parents=True, exist_ok=True)
Path(path+'figs/').mkdir(parents=True, exist_ok=True)
self.trainLoss = []
self.trainAcc = []
self.trainAUC = []
self.trainLossBatch = []
self.trainAccBatch = []
self.betaBatch = []
self.testLoss = []
self.testAcc = []
self.testAUC = []
self.testLossBatch = []
self.testAccBatch = []
def train(self, dataloader, epoch, cyclic = False, device = 'cuda'):
self.model.train()
total_data = len(dataloader.dataset)
total_batches = len(dataloader)
batch_idx = epoch * total_batches
if cyclic:
batches_cycle = int(total_batches / 3)
batches_half_cycle = int(batches_cycle / 2)
betas = np.ones(batches_cycle)
betas[:batches_half_cycle] = np.linspace(0, 1, num=batches_half_cycle)
atom_correct = atom_total = edge_correct = edge_total = 0.
atom_auc = edge_auc = 0.
kl_loss = atom_loss = edge_loss = 0.
beta = 1
for data in dataloader:
graph = data[0].to(device)
atom_true = data[1].to(device)
edge_true = data[2].to(device)
batch_size = atom_true.shape[0]
# Core Training
self.optimizer.zero_grad()
mu, logv, atom_probs, edge_probs = self.model(graph)
_kl_loss = get_vae_loss(mu, logv)
aloss = eloss = 0.
for atom_prob, edge_prob in zip(atom_probs, edge_probs):
_atom_loss, _edge_loss = get_recon_loss(atom_true, atom_prob, edge_true, edge_prob)
aloss += _atom_loss
eloss += _edge_loss
C = len(atom_probs)
_atom_loss = aloss / C
_edge_loss = eloss / C
_recon_loss = _atom_loss + _edge_loss
if cyclic:
beta_idx = batch_idx % batches_cycle
beta = betas[beta_idx]
loss = _kl_loss * beta + _recon_loss
loss.backward()
self.optimizer.step()
batch_idx += 1
with torch.no_grad():
self.trainLossBatch.append([_kl_loss, _atom_loss, _edge_loss])
self.betaBatch.append(beta)
kl_loss += _kl_loss.item() * batch_size
atom_loss += _atom_loss.item() * batch_size
edge_loss += _edge_loss.item() * batch_size
_atom_correct = _edge_correct = _atom_auc = _edge_auc = 0.
C = len(atom_probs)
for atom_prob, edge_prob in zip(atom_probs, edge_probs):
acorrect, _atom_total = get_accuracy(atom_true, atom_prob)
ecorrect, _edge_total = get_accuracy(edge_true, edge_prob)
aauc = get_multi_auc(atom_true.cpu(), atom_prob.cpu(), atom_prob.shape[-1]) * batch_size
eauc = get_multi_auc(edge_true.cpu(), edge_prob.cpu(), edge_prob.shape[-1]) * batch_size
_atom_correct += acorrect
_edge_correct += ecorrect
_atom_auc += aauc
_edge_auc += eauc
atom_auc += _atom_auc.item() / C
edge_auc += _edge_auc.item() / C
atom_correct += _atom_correct.item() / C
atom_total += _atom_total.item()
edge_correct += _edge_correct.item() / C
edge_total += _edge_total.item()
self.trainAccBatch.append([_atom_correct/C/_atom_total, _edge_correct/C/_edge_total])
self.trainLoss.append([kl_loss/total_data, atom_loss/total_data, edge_loss/total_data])
self.trainAcc.append([atom_correct/atom_total, edge_correct/edge_total])
self.trainAUC.append([atom_auc/total_data, edge_auc/total_data])
def test(self, dataloader, epoch, num_samples = 1, device = 'cuda'):
self.model.eval()
with torch.no_grad():
kl_loss = 0.
atom_loss = 0.
edge_loss = 0.
total_data = len(dataloader.dataset)
total_batches = len(dataloader)
batch_idx = epoch * total_batches
atom_correct = atom_total = edge_correct = edge_total = 0.
atom_auc = edge_auc = 0.
for data in dataloader:
graph = data[0].to(device)
atom_true = data[1].to(device)
edge_true = data[2].to(device)
batch_size = atom_true.shape[0]
mu, logv, atom_probs, edge_probs = self.model(graph, num_samples)
_kl_loss = get_vae_loss(mu, logv)
aloss = eloss = 0.
for atom_prob, edge_prob in zip(atom_probs, edge_probs):
_atom_loss, _edge_loss = get_recon_loss(atom_true, atom_prob, edge_true, edge_prob)
aloss += _atom_loss
eloss += _edge_loss
C = len(atom_probs)
_atom_loss = aloss / C
_edge_loss = eloss / C
kl_loss += _kl_loss.item() * batch_size
atom_loss += _atom_loss.item() * batch_size
edge_loss += _edge_loss.item() * batch_size
_atom_correct = _edge_correct = _atom_auc = _edge_auc = 0.
C = len(atom_probs)
for atom_prob, edge_prob in zip(atom_probs, edge_probs):
acorrect, _atom_total = get_accuracy(atom_true, atom_prob)
ecorrect, _edge_total = get_accuracy(edge_true, edge_prob)
aauc = get_multi_auc(atom_true.cpu(), atom_prob.cpu(), atom_prob.shape[-1]) * batch_size
eauc = get_multi_auc(edge_true.cpu(), edge_prob.cpu(), edge_prob.shape[-1]) * batch_size
_atom_correct += acorrect
_edge_correct += ecorrect
_atom_auc += aauc
_edge_auc += eauc
atom_auc += _atom_auc.item() / C
edge_auc += _edge_auc.item() / C
atom_correct += _atom_correct.item() / C
atom_total += _atom_total.item()
edge_correct += _edge_correct.item() / C
edge_total += _edge_total.item()
self.testLossBatch.append([_kl_loss, _atom_loss, _edge_loss])
self.testAccBatch.append([_atom_correct/_atom_total, _edge_correct/_edge_total])
self.testLoss.append([kl_loss/total_data, atom_loss/total_data, edge_loss/total_data])
self.testAcc.append([atom_correct/atom_total, edge_correct/edge_total])
self.testAUC.append([atom_auc/total_data, edge_auc/total_data])
def save_state(self, epoch):
metrics = {}
trainLoss = np.array(self.trainLoss)
trainAcc = np.array(self.trainAcc)
trainAUC = np.array(self.trainAUC)
testLoss = np.array(self.testLoss)
testAcc = np.array(self.testAcc)
testAUC = np.array(self.testAUC)
metrics['Train KL Loss'] = trainLoss[:, 0]
metrics['Train Atom Loss'] = trainLoss[:, 1]
metrics['Train Edge Loss'] = trainLoss[:, 2]
metrics['Train Atom Acc'] = trainAcc[:, 0]
metrics['Train Edge Acc'] = trainAcc[:, 1]
metrics['Train Atom AUC'] = trainAUC[:, 0]
metrics['Train Edge AUC'] = trainAUC[:, 1]
metrics['Test KL Loss'] = testLoss[:, 0]
metrics['Test Atom Loss'] = testLoss[:, 1]
metrics['Test Edge Loss'] = testLoss[:, 2]
metrics['Test Atom Acc'] = testAcc[:, 0]
metrics['Test Edge Acc'] = testAcc[:, 1]
metrics['Test Atom AUC'] = testAUC[:, 0]
metrics['Test Edge AUC'] = testAUC[:, 1]
torch.save(metrics, self.path + 'metrics/epoch_{}.metric'.format(epoch))
batch_metrics = {}
trainLossBatch = np.array(self.trainLossBatch)
trainAccBatch = np.array(self.trainAccBatch)
testLossBatch = np.array(self.testLossBatch)
testAccBatch = np.array(self.testAccBatch)
batch_metrics['Train KL Loss'] = trainLossBatch[:, 0]
batch_metrics['Train Atom Loss'] = trainLossBatch[:, 1]
batch_metrics['Train Edge Loss'] = trainLossBatch[:, 2]
batch_metrics['Train Atom Acc'] = trainAccBatch[:, 0]
batch_metrics['Train Edge Acc'] = trainAccBatch[:, 1]
batch_metrics['Test KL Loss'] = testLossBatch[:, 0]
batch_metrics['Test Atom Loss'] = testLossBatch[:, 1]
batch_metrics['Test Edge Loss'] = testLossBatch[:, 2]
batch_metrics['Test Atom Acc'] = testAccBatch[:, 0]
batch_metrics['Test Edge Acc'] = testAccBatch[:, 1]
torch.save(batch_metrics, self.path + 'metrics/batches.metric')
torch.save(self.model.state_dict(), self.path + 'models/gvae_{}.model'.format(epoch))
We can use many tricks from the VAE literature to improve performance and obtain better results. Cyclic Annealing Schedule for VAE training has been show to significantly improve performance on autoregressive NLP tasks such as language modeling, dialog response generation etc. We can train the model using cyclic annealing schedule by setting the flag cyclic=True
.
We use the mean of 40 samples during testing to calculate the loss and accuracy on the test set.
model = VAE(MAX_N, MAX_ATOM, MAX_EDGE, 13)
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
epochs = 100
device = 'cuda'
model = model.to(device)
path = './saved/test/'
experiment = Experiment(model, optimizer, path)
for epoch in range(epochs):
experiment.train(trainloader, epoch, cyclic=False)
experiment.test(testloader, epoch, 40)
experiment.save_state(epoch)
Results
We train the model on the dataset of 500 molecules and test on a dataset of 10,000 molecules for 100 epochs. The small size of the trainset leads to a large generalization gap and the model overfits on trainset with the atom and edge cross entropy loss decreasing but a flat test loss. The test atom classification accuracy converges at around 80% while the edge classification accuracy converges at around 75%. We only consider the labels with non-zero entry in the ground truth to calculate accuracy.
Now we consider metrics to check the generative performance for molecules. We sample $n_s = 10,000$ latent vectors $\textbf{z}$ from the prior $\mathcal{N}(0, I)$ and use the trained decoder to generate new molecules. We use the following three metrics to test the quality of the learned latent space:
- Validity: Let list $|V|$ represent chemically valid molecules, validity is the ratio $|V|/n_s$
- Uniqueness: The set of chemically valid molecules $U = set(V)$, uniqueness is the ratio $|U|/|V|$
- Novelty: Novelty is the set of elements in $U$ not in the trainset $\mathcal{D}_{train}$ defined as $1 - (U \cap \mathcal{D}_{train})/|U|$
We obtain the following results from our trained model:
Validity | Uniqueness | Novelty |
---|---|---|
62.33% | 76.40% | 66.71% |
Even the simple toy example can generate a substantial number of unique and novel valid molecules. This is without any graph matching in the loss function so the model can learn to predict nodes in arbritary order and is restricted by the order of nodes in the trainset. Also these results are obtained from a very small subset of data and is comparable to GraphAF SOTA method without validity checks during training (67% in Table 3).
The model predicts a lot of disconnected graphs and molecules with invalid chemical valency rules
Conclusion
The GraphVAE presents a simple model to learn a generative model over graphs and gives a reasonable performance in modelling the structure behind molecules. This gives us a simple but suprisingly effective baseline for structured generation. Several work has improved upon this baseline using reinforcement learning and flows for goal directed generation of molecules. Both of these methods use an oracle during autoregressive generation to prune chemically invalid additions to the sub-graph. Going forward, we will look at implementing the SOTA methods for improved baseline. We will also explore ways to develop algorithms for structured data without depending on an oracle but learn the underlying rules implicitly from the data.