import torch
from torch.nn import Parameter
import torch.nn.functional as F
import torch.nn as nn
from torch_scatter import scatter_add, scatter_mean
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.nn.inits import uniform
from torch_geometric.nn import max_pool_x
from .community_pooling import get_preloaded_cluster, community_pooling
[docs]class FoutLayer(torch.nn.Module):
"""
This layer is described by eq. (1) of
Protein Interface Predition using Graph Convolutional Network
by Alex Fout et al. NIPS 2018
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
"""
def __init__(self,
in_channels,
out_channels,
bias=True):
super(FoutLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# Wc and Wn are the center and neighbor weight matrix
self.Wc = Parameter(torch.Tensor(in_channels, out_channels))
self.Wn = Parameter(torch.Tensor(in_channels, out_channels))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
[docs] def reset_parameters(self):
size = self.in_channels
uniform(size, self.Wc)
uniform(size, self.Wn)
uniform(size, self.bias)
[docs] def forward(self, x, edge_index):
row, col = edge_index
num_node = len(x)
# alpha = x * Wc
alpha = torch.mm(x, self.Wc)
# beta = x * Wn
beta = torch.mm(x, self.Wn)
# gamma_i = 1/Ni Sum_j x_j * Wn
# there might be a better way than looping over the nodes
gamma = torch.zeros(
num_node, self.out_channels).to(alpha.device)
for n in range(num_node):
index = edge_index[:, edge_index[0, :] == n][1, :]
gamma[n, :] = torch.mean(beta[index, :], dim=0)
# alpha = alpha + gamma
alpha = alpha + gamma
# add the bias
if self.bias is not None:
alpha = alpha + self.bias
return alpha
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels)
[docs]class FoutNet(torch.nn.Module):
def __init__(self, input_shape, output_shape=1, input_shape_edge=None):
super(FoutNet, self).__init__()
self.conv1 = FoutLayer(input_shape, 16)
self.conv2 = FoutLayer(16, 32)
self.fc1 = torch.nn.Linear(32, 64)
self.fc2 = torch.nn.Linear(64, output_shape)
self.clustering = 'mcl'
[docs] def forward(self, data):
act = nn.Tanhshrink()
act = F.relu
#act = nn.LeakyReLU(0.25)
# first conv block
data.x = act(self.conv1(data.x, data.edge_index))
cluster = get_preloaded_cluster(data.cluster0, data.batch)
data = community_pooling(cluster, data)
# second conv block
data.x = act(self.conv2(data.x, data.edge_index))
cluster = get_preloaded_cluster(data.cluster1, data.batch)
x, batch = max_pool_x(cluster, data.x, data.batch)
# FC
x = scatter_mean(x, batch, dim=0)
x = act(self.fc1(x))
x = self.fc2(x)
#x = F.dropout(x, training=self.training)
return x
# return F.relu(x)