import numpy as np
from time import time
import h5py
import os
# torch import
import torch
import torch.nn as nn
from torch.nn import MSELoss
import torch.nn.functional as F
from torch_geometric.data import DataLoader
# deeprank_gnn import
from .DataSet import HDF5DataSet, DivideDataSet, PreCluster
from .Metrics import Metrics
[docs]class NeuralNet(object):
def __init__(self, database, Net,
node_feature=['type', 'polarity', 'bsa'],
edge_feature=['dist'], target='irmsd', lr=0.01,
batch_size=32, percent=[1.0, 0.0],
database_eval=None, index=None, class_weights=None, task=None,
classes=[0, 1], threshold=None,
pretrained_model=None, shuffle=True, outdir='./', cluster_nodes='mcl', transform_sigmoid=False):
"""Class from which the network is trained, evaluated and tested
Args:
database (str, required): path(s) to hdf5 dataset(s). Unique hdf5 file or list of hdf5 files.
Net (function, required): neural network function (ex. GINet, Foutnet etc.)
node_feature (list, optional): type, charge, polarity, bsa (buried surface area), pssm,
cons (pssm conservation information), ic (pssm information content), depth,
hse (half sphere exposure).
Defaults to ['type', 'polarity', 'bsa'].
edge_feature (list, optional): dist (distance). Defaults to ['dist'].
target (str, optional): irmsd, lrmsd, fnat, capri_class, bin_class, dockQ.
Defaults to 'irmsd'.
lr (float, optional): learning rate. Defaults to 0.01.
batch_size (int, optional): defaults to 32.
percent (list, optional): divides the input dataset into a training and an evaluation set.
Defaults to [1.0, 0.0].
database_eval ([type], optional): independent evaluation set. Defaults to None.
index ([type], optional): index of the molecules to consider. Defaults to None.
class_weights ([list or bool], optional): weights provided to the cross entropy loss function.
The user can either input a list of weights or let DeepRanl-GNN (True) define weights
based on the dataset content. Defaults to None.
task (str, optional): 'reg' for regression or 'class' for classification . Defaults to None.
classes (list, optional): define the dataset target classes in classification mode. Defaults to [0, 1].
threshold (int, optional): threshold to compute binary classification metrics. Defaults to 4.0.
pretrained_model (str, optional): path to pre-trained model. Defaults to None.
shuffle (bool, optional): shuffle the training set. Defaults to True.
outdir (str, optional): output directory. Defaults to ./
cluster_nodes (bool, optional): perform node clustering ('mcl' or 'louvain' algorithm). Default to 'mcl'.
"""
# load the input data or a pretrained model
# each named arguments is stored in a member vairable
# i.e. self.node_feature = node_feature
if pretrained_model is None:
for k, v in dict(locals()).items():
if k not in ['self', 'database', 'Net', 'database_eval']:
self.__setattr__(k, v)
if self.task == None:
if self.target in ['irmsd', 'lrmsd', 'fnat', 'dockQ']:
self.task = 'reg'
elif self.target in ['bin_class', 'capri_classes']:
self.task = 'class'
else:
raise ValueError(
f"User target detected -> The task argument is required ('class' or 'reg'). \n\t"
f"Example: \n\t"
f""
f"model = NeuralNet(database, GINet,"
f" target='physiological_assembly',"
f" task='class',"
f" shuffle=True,"
f" percent=[0.8, 0.2])")
if self.task == 'class' and self.threshold == None:
print('the threshold for accuracy computation is set to {}'.format(self.classes[1]))
self.threshold = self.classes[1]
if self.task == 'reg' and self.threshold == None:
print('the threshold for accuracy computation is set to 0.3')
self.threshold = 0.3
self.load_model(database, Net, database_eval)
else:
self.load_params(pretrained_model)
self.outdir = outdir
self.load_pretrained_model(database, Net)
[docs] def load_pretrained_model(self, database, Net):
"""
Loads pretrained model
Args:
database (str): path to hdf5 file(s)
Net (function): neural network
"""
# Load the test set
test_dataset = HDF5DataSet(root='./', database=database,
node_feature=self.node_feature, edge_feature=self.edge_feature,
target=self.target, clustering_method=self.cluster_nodes)
self.test_loader = DataLoader(
test_dataset)
PreCluster(test_dataset, method=self.cluster_nodes)
print('Test set loaded')
self.put_model_to_device(test_dataset, Net)
self.set_loss()
# optimizer
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.lr)
# load the model and the optimizer state if we have one
self.optimizer.load_state_dict(self.opt_loaded_state_dict)
self.model.load_state_dict(self.model_load_state_dict)
[docs] def load_model(self, database, Net, database_eval):
"""
Loads model
Args:
database (str): path to the hdf5 file(s) of the training set
Net (function): neural network
database_eval (str): path to the hdf5 file(s) of the evaluation set
Raises:
ValueError: Invalid node clustering method.
"""
# dataset
dataset = HDF5DataSet(root='./', database=database, index=self.index,
node_feature=self.node_feature, edge_feature=self.edge_feature,
target=self.target, clustering_method=self.cluster_nodes)
if self.cluster_nodes != None:
if self.cluster_nodes == 'mcl' or self.cluster_nodes == 'louvain':
print("Loading clusters")
PreCluster(dataset, method=self.cluster_nodes)
else:
raise ValueError(
f"Invalid node clustering method. \n\t"
f"Please set cluster_nodes to 'mcl', 'louvain' or None. Default to 'mcl' \n\t")
# divide the dataset
train_dataset, valid_dataset = DivideDataSet(
dataset, percent=self.percent)
# dataloader
self.train_loader = DataLoader(
train_dataset, batch_size=self.batch_size, shuffle=self.shuffle)
print('Training set loaded')
if self.percent[1] > 0.0:
self.valid_loader = DataLoader(
valid_dataset, batch_size=self.batch_size, shuffle=self.shuffle)
print('Evaluation set loaded')
# independent validation dataset
if database_eval is not None:
print('Loading independent evaluation dataset...')
valid_dataset = HDF5DataSet(root='./', database=database_eval, index=self.index,
node_feature=self.node_feature, edge_feature=self.edge_feature,
target=self.target, clustering_method=self.cluster_nodes)
if self.cluster_nodes == 'mcl' or self.cluster_nodes == 'louvain':
print('Loading clusters for the evaluation set.')
PreCluster(valid_dataset, method=self.cluster_nodes)
self.valid_loader = DataLoader(
valid_dataset, batch_size=self.batch_size, shuffle=self.shuffle)
print('Independent validation set loaded !')
else:
print('No independent validation set loaded')
self.put_model_to_device(dataset, Net)
# optimizer
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.lr)
self.set_loss()
# init lists
self.train_acc = []
self.train_loss = []
self.valid_acc = []
self.valid_loss = []
[docs] def put_model_to_device(self, dataset, Net):
"""
Puts the model on the available device
Args:
dataset (str): path to the hdf5 file(s)
Net (function): neural network
Raises:
ValueError: Incorrect output shape
"""
# get the device
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
print ('device set to :', self.device)
if self.device.type == 'cuda':
print(torch.cuda.get_device_name(0))
self.num_edge_features = len(self.edge_feature)
# regression mode
if self.task == 'reg':
self.model = Net(dataset.get(
0).num_features, 1, self.num_edge_features).to(self.device)
# classification mode
elif self.task == 'class':
self.classes_to_idx = {i: idx for idx,
i in enumerate(self.classes)}
self.idx_to_classes = {idx: i for idx,
i in enumerate(self.classes)}
self.output_shape = len(self.classes)
try:
self.model = Net(dataset.get(
0).num_features, self.output_shape, self.num_edge_features).to(self.device)
except:
raise ValueError(
f"The loaded model does not accept output_shape = {self.output_shape} argument \n\t"
f"Check your input or adapt the model\n\t"
f"Example :\n\t"
f"def __init__(self, input_shape): --> def __init__(self, input_shape, output_shape) \n\t"
f"self.fc2 = torch.nn.Linear(64, 1) --> self.fc2 = torch.nn.Linear(64, output_shape) \n\t")
[docs] def set_loss(self):
"""Sets the loss function (MSE loss for regression/ CrossEntropy loss for classification)."""
if self.task == 'reg':
self.loss = MSELoss()
elif self.task == 'class':
# assign weights to each class in case of unbalanced dataset
self.weights = None
if self.class_weights == True:
targets_all = []
for batch in self.train_loader:
targets_all.append(batch.y)
targets_all = torch.cat(
targets_all).squeeze().tolist()
self.weights = torch.tensor(
[targets_all.count(i) for i in self.classes], dtype=torch.float32)
print('class occurences: {}'.format(self.weights))
self.weights = 1.0 / self.weights
self.weights = self.weights / self.weights.sum()
print('class weights: {}'.format(self.weights))
self.loss = nn.CrossEntropyLoss(
weight=self.weights, reduction='mean')
[docs] def train(self, nepoch=1, validate=False, save_model='last', hdf5='train_data.hdf5', save_epoch='intermediate', save_every=5):
"""
Trains the model
Args:
nepoch (int, optional): number of epochs. Defaults to 1.
validate (bool, optional): perform validation. Defaults to False.
save_model (last, best, optional): save the model. Defaults to 'last'
hdf5 (str, optional): hdf5 output file
save_epoch (all, intermediate, optional)
save_every (int, optional): save data every n epoch if save_epoch == 'intermediate'. Defaults to 5
"""
# Output file name
fname = self.update_name(hdf5, self.outdir)
# Open output file for writting
with h5py.File(fname, 'w') as self.f5:
# Number of epochs
self.nepoch = nepoch
# Loop over epochs
self.data = {}
for epoch in range(1, nepoch+1):
# Train the model
self.model.train()
t0 = time()
_out, _y, _loss, self.data['train'] = self._epoch(epoch)
t = time() - t0
self.train_loss.append(_loss)
self.train_out = _out
self.train_y = _y
_acc = self.get_metrics('train', self.threshold).accuracy
self.train_acc.append(_acc)
# Print the loss and accuracy (training set)
self.print_epoch_data(
'train', epoch, _loss, _acc, t)
# Validate the model
if validate is True:
t0 = time()
_out, _y, _val_loss, self.data['eval'] = self.eval(
self.valid_loader)
t = time() - t0
self.valid_loss.append(_val_loss)
self.valid_out = _out
self.valid_y = _y
_val_acc = self.get_metrics(
'eval', self.threshold).accuracy
self.valid_acc.append(_val_acc)
# Print loss and accuracy (validation set)
self.print_epoch_data(
'valid', epoch, _val_loss, _val_acc, t)
# save the best model (i.e. lowest loss value on validation data)
if save_model == 'best':
if min(self.valid_loss) == _val_loss:
self.save_model(filename='t{}_y{}_b{}_e{}_lr{}_{}.pth.tar'.format(
self.task, self.target, str(self.batch_size), str(nepoch), str(self.lr), str(epoch)))
else:
# if no validation set, saves the best performing model on the traing set
if save_model == 'best':
if min(self.train_loss) == _loss:
print(
'WARNING: The training set is used both for learning and model selection.')
print(
'this may lead to training set data overfitting.')
print(
'We advice you to use an external validation set.')
self.save_model(filename='t{}_y{}_b{}_e{}_lr{}_{}.pth.tar'.format(
self.task, self.target, str(self.batch_size), str(nepoch), str(self.lr), str(epoch)))
# Save epoch data
if (save_epoch == 'all') or (epoch == nepoch):
self._export_epoch_hdf5(epoch, self.data)
elif (save_epoch == 'intermediate') and (epoch % save_every == 0):
self._export_epoch_hdf5(epoch, self.data)
# Save the last model
if save_model == 'last':
self.save_model(filename='t{}_y{}_b{}_e{}_lr{}.pth.tar'.format(
self.task, self.target, str(self.batch_size), str(nepoch), str(self.lr)))
[docs] def test(self, database_test=None, threshold=4, hdf5='test_data.hdf5'):
"""
Tests the model
Args:
database_test ([type], optional): test database
threshold (int, optional): threshold use to tranform data into binary values. Defaults to 4.
hdf5 (str, optional): output hdf5 file. Defaults to 'test_data.hdf5'.
"""
# Output file name
fname = self.update_name(hdf5, self.outdir)
# Open output file for writting
with h5py.File(fname, 'w') as self.f5:
# Loads the test dataset if provided
if database_test is not None:
# Load the test set
test_dataset = HDF5DataSet(root='./', database=database_test,
node_feature=self.node_feature, edge_feature=self.edge_feature,
target=self.target, clustering_method=self.cluster_nodes)
print('Test set loaded')
PreCluster(test_dataset, method='mcl')
self.test_loader = DataLoader(
test_dataset)
else:
if self.load_pretrained_model == None:
raise ValueError(
"You need to upload a test dataset \n\t"
"\n\t"
">> model.test(test_dataset)\n\t"
"if a pretrained network is loaded, you can directly test the model on the loaded dataset :\n\t"
">> model = NeuralNet(database_test, gnn, pretrained_model = model_saved, target=None)\n\t"
">> model.test()\n\t")
self.data = {}
# Run test
_out, _y, _test_loss, self.data['test'] = self.eval(
self.test_loader)
self.test_out = _out
if len(_y) == 0:
self.test_y = None
self.test_acc = None
else:
self.test_y = _y
_test_acc = self.get_metrics('test', threshold).accuracy
self.test_acc = _test_acc
self.test_loss = _test_loss
self._export_epoch_hdf5(0, self.data)
[docs] def eval(self, loader):
"""
Evaluates the model
Args:
loader (DataLoader): [description]
Returns:
(tuple):
"""
self.model.eval()
loss_func, loss_val = self.loss, 0
out = []
raw_outputs = []
y = []
data = {'outputs': [], 'raw_outputs' : [], 'targets': [], 'mol': []}
for data_batch in loader:
data_batch = data_batch.to(self.device)
pred = self.model(data_batch)
pred, data_batch.y = self.format_output(
pred, data_batch.y)
pred = pred.to(self.device)
# Check if a target value was provided (i.e. benchmarck scenario)
if data_batch.y is not None:
data_batch.y = data_batch.y.to(self.device)
loss_val += loss_func(pred,
data_batch.y).detach().item()
y += data_batch.y.tolist()
# get the outputs for export
if self.task == 'class':
pred = F.softmax(pred.cpu().detach(), dim=1)
raw_outputs += pred.tolist()
pred = np.argmax(pred, axis=1)
else:
pred = pred.cpu().detach().reshape(-1)
raw_outputs += pred.tolist()
out += pred.tolist()
# get the data
data['mol'] += data_batch['mol']
# Save targets
if self.task == 'class':
data['targets'] += [self.idx_to_classes[x]
for x in y]
data['outputs'] += [self.idx_to_classes[x]
for x in out]
else:
data['targets'] += y
data['outputs'] += out
data['raw_outputs'] += raw_outputs
return out, y, loss_val, data
def _epoch(self, epoch):
"""
Runs a single epoch
Returns:
tuple: prediction, ground truth, running loss
"""
running_loss = 0
out = []
raw_outputs = []
y = []
data = {'outputs': [], 'raw_outputs' : [], 'targets': [], 'mol': []}
for data_batch in self.train_loader:
data_batch = data_batch.to(self.device)
self.optimizer.zero_grad()
pred = self.model(data_batch)
pred, data_batch.y = self.format_output(
pred, data_batch.y)
pred = pred.to(self.device)
data_batch.y = data_batch.y.to(self.device)
loss = self.loss(pred, data_batch.y)
running_loss += loss.detach().item()
loss.backward()
self.optimizer.step()
try:
y += data_batch.y.tolist()
except ValueError:
print(
"You must provide target values (y) for the training set")
# get the outputs for export
if self.task == 'class':
pred = F.softmax(pred, dim=1).cpu().detach()
raw_outputs += pred.tolist()
pred = np.argmax(pred, axis=1)
else:
pred = pred.cpu().detach().reshape(-1)
raw_outputs += pred.tolist()
out += pred.tolist()
# get the data
data['mol'] += data_batch['mol']
# save targets and predictions
if self.task == 'class':
data['targets'] += [self.idx_to_classes[x]
for x in y]
data['outputs'] += [self.idx_to_classes[x]
for x in out]
else:
data['targets'] += y
data['outputs'] += out
data['raw_outputs'] += raw_outputs
return out, y, running_loss, data
[docs] def get_metrics(self, data='eval', threshold=4.0, binary=True):
"""
Computes the metrics needed
Args:
data (str, optional): 'eval', 'train' or 'test'. Defaults to 'eval'.
threshold (float, optional): threshold use to tranform data into binary values. Defaults to 4.0.
binary (bool, optional): Transform data into binary data. Defaults to True.
"""
if self.task == 'class':
threshold = self.classes_to_idx[threshold]
if data == 'eval':
if len(self.valid_out) == 0:
print('No evaluation set has been provided')
else:
pred = self.valid_out
y = self.valid_y
elif data == 'train':
if len(self.train_out) == 0:
print('No training set has been provided')
else:
pred = self.train_out
y = self.train_y
elif data == 'test':
if len(self.test_out) == 0:
print('No test set has been provided')
if self.test_y == None:
print(
'You must provide ground truth target values to compute the metrics')
else:
pred = self.test_out
y = self.test_y
return Metrics(pred, y, self.target, threshold, binary)
[docs] def compute_class_weights(self):
targets_all = []
for batch in self.train_loader:
targets_all.append(batch.y)
targets_all = torch.cat(targets_all).squeeze().tolist()
weights = torch.tensor([targets_all.count(i)
for i in self.classes], dtype=torch.float32)
print('class occurences: {}'.format(weights))
weights = 1.0 / weights
weights = weights / weights.sum()
print('class weights: {}'.format(weights))
return weights
[docs] @staticmethod
def print_epoch_data(stage, epoch, loss, acc, time):
"""
Prints the data of each epoch
Args:
stage (str): tain or valid
epoch (int): epoch number
loss (float): loss during that epoch
acc (float or None): accuracy
time (float): timing of the epoch
"""
if acc is None:
acc_str = 'None'
else:
acc_str = '%1.4e' % acc
print('Epoch [%04d] : %s loss %e | accuracy %s | time %1.2e sec.' % (epoch,
stage, loss, acc_str, time))
[docs] @staticmethod
def update_name(hdf5, outdir):
"""
Checks if the file already exists, if so, update the name
Args:
hdf5 (str): hdf5 file
outdir (str): output directory
Returns:
str: update hdf5 name
"""
fname = os.path.join(outdir, hdf5)
count = 0
hdf5_name = hdf5.split('.')[0]
# If file exists, change its name with a number
while os.path.exists(fname):
count += 1
hdf5 = '{}_{:03d}.hdf5'.format(hdf5_name, count)
fname = os.path.join(outdir, hdf5)
return fname
[docs] def plot_loss(self, name=''):
"""
Plots the loss of the model as a function of the epoch
Args:
name (str, optional): name of the output file. Defaults to ''.
"""
nepoch = self.nepoch
train_loss = self.train_loss
valid_loss = self.valid_loss
import matplotlib.pyplot as plt
if len(valid_loss) > 1:
plt.plot(range(1, nepoch+1), valid_loss,
c='red', label='valid')
if len(train_loss) > 1:
plt.plot(range(1, nepoch+1), train_loss,
c='blue', label='train')
plt.title("Loss/ epoch")
plt.xlabel("Number of epoch")
plt.ylabel("Total loss")
plt.legend()
plt.savefig('loss_epoch{}.png'.format(name))
plt.close()
[docs] def plot_acc(self, name=''):
"""
Plots the accuracy of the model as a function of the epoch
Args:
name (str, optional): name of the output file. Defaults to ''.
"""
nepoch = self.nepoch
train_acc = self.train_acc
valid_acc = self.valid_acc
import matplotlib.pyplot as plt
if len(valid_acc) > 1:
plt.plot(range(1, nepoch+1), valid_acc,
c='red', label='valid')
if len(train_acc) > 1:
plt.plot(range(1, nepoch+1), train_acc,
c='blue', label='train')
plt.title("Accuracy/ epoch")
plt.xlabel("Number of epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.savefig('acc_epoch{}.png'.format(name))
plt.close()
[docs] def plot_hit_rate(self, data='eval', threshold=4, mode='percentage', name=''):
"""
Plots the hitrate as a function of the models' rank
Args:
data (str, optional): which stage to consider train/eval/test. Defaults to 'eval'.
threshold (int, optional): defines the value to split into a hit (1) or a non-hit (0). Defaults to 4.
mode (str, optional): displays the hitrate as a number of hits ('count') or as a percentage ('percantage') . Defaults to 'percentage'.
"""
import matplotlib.pyplot as plt
try:
hitrate = self.get_metrics(data, threshold).hitrate()
nb_models = len(hitrate)
X = range(1, nb_models + 1)
if mode == 'percentage':
hitrate /= hitrate.sum()
plt.plot(X, hitrate, c='blue', label='train')
plt.title("Hit rate")
plt.xlabel("Number of models")
plt.ylabel("Hit Rate")
plt.legend()
plt.savefig('hitrate{}.png'.format(name))
plt.close()
except:
print('No hit rate plot could be generated for you {} task'.format(
self.task))
[docs] def plot_scatter(self):
"""Scatters plot of the results."""
import matplotlib.pyplot as plt
self.model.eval()
pred, truth = {'train': [], 'valid': []}, {
'train': [], 'valid': []}
for data in self.train_loader:
data = data.to(self.device)
truth['train'] += data.y.tolist()
pred['train'] += self.model(data).reshape(-1).tolist()
for data in self.valid_loader:
data = data.to(self.device)
truth['valid'] += data.y.tolist()
pred['valid'] += self.model(data).reshape(-1).tolist()
plt.scatter(truth['train'], pred['train'], c='blue')
plt.scatter(truth['valid'], pred['valid'], c='red')
plt.show()
[docs] def save_model(self, filename='model.pth.tar'):
"""
Saves the model to a file
Args:
filename (str, optional): name of the file. Defaults to 'model.pth.tar'.
"""
state = {'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'node': self.node_feature,
'edge': self.edge_feature,
'target': self.target,
'task': self.task,
'classes': self.classes,
'class_weight': self.class_weights,
'batch_size': self.batch_size,
'percent': self.percent,
'lr': self.lr,
'index': self.index,
'shuffle': self.shuffle,
'threshold': self.threshold,
'cluster_nodes': self.cluster_nodes,
'transform_sigmoid': self.transform_sigmoid}
torch.save(state, filename)
[docs] def load_params(self, filename):
"""
Loads the parameters of a pretrained model
Args:
filename ([type]): [description]
Returns:
[type]: [description]
"""
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
state = torch.load(filename, map_location=torch.device(self.device))
self.node_feature = state['node']
self.edge_feature = state['edge']
self.target = state['target']
self.batch_size = state['batch_size']
self.percent = state['percent']
self.lr = state['lr']
self.index = state['index']
self.class_weights = state['class_weight']
self.task = state['task']
self.classes = state['classes']
self.threshold = state['threshold']
self.shuffle = state['shuffle']
self.cluster_nodes = state['cluster_nodes']
self.transform_sigmoid = state['transform_sigmoid']
self.opt_loaded_state_dict = state['optimizer']
self.model_load_state_dict = state['model']
def _export_epoch_hdf5(self, epoch, data):
"""
Exports the epoch data to the hdf5 file.
Exports the data of a given epoch in train/valid/test group.
In each group are stored the predicted values (outputs),
ground truth (targets) and molecule name (mol).
Args:
epoch (int): index of the epoch
data (dict): data of the epoch
"""
# create a group
grp_name = 'epoch_%04d' % epoch
grp = self.f5.create_group(grp_name)
grp.attrs['task'] = self.task
grp.attrs['target'] = self.target
grp.attrs['batch_size'] = self.batch_size
# loop over the pass_type : train/valid/test
for pass_type, pass_data in data.items():
# we don't want to breack the process in case of issue
try:
# create subgroup for the pass
sg = grp.create_group(pass_type)
# loop over the data : target/output/molname
for data_name, data_value in pass_data.items():
# mol name is a bit different
# since there are strings
if data_name == 'mol':
data_value = np.string_(data_value)
string_dt = h5py.special_dtype(vlen=str)
sg.create_dataset(
data_name, data=data_value, dtype=string_dt)
# output/target values
else:
sg.create_dataset(data_name, data=data_value)
except TypeError:
raise ValueError("Error in export epoch to hdf5")