Source code for deeprank_gnn.Graph

import os
from pdb2sql import StructureSimilarity
import numpy as np
import networkx as nx
from collections import OrderedDict
import time
from deeprank_gnn.tools.embedding import manifold_embedding
import community
import markov_clustering as mc
import h5py


[docs]class Graph(object): def __init__(self): """Class that performs graph level action - get score for the graph (lrmsd, irmsd, fnat, capri_class, bin_class and dockQ) - networkx object (graph) to hdf5 format - networkx object (graph) from hdf5 format """ self.name = None self.nx = None self.score = {'irmsd': None, 'lrmsd': None, 'capri_class': None, 'fnat': None, 'dockQ': None, 'bin_class': None}
[docs] def get_score(self, ref): """Assigns scores (lrmsd, irmsd, fnat, dockQ, bin_class, capri_class) to a protein graph Args: ref (path): path to the reference structure required to compute the different score """ ref_name = os.path.splitext(os.path.basename(ref))[0] sim = StructureSimilarity(self.pdb, ref) # Input pre-computed zone files if os.path.exists(ref_name+'.lzone'): self.score['lrmsd'] = sim.compute_lrmsd_fast( method='svd', lzone=ref_name+'.lzone') self.score['irmsd'] = sim.compute_irmsd_fast( method='svd', izone=ref_name+'.izone') # Compute zone files else: self.score['lrmsd'] = sim.compute_lrmsd_fast( method='svd') self.score['irmsd'] = sim.compute_irmsd_fast( method='svd') self.score['fnat'] = sim.compute_fnat_fast() self.score['dockQ'] = sim.compute_DockQScore( self.score['fnat'], self.score['lrmsd'], self.score['irmsd']) self.score['bin_class'] = self.score['irmsd'] < 4.0 self.score['capri_class'] = 5 for thr, val in zip([6.0, 4.0, 2.0, 1.0], [4, 3, 2, 1]): if self.score['irmsd'] < thr: self.score['capri_class'] = val
[docs] def nx2h5(self, f5): """Converts Networkx object to hdf5 format Args: f5 ([type]): hdf5 file """ # group for the graph grp = f5.create_group(self.name) # store the nodes data = np.array(list(self.nx.nodes)).astype('S') grp.create_dataset('nodes', data=data) # store node features node_feat_grp = grp.create_group('node_data') feature_names = list(list(self.nx.nodes.data())[0][1].keys()) for feat in feature_names: data = [v for _, v in nx.get_node_attributes( self.nx, feat).items()] node_feat_grp.create_dataset(feat, data=data) edges, internal_edges = [], [] edge_index, internal_edge_index = [], [] edge_feat = list(list(self.nx.edges.data())[0][2].keys()) edge_data, internal_edge_data = {}, {} for feat in edge_feat: edge_data[feat] = [] internal_edge_data[feat] = [] node_key = list(self.nx.nodes) for e in self.nx.edges: edge_type = self.nx.edges[e]['type'].decode('utf-8') ind1, ind2 = node_key.index(e[0]), node_key.index(e[1]) if edge_type == 'interface': edges.append(e) edge_index.append([ind1, ind2]) for feat in edge_feat: edge_data[feat].append(self.nx.edges[e][feat]) elif edge_type == 'internal': internal_edges.append(e) internal_edge_index.append([ind1, ind2]) for feat in edge_feat: internal_edge_data[feat].append( self.nx.edges[e][feat]) # store the edges data = np.array(edges).astype('S') grp.create_dataset('edges', data=data) data = np.array(internal_edges).astype('S') grp.create_dataset('internal_edges', data=data) # store the edge index grp.create_dataset('edge_index', data=edge_index) grp.create_dataset('internal_edge_index', data=internal_edge_index) # store the edge attributes edge_feat_grp = grp.create_group('edge_data') internal_edge_feat_grp = grp.create_group( 'internal_edge_data') for feat in edge_feat: edge_feat_grp.create_dataset(feat, data=edge_data[feat]) internal_edge_feat_grp.create_dataset( feat, data=internal_edge_data[feat]) # store the score score_grp = grp.create_group('score') for k, v in self.score.items(): if v is not None: score_grp.create_dataset(k, data=v)
[docs] def h52nx(self, f5name, mol, molgrp=None): """Converts Hdf5 file to Networkx object Args: f5name (str): hdf5 file mol (str): molecule name molgrp ([type], optional): hdf5[molecule]. Defaults to None. """ if molgrp is None: f5 = h5py.File(f5name, 'r') molgrp = f5[mol] self.name = mol self.pdb = mol+'.pdb' else: self.name = molgrp.name self.pdb = self.name+'.pdb' # creates the graph self.nx = nx.Graph() # get nodes nodes = molgrp['nodes'][()].astype('U').tolist() nodes = [tuple(n) for n in nodes] # get node features node_keys = list(molgrp['node_data'].keys()) node_feat = {} for key in node_keys: node_feat[key] = molgrp['node_data/'+key][()] # add nodes for iN, n in enumerate(nodes): self.nx.add_node(n) for k, v in node_feat.items(): v = v.reshape(-1, 1) if v.ndim == 1 else v v = v[iN, :] v = v[0] if len(v) == 1 else v self.nx.nodes[n][k] = v # get edges edges = molgrp['edges'][()].astype('U').tolist() edges = [(tuple(e[0]), tuple(e[1])) for e in edges] # get edge data edge_key = list(molgrp['edge_data'].keys()) edge_feat = {} for key in edge_key: edge_feat[key] = molgrp['edge_data/'+key][()] # add edges for iedge, e in enumerate(edges): self.nx.add_edge(e[0], e[1]) for k, v in edge_feat.items(): v = v.reshape(-1, 1) if v.ndim == 1 else v v = v[iedge, :] v = v[0] if len(v) == 1 else v self.nx.edges[e[0], e[1]][k] = v # get internal edges edges = molgrp['internal_edges'][()].astype( 'U').tolist() edges = [(tuple(e[0]), tuple(e[1])) for e in edges] # get edge data edge_key = list(molgrp['internal_edge_data'].keys()) edge_feat = {} for key in edge_key: edge_feat[key] = molgrp['internal_edge_data/' + key][()] # add edges for iedge, e in enumerate(edges): self.nx.add_edge(e[0], e[1]) for k, v in edge_feat.items(): v = v.reshape(-1, 1) if v.ndim == 1 else v v = v[iedge, :] v = v[0] if len(v) == 1 else v self.nx.edges[e[0], e[1]][k] = v # add score self.score = {} score_key = list(molgrp['score'].keys()) for key in score_key: self.score[key] = molgrp['score/'+key][()] # add cluster if 'clustering' in molgrp: self.clusters = {} for method in list(molgrp['clustering'].keys()): self.clusters[method] = molgrp['clustering/' + method+'/depth_0'][()] if molgrp is None: f5.close()
[docs] def plotly_2d(self, out=None, offline=False, iplot=True, disable_plot=False, method='louvain'): """Plots the interface graph in 2D Args: out ([type], optional): output name. Defaults to None. offline (bool, optional): Defaults to False. iplot (bool, optional): Defaults to True. method (str, optional): 'mcl' of 'louvain'. Defaults to 'louvain'. """ if offline: import plotly.offline as py else: import chart_studio.plotly as py import plotly.graph_objs as go import matplotlib.pyplot as plt pos = np.array( [v.tolist() for _, v in nx.get_node_attributes(self.nx, 'pos').items()]) pos2D = manifold_embedding(pos) dict_pos = {n: p for n, p in zip(self.nx.nodes, pos2D)} nx.set_node_attributes(self.nx, dict_pos, 'pos2D') # remove interface edges for clustering gtmp = self.nx.copy() ebunch = [] for e in self.nx.edges: typ = self.nx.edges[e]['type'] if isinstance(typ, bytes): typ = typ.decode('utf-8') if typ == 'interface': ebunch.append(e) gtmp.remove_edges_from(ebunch) try: raw_cluster = self.clusters[method] cluster = {} node_key = list(self.nx.nodes.keys()) for inode, index_cluster in enumerate(raw_cluster): cluster[node_key[inode]] = index_cluster except: print( 'No cluster in the group. Computing cluster now with method : %s' % method) if method == 'louvain': cluster = community.best_partition(gtmp) elif method == 'mcl': matrix = nx.to_scipy_sparse_matrix(gtmp) # run MCL with default parameters result = mc.run_mcl(matrix) mcl_clust = mc.get_clusters(result) # get clusters cluster = {} node_key = list(self.nx.nodes.keys()) for ic, c in enumerate(mcl_clust): for node in c: cluster[node_key[node]] = ic # get the colormap for the clsuter line ncluster = np.max([v for _, v in cluster.items()])+1 cmap = plt.cm.nipy_spectral N = cmap.N cmap = [cmap(i) for i in range(N)] cmap = cmap[::int(N/ncluster)] cmap = 'plasma' edge_trace_list, internal_edge_trace_list = [], [] node_connect = {} for edge in self.nx.edges: edge_type = self.nx.edges[edge[0], edge[1]]['type'].decode('utf-8') if edge_type == 'internal': trace = go.Scatter(x=[], y=[], text=[], mode='lines', hoverinfo=None, showlegend=False, line=go.scatter.Line(color='rgb(110,110,110)', width=3)) elif edge_type == 'interface': trace = go.Scatter(x=[], y=[], text=[], mode='lines', hoverinfo=None, showlegend=False, line=go.scatter.Line(color='rgb(210,210,210)', width=1)) x0, y0 = self.nx.nodes[edge[0]]['pos2D'] x1, y1 = self.nx.nodes[edge[1]]['pos2D'] trace['x'] += (x0, x1, None) trace['y'] += (y0, y1, None) if edge_type == 'internal': internal_edge_trace_list.append(trace) elif edge_type == 'interface': edge_trace_list.append(trace) for i in [0, 1]: if edge[i] not in node_connect: node_connect[edge[i]] = 1 else: node_connect[edge[i]] += 1 node_trace_A = go.Scatter(x=[], y=[], text=[], mode='markers', hoverinfo='text', marker=dict(color='rgb(227,28,28)', size=[], line=dict(color=[], width=4, colorscale=cmap))) # 'rgb(227,28,28)' node_trace_B = go.Scatter(x=[], y=[], text=[], mode='markers', hoverinfo='text', marker=dict(color='rgb(0,102,255)', size=[], line=dict(color=[], width=4, colorscale=cmap))) # 'rgb(0,102,255)' node_trace = [node_trace_A, node_trace_B] for node in self.nx.nodes: index = self.nx.nodes[node]['chain'] pos = self.nx.nodes[node]['pos2D'] node_trace[index]['x'] += (pos[0],) node_trace[index]['y'] += (pos[1],) node_trace[index]['text'] += ( '[Clst:' + str(cluster[node]) + '] ' + ' '.join(node),) nc = node_connect[node] node_trace[index]['marker']['size'] += ( 5 + 15*np.tanh(nc/5),) node_trace[index]['marker']['line']['color'] += ( cluster[node],) fig = go.Figure(data=[*internal_edge_trace_list, *edge_trace_list, *node_trace], layout=go.Layout( title='<br>tSNE connection graph for %s' % self.pdb, titlefont=dict(size=16), showlegend=False, hovermode='closest', margin=dict(b=20, l=5, r=5, t=40), annotations=[dict( text="", showarrow=False, xref="paper", yref="paper", x=0.005, y=-0.002)], xaxis=dict( showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))) if not disable_plot: if iplot: py.iplot(fig, filename=out) else: py.plot(fig)
[docs] def plotly_3d(self, out=None, offline=False, iplot=True, disable_plot=False): """Plots interface graph in 3D Args: out ([type], optional): [description]. Defaults to None. offline (bool, optional): [description]. Defaults to False. iplot (bool, optional): [description]. Defaults to True. """ if offline: import plotly.offline as py else: import chart_studio.plotly as py import plotly.graph_objs as go edge_trace_list, internal_edge_trace_list = [], [] node_connect = {} for edge in self.nx.edges: edge_type = self.nx.edges[edge[0], edge[1]]['type'].decode('utf-8') if edge_type == 'internal': trace = go.Scatter3d(x=[], y=[], z=[], text=[], mode='lines', hoverinfo=None, showlegend=False, line=go.scatter3d.Line(color='rgb(110,110,110)', width=5)) elif edge_type == 'interface': trace = go.Scatter3d(x=[], y=[], z=[], text=[], mode='lines', hoverinfo=None, showlegend=False, line=go.scatter3d.Line(color='rgb(210,210,210)', width=2)) x0, y0, z0 = self.nx.nodes[edge[0]]['pos'] x1, y1, z1 = self.nx.nodes[edge[1]]['pos'] trace['x'] += (x0, x1, None) trace['y'] += (y0, y1, None) trace['z'] += (z0, z1, None) if edge_type == 'internal': internal_edge_trace_list.append(trace) elif edge_type == 'interface': edge_trace_list.append(trace) for i in [0, 1]: if edge[i] not in node_connect: node_connect[edge[i]] = 1 else: node_connect[edge[i]] += 1 node_trace_A = go.Scatter3d(x=[], y=[], z=[], text=[], mode='markers', hoverinfo='text', marker=dict(color='rgb(227,28,28)', size=[], symbol='circle', line=dict(color='rgb(50,50,50)', width=2))) node_trace_B = go.Scatter3d(x=[], y=[], z=[], text=[], mode='markers', hoverinfo='text', marker=dict(color='rgb(0,102,255)', size=[], symbol='circle', line=dict(color='rgb(50,50,50)', width=2))) node_trace = [node_trace_A, node_trace_B] for node in self.nx.nodes: index = self.nx.nodes[node]['chain'] pos = self.nx.nodes[node]['pos'] node_trace[index]['x'] += (pos[0],) node_trace[index]['y'] += (pos[1],) node_trace[index]['z'] += (pos[2], ) node_trace[index]['text'] += (' '.join(node),) nc = node_connect[node] node_trace[index]['marker']['size'] += ( 5 + 15*np.tanh(nc/5), ) fig = go.Figure(data=[*node_trace, *internal_edge_trace_list, *edge_trace_list], layout=go.Layout( title='<br>Connection graph for %s' % self.pdb, titlefont=dict(size=16), showlegend=False, hovermode='closest', margin=dict(b=20, l=5, r=5, t=40), annotations=[dict( text="", showarrow=False, xref="paper", yref="paper", x=0.005, y=-0.002)], xaxis=dict( showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))) if not disable_plot: if iplot: py.iplot(fig, filename=out) else: py.plot(fig)