Source code for deeprank_gnn.ginet

import torch
from torch.nn import Parameter
import torch.nn.functional as F
import torch.nn as nn

from torch_scatter import scatter_mean
from torch_scatter import scatter_sum

from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

# torch_geometric import
from torch_geometric.nn.inits import uniform
from torch_geometric.nn import max_pool_x
from torch_geometric.data import DataLoader

# deeprank_gnn import
from deeprank_gnn.community_pooling import get_preloaded_cluster, community_pooling
from deeprank_gnn.NeuralNet import NeuralNet
from deeprank_gnn.DataSet import HDF5DataSet, PreCluster


[docs]class GINetConvLayer(torch.nn.Module): def __init__(self, in_channels, out_channels, number_edge_features=1, bias=False): super(GINetConvLayer, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.fc = nn.Linear( self.in_channels, self.out_channels, bias=bias) self.fc_edge_attr = nn.Linear( number_edge_features, number_edge_features, bias=bias) self.fc_attention = nn.Linear( 2 * self.out_channels + number_edge_features, 1, bias=bias) self.reset_parameters()
[docs] def reset_parameters(self): size = self.in_channels uniform(size, self.fc.weight) uniform(size, self.fc_attention.weight) uniform(size, self.fc_edge_attr.weight)
[docs] def forward(self, x, edge_index, edge_attr): row, col = edge_index num_node = len(x) edge_attr = edge_attr.unsqueeze( -1) if edge_attr.dim() == 1 else edge_attr xcol = self.fc(x[col]) xrow = self.fc(x[row]) ed = self.fc_edge_attr(edge_attr) # create edge feature by concatenating node feature alpha = torch.cat([xrow, xcol, ed], dim=1) alpha = self.fc_attention(alpha) alpha = F.leaky_relu(alpha) alpha = F.softmax(alpha, dim=1) h = alpha * xcol out = torch.zeros( num_node, self.out_channels).to(alpha.device) z = scatter_sum(h, row, dim=0, out=out) return z
def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels)
[docs]class GINet(torch.nn.Module): # input_shape -> number of node input features # output_shape -> number of output value per graph # input_shape_edge -> number of edge input features def __init__(self, input_shape, output_shape=1, input_shape_edge=1): super(GINet, self).__init__() self.conv1 = GINetConvLayer(input_shape, 16, input_shape_edge) self.conv2 = GINetConvLayer(16, 32, input_shape_edge) self.conv1_ext = GINetConvLayer( input_shape, 16, input_shape_edge) self.conv2_ext = GINetConvLayer(16, 32, input_shape_edge) self.fc1 = nn.Linear(2*32, 128) self.fc2 = nn.Linear(128, output_shape) self.clustering = 'mcl' self.dropout = 0.4
[docs] def forward(self, data): act = F.relu data_ext = data.clone() # EXTERNAL INTERACTION GRAPH # first conv block data.x = act(self.conv1( data.x, data.edge_index, data.edge_attr)) cluster = get_preloaded_cluster(data.cluster0, data.batch) data = community_pooling(cluster, data) # second conv block data.x = act(self.conv2( data.x, data.edge_index, data.edge_attr)) cluster = get_preloaded_cluster(data.cluster1, data.batch) x, batch = max_pool_x(cluster, data.x, data.batch) # INTERNAL INTERACTION GRAPH # first conv block data_ext.x = act(self.conv1_ext( data_ext.x, data_ext.edge_index, data_ext.edge_attr)) cluster = get_preloaded_cluster( data_ext.cluster0, data_ext.batch) data_ext = community_pooling(cluster, data_ext) # second conv block data_ext.x = act(self.conv2_ext( data_ext.x, data_ext.edge_index, data_ext.edge_attr)) cluster = get_preloaded_cluster( data_ext.cluster1, data_ext.batch) x_ext, batch_ext = max_pool_x( cluster, data_ext.x, data_ext.batch) # FC x = scatter_mean(x, batch, dim=0) x_ext = scatter_mean(x_ext, batch_ext, dim=0) x = torch.cat([x, x_ext], dim=1) x = act(self.fc1(x)) x = F.dropout(x, self.dropout, training=self.training) x = self.fc2(x) return x