Neural Network Training Evaluation and Test
Neural Network
- class deeprank_gnn.NeuralNet.NeuralNet(database, Net, node_feature=['type', 'polarity', 'bsa'], edge_feature=['dist'], target='irmsd', lr=0.01, batch_size=32, percent=[1.0, 0.0], database_eval=None, index=None, class_weights=None, task=None, classes=[0, 1], threshold=None, pretrained_model=None, shuffle=True, outdir='./', cluster_nodes='mcl', transform_sigmoid=False)[source]
Class from which the network is trained, evaluated and tested
- Parameters
database (str, required) – path(s) to hdf5 dataset(s). Unique hdf5 file or list of hdf5 files.
Net (function, required) – neural network function (ex. GINet, Foutnet etc.)
node_feature (list, optional) – type, charge, polarity, bsa (buried surface area), pssm, cons (pssm conservation information), ic (pssm information content), depth, hse (half sphere exposure). Defaults to [‘type’, ‘polarity’, ‘bsa’].
edge_feature (list, optional) – dist (distance). Defaults to [‘dist’].
target (str, optional) – irmsd, lrmsd, fnat, capri_class, bin_class, dockQ. Defaults to ‘irmsd’.
lr (float, optional) – learning rate. Defaults to 0.01.
batch_size (int, optional) – defaults to 32.
percent (list, optional) – divides the input dataset into a training and an evaluation set. Defaults to [1.0, 0.0].
database_eval ([type], optional) – independent evaluation set. Defaults to None.
index ([type], optional) – index of the molecules to consider. Defaults to None.
class_weights ([list or bool], optional) – weights provided to the cross entropy loss function. The user can either input a list of weights or let DeepRanl-GNN (True) define weights based on the dataset content. Defaults to None.
task (str, optional) – ‘reg’ for regression or ‘class’ for classification . Defaults to None.
classes (list, optional) – define the dataset target classes in classification mode. Defaults to [0, 1].
threshold (int, optional) – threshold to compute binary classification metrics. Defaults to 4.0.
pretrained_model (str, optional) – path to pre-trained model. Defaults to None.
shuffle (bool, optional) – shuffle the training set. Defaults to True.
outdir (str, optional) – output directory. Defaults to ./
cluster_nodes (bool, optional) – perform node clustering (‘mcl’ or ‘louvain’ algorithm). Default to ‘mcl’.
- load_pretrained_model(database, Net)[source]
Loads pretrained model
- Parameters
database (str) – path to hdf5 file(s)
Net (function) – neural network
- load_model(database, Net, database_eval)[source]
Loads model
- Parameters
- Raises
ValueError – Invalid node clustering method.
- put_model_to_device(dataset, Net)[source]
Puts the model on the available device
- Parameters
dataset (str) – path to the hdf5 file(s)
Net (function) – neural network
- Raises
ValueError – Incorrect output shape
- set_loss()[source]
Sets the loss function (MSE loss for regression/ CrossEntropy loss for classification).
- train(nepoch=1, validate=False, save_model='last', hdf5='train_data.hdf5', save_epoch='intermediate', save_every=5)[source]
Trains the model
- Parameters
nepoch (int, optional) – number of epochs. Defaults to 1.
validate (bool, optional) – perform validation. Defaults to False.
save_model (last, best, optional) – save the model. Defaults to ‘last’
hdf5 (str, optional) – hdf5 output file
save_epoch (all, intermediate, optional) –
save_every (int, optional) – save data every n epoch if save_epoch == ‘intermediate’. Defaults to 5
- eval(loader)[source]
Evaluates the model
- Parameters
loader (DataLoader) – [description]
- Return type
(tuple)
- format_output(pred, target=None)[source]
Format the network output depending on the task (classification/regression).
- plot_loss(name='')[source]
Plots the loss of the model as a function of the epoch
- Parameters
name (str, optional) – name of the output file. Defaults to ‘’.
- plot_acc(name='')[source]
Plots the accuracy of the model as a function of the epoch
- Parameters
name (str, optional) – name of the output file. Defaults to ‘’.
- plot_hit_rate(data='eval', threshold=4, mode='percentage', name='')[source]
Plots the hitrate as a function of the models’ rank
- Parameters
data (str, optional) – which stage to consider train/eval/test. Defaults to ‘eval’.
threshold (int, optional) – defines the value to split into a hit (1) or a non-hit (0). Defaults to 4.
mode (str, optional) – displays the hitrate as a number of hits (‘count’) or as a percentage (‘percantage’) . Defaults to ‘percentage’.