Despite some of their technical issues, graph neural networks (GNNs) are quickly being adopted as one of the state-of-the-art methods for molecular property prediction. The differentiable extraction of molecular features from low-level molecular graphs has become a viable (although not always superior) alternative to classical molecular representation techniques such as Morgan fingerprints and molecular descriptor vectors.
But molecular data usually comes in the sequential form of labeled SMILES strings. It is not obvious for beginners how to optimally transform a SMILES string into a structured molecular graph object that can be used as an input for a GNN. In this post, we show how to convert a SMILES string into a molecular graph object which can subsequently be used for graph-based machine learning. We do so within the framework of Pytorch Geometric which currently is one of the best and most commonly used Python-based GNN-libraries.
We divide our task into three high-level steps:
- We define a function that maps an RDKit atom object to a suitable atom feature vector.
- We define a function that maps an RDKit bond object to a suitable bond feature vector.
- We define a function that takes as its input a list of SMILES strings and associated labels and then uses the functions from 1.) and 2.) to create a list of labeled Pytorch Geometric graph objects as its output.
Step 0: Import Packages
As always, we first import the necessary Python packages for our endeavour:
# import packages # general tools import numpy as np # RDkit from rdkit import Chem from rdkit.Chem.rdmolops import GetAdjacencyMatrix # Pytorch and Pytorch Geometric import torch from torch_geometric.data import Data from torch.utils.data import DataLoader
Step 1: Atom Featurisation
We start by defining an auxiliary function which transforms a value x into a one-hot encoding based on a list of permitted values for x:
def one_hot_encoding(x, permitted_list): """ Maps input elements x which are not in the permitted list to the last element of the permitted list. """ if x not in permitted_list: x = permitted_list[-1] binary_encoding = [int(boolean_value) for boolean_value in list(map(lambda s: x == s, permitted_list))] return binary_encoding
Now we use this auxiliary function to define the actual atom featurisation function:
def get_atom_features(atom, use_chirality = True, hydrogens_implicit = True): """ Takes an RDKit atom object as input and gives a 1d-numpy array of atom features as output. """ # define list of permitted atoms permitted_list_of_atoms = ['C','N','O','S','F','Si','P','Cl','Br','Mg','Na','Ca','Fe','As','Al','I', 'B','V','K','Tl','Yb','Sb','Sn','Ag','Pd','Co','Se','Ti','Zn', 'Li','Ge','Cu','Au','Ni','Cd','In','Mn','Zr','Cr','Pt','Hg','Pb','Unknown'] if hydrogens_implicit == False: permitted_list_of_atoms = ['H'] + permitted_list_of_atoms # compute atom features atom_type_enc = one_hot_encoding(str(atom.GetSymbol()), permitted_list_of_atoms) n_heavy_neighbors_enc = one_hot_encoding(int(atom.GetDegree()), [0, 1, 2, 3, 4, "MoreThanFour"]) formal_charge_enc = one_hot_encoding(int(atom.GetFormalCharge()), [-3, -2, -1, 0, 1, 2, 3, "Extreme"]) hybridisation_type_enc = one_hot_encoding(str(atom.GetHybridization()), ["S", "SP", "SP2", "SP3", "SP3D", "SP3D2", "OTHER"]) is_in_a_ring_enc = [int(atom.IsInRing())] is_aromatic_enc = [int(atom.GetIsAromatic())] atomic_mass_scaled = [float((atom.GetMass() - 10.812)/116.092)] vdw_radius_scaled = [float((Chem.GetPeriodicTable().GetRvdw(atom.GetAtomicNum()) - 1.5)/0.6)] covalent_radius_scaled = [float((Chem.GetPeriodicTable().GetRcovalent(atom.GetAtomicNum()) - 0.64)/0.76)] atom_feature_vector = atom_type_enc + n_heavy_neighbors_enc + formal_charge_enc + hybridisation_type_enc + is_in_a_ring_enc + is_aromatic_enc + atomic_mass_scaled + vdw_radius_scaled + covalent_radius_scaled if use_chirality == True: chirality_type_enc = one_hot_encoding(str(atom.GetChiralTag()), ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"]) atom_feature_vector += chirality_type_enc if hydrogens_implicit == True: n_hydrogens_enc = one_hot_encoding(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4, "MoreThanFour"]) atom_feature_vector += n_hydrogens_enc return np.array(atom_feature_vector)
To encapsulate as much information as possible within the molecular graph, we include a plethora of atomic features: atom type, number of heavy atom neighbours, formal charge, hybridisation type, whether the atom is in a ring, whether the atom is aromatic, atomic mass, Van der Waals radius, and covalent radius. The last three properties are numerical in nature and are thus automatically scaled to a reasonable range using empirically estimated quantities. The user can explicitly specify whether to use chirality as a stereochemical feature and whether to treat hydrogen atoms implicitly or explicitly.
Step 2: Bond Featurisation
Now that we have constructed a function to conveniently turn RDKit atom objects into feature vectors, we define an analogous function for RDKit bond objects:
def get_bond_features(bond, use_stereochemistry = True): """ Takes an RDKit bond object as input and gives a 1d-numpy array of bond features as output. """ permitted_list_of_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] bond_type_enc = one_hot_encoding(bond.GetBondType(), permitted_list_of_bond_types) bond_is_conj_enc = [int(bond.GetIsConjugated())] bond_is_in_ring_enc = [int(bond.IsInRing())] bond_feature_vector = bond_type_enc + bond_is_conj_enc + bond_is_in_ring_enc if use_stereochemistry == True: stereo_type_enc = one_hot_encoding(str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"]) bond_feature_vector += stereo_type_enc return np.array(bond_feature_vector)
The bond features we consider in the above function are: bond type, whether the bond is conjugated, and whether the bond is in a ring. As an additional option, the user can specify whether to include E-Z stereochemical features around double bonds.
Step 3: Generating labeled Pytorch Geometric Graph Objects
Equipped with suitable functions to turn RDKit atom objects and RDKit bond objects into informative feature vectors, we swiftly move on to define a function which turns a list of SMILES strings and an associated list of labels (such as pKi values) into a list of Pytorch Geometric graph objects:
def create_pytorch_geometric_graph_data_list_from_smiles_and_labels(x_smiles, y): """ Inputs: x_smiles = [smiles_1, smiles_2, ....] ... a list of SMILES strings y = [y_1, y_2, ...] ... a list of numerial labels for the SMILES strings (such as associated pKi values) Outputs: data_list = [G_1, G_2, ...] ... a list of torch_geometric.data.Data objects which represent labeled molecular graphs that can readily be used for machine learning """ data_list = [] for (smiles, y_val) in zip(x_smiles, y): # convert SMILES to RDKit mol object mol = Chem.MolFromSmiles(smiles) # get feature dimensions n_nodes = mol.GetNumAtoms() n_edges = 2*mol.GetNumBonds() unrelated_smiles = "O=O" unrelated_mol = Chem.MolFromSmiles(unrelated_smiles) n_node_features = len(get_atom_features(unrelated_mol.GetAtomWithIdx(0))) n_edge_features = len(get_bond_features(unrelated_mol.GetBondBetweenAtoms(0,1))) # construct node feature matrix X of shape (n_nodes, n_node_features) X = np.zeros((n_nodes, n_node_features)) for atom in mol.GetAtoms(): X[atom.GetIdx(), :] = get_atom_features(atom) X = torch.tensor(X, dtype = torch.float) # construct edge index array E of shape (2, n_edges) (rows, cols) = np.nonzero(GetAdjacencyMatrix(mol)) torch_rows = torch.from_numpy(rows.astype(np.int64)).to(torch.long) torch_cols = torch.from_numpy(cols.astype(np.int64)).to(torch.long) E = torch.stack([torch_rows, torch_cols], dim = 0) # construct edge feature array EF of shape (n_edges, n_edge_features) EF = np.zeros((n_edges, n_edge_features)) for (k, (i,j)) in enumerate(zip(rows, cols)): EF[k] = get_bond_features(mol.GetBondBetweenAtoms(int(i),int(j))) EF = torch.tensor(EF, dtype = torch.float) # construct label tensor y_tensor = torch.tensor(np.array([y_val]), dtype = torch.float) # construct Pytorch Geometric data object and append to data list data_list.append(Data(x = X, edge_index = E, edge_attr = EF, y = y_tensor)) return data_list
Training Loop and Summary
In this post, we learned how to turn RDKit atoms and RDKit bonds into meaningful feature vectors and how to use these feature vectors to create molecular graph objects that can be used as inputs to Pytorch Geometric GNNs. A canonical GNN training loop could now look like this:
# canonical training loop for a Pytorch Geometric GNN model gnn_model # create list of molecular graph objects from list of SMILES x_smiles and list of labels y data_list = create_pytorch_geometric_graph_data_list_from_smiles_and_labels(x_smiles, y) # create dataloader for training dataloader = DataLoader(dataset = data_list, batch_size = 2**7) # define loss function loss_function = nn.MSELoss() # define optimiser optimiser = torch.optim.Adam(gnn_model.parameters(), lr = 1e-3) # loop over 10 training epochs for epoch in range(10): # set model to training mode gnn_model.train() # loop over minibatches for training for (k, batch) in enumerate(dataloader): # compute current value of loss function via forward pass output = gnn_model(batch) loss_function_value = loss_function(output[:,0], torch.tensor(batch.y, dtype = torch.float32)) # set past gradient to zero optimiser.zero_grad() # compute current gradient via backward pass loss_function_value.backward() # update model weights using gradient and optimisation method optimiser.step()
Happy training!