Neural Network Training Evaluation and Test

Neural Network

class deeprank_gnn.NeuralNet.NeuralNet(database, Net, node_feature=['type', 'polarity', 'bsa'], edge_feature=['dist'], target='irmsd', lr=0.01, batch_size=32, percent=[1.0, 0.0], database_eval=None, index=None, class_weights=None, task=None, classes=[0, 1], threshold=None, pretrained_model=None, shuffle=True, outdir='./', cluster_nodes='mcl', transform_sigmoid=False)[source]

Class from which the network is trained, evaluated and tested

Parameters
  • database (str, required) – path(s) to hdf5 dataset(s). Unique hdf5 file or list of hdf5 files.

  • Net (function, required) – neural network function (ex. GINet, Foutnet etc.)

  • node_feature (list, optional) – type, charge, polarity, bsa (buried surface area), pssm, cons (pssm conservation information), ic (pssm information content), depth, hse (half sphere exposure). Defaults to [‘type’, ‘polarity’, ‘bsa’].

  • edge_feature (list, optional) – dist (distance). Defaults to [‘dist’].

  • target (str, optional) – irmsd, lrmsd, fnat, capri_class, bin_class, dockQ. Defaults to ‘irmsd’.

  • lr (float, optional) – learning rate. Defaults to 0.01.

  • batch_size (int, optional) – defaults to 32.

  • percent (list, optional) – divides the input dataset into a training and an evaluation set. Defaults to [1.0, 0.0].

  • database_eval ([type], optional) – independent evaluation set. Defaults to None.

  • index ([type], optional) – index of the molecules to consider. Defaults to None.

  • class_weights ([list or bool], optional) – weights provided to the cross entropy loss function. The user can either input a list of weights or let DeepRanl-GNN (True) define weights based on the dataset content. Defaults to None.

  • task (str, optional) – ‘reg’ for regression or ‘class’ for classification . Defaults to None.

  • classes (list, optional) – define the dataset target classes in classification mode. Defaults to [0, 1].

  • threshold (int, optional) – threshold to compute binary classification metrics. Defaults to 4.0.

  • pretrained_model (str, optional) – path to pre-trained model. Defaults to None.

  • shuffle (bool, optional) – shuffle the training set. Defaults to True.

  • outdir (str, optional) – output directory. Defaults to ./

  • cluster_nodes (bool, optional) – perform node clustering (‘mcl’ or ‘louvain’ algorithm). Default to ‘mcl’.

load_pretrained_model(database, Net)[source]

Loads pretrained model

Parameters
  • database (str) – path to hdf5 file(s)

  • Net (function) – neural network

load_model(database, Net, database_eval)[source]

Loads model

Parameters
  • database (str) – path to the hdf5 file(s) of the training set

  • Net (function) – neural network

  • database_eval (str) – path to the hdf5 file(s) of the evaluation set

Raises

ValueError – Invalid node clustering method.

put_model_to_device(dataset, Net)[source]

Puts the model on the available device

Parameters
  • dataset (str) – path to the hdf5 file(s)

  • Net (function) – neural network

Raises

ValueError – Incorrect output shape

set_loss()[source]

Sets the loss function (MSE loss for regression/ CrossEntropy loss for classification).

train(nepoch=1, validate=False, save_model='last', hdf5='train_data.hdf5', save_epoch='intermediate', save_every=5)[source]

Trains the model

Parameters
  • nepoch (int, optional) – number of epochs. Defaults to 1.

  • validate (bool, optional) – perform validation. Defaults to False.

  • save_model (last, best, optional) – save the model. Defaults to ‘last’

  • hdf5 (str, optional) – hdf5 output file

  • save_epoch (all, intermediate, optional) –

  • save_every (int, optional) – save data every n epoch if save_epoch == ‘intermediate’. Defaults to 5

test(database_test=None, threshold=4, hdf5='test_data.hdf5')[source]

Tests the model

Parameters
  • database_test ([type], optional) – test database

  • threshold (int, optional) – threshold use to tranform data into binary values. Defaults to 4.

  • hdf5 (str, optional) – output hdf5 file. Defaults to ‘test_data.hdf5’.

eval(loader)[source]

Evaluates the model

Parameters

loader (DataLoader) – [description]

Return type

(tuple)

get_metrics(data='eval', threshold=4.0, binary=True)[source]

Computes the metrics needed

Parameters
  • data (str, optional) – ‘eval’, ‘train’ or ‘test’. Defaults to ‘eval’.

  • threshold (float, optional) – threshold use to tranform data into binary values. Defaults to 4.0.

  • binary (bool, optional) – Transform data into binary data. Defaults to True.

compute_class_weights()[source]
static print_epoch_data(stage, epoch, loss, acc, time)[source]

Prints the data of each epoch

Parameters
  • stage (str) – tain or valid

  • epoch (int) – epoch number

  • loss (float) – loss during that epoch

  • acc (float or None) – accuracy

  • time (float) – timing of the epoch

format_output(pred, target=None)[source]

Format the network output depending on the task (classification/regression).

static update_name(hdf5, outdir)[source]

Checks if the file already exists, if so, update the name

Parameters
  • hdf5 (str) – hdf5 file

  • outdir (str) – output directory

Returns

update hdf5 name

Return type

str

plot_loss(name='')[source]

Plots the loss of the model as a function of the epoch

Parameters

name (str, optional) – name of the output file. Defaults to ‘’.

plot_acc(name='')[source]

Plots the accuracy of the model as a function of the epoch

Parameters

name (str, optional) – name of the output file. Defaults to ‘’.

plot_hit_rate(data='eval', threshold=4, mode='percentage', name='')[source]

Plots the hitrate as a function of the models’ rank

Parameters
  • data (str, optional) – which stage to consider train/eval/test. Defaults to ‘eval’.

  • threshold (int, optional) – defines the value to split into a hit (1) or a non-hit (0). Defaults to 4.

  • mode (str, optional) – displays the hitrate as a number of hits (‘count’) or as a percentage (‘percantage’) . Defaults to ‘percentage’.

plot_scatter()[source]

Scatters plot of the results.

save_model(filename='model.pth.tar')[source]

Saves the model to a file

Parameters

filename (str, optional) – name of the file. Defaults to ‘model.pth.tar’.

load_params(filename)[source]

Loads the parameters of a pretrained model

Parameters

filename ([type]) – [description]

Returns

[description]

Return type

[type]