Source code for deeprank_gnn.ResidueGraph

import os
import numpy as np
import shutil
import torch
from time import time
import networkx as nx

from pdb2sql import interface

from .tools import BioWrappers, PSSM, BSA
from .Graph import Graph

[docs]class ResidueGraph(Graph): def __init__(self, pdb=None, pssm=None, contact_distance=8.5, internal_contact_distance=3, pssm_align='res', biopython=False): """Class from which Residue features are computed Args: pdb (str, optional): path to pdb file. Defaults to None. pssm (str, optional): path to pssm file. Defaults to None. contact_distance (float, optional): cutoff distance for external edges. Defaults to 8.5. internal_contact_distance (int, optional): cutoff distance for internal edges. Defaults to 3. pssm_align (str, optional): [description]. Defaults to 'res'. """ super().__init__() self.type = 'residue' self.pdb = pdb = os.path.splitext(os.path.basename(pdb))[0] if pssm is not None: self.pssm, self.ic = PSSM.PSSM_aligned( pssm, style=pssm_align) else: self.pssm, self.ic = None, None self.contact_distance = contact_distance self.internal_contact_distance = internal_contact_distance self.residue_charge = {'CYS': -0.64, 'HIS': -0.29, 'ASN': -1.22, 'GLN': -1.22, 'SER': -0.80, 'THR': -0.80, 'TYR': -0.80, 'TRP': -0.79, 'ALA': -0.37, 'PHE': -0.37, 'GLY': -0.37, 'ILE': -0.37, 'VAL': -0.37, 'MET': -0.37, 'PRO': 0.0, 'LEU': -0.37, 'GLU': -1.37, 'ASP': -1.37, 'LYS': -0.36, 'ARG': -1.65} self.residue_names = {'CYS': 0, 'HIS': 1, 'ASN': 2, 'GLN': 3, 'SER': 4, 'THR': 5, 'TYR': 6, 'TRP': 7, 'ALA': 8, 'PHE': 9, 'GLY': 10, 'ILE': 11, 'VAL': 12, 'MET': 13, 'PRO': 14, 'LEU': 15, 'GLU': 16, 'ASP': 17, 'LYS': 18, 'ARG': 19} self.residue_polarity = {'CYS': 'polar', 'HIS': 'polar', 'ASN': 'polar', 'GLN': 'polar', 'SER': 'polar', 'THR': 'polar', 'TYR': 'polar', 'TRP': 'polar', 'ALA': 'apolar', 'PHE': 'apolar', 'GLY': 'apolar', 'ILE': 'apolar', 'VAL': 'apolar', 'MET': 'apolar', 'PRO': 'apolar', 'LEU': 'apolar', 'GLU': 'neg_charged', 'ASP': 'neg_charged', 'LYS': 'neg_charged', 'ARG': 'pos_charged'} self.pssm_pos = {'CYS': 4, 'HIS': 8, 'ASN': 2, 'GLN': 5, 'SER': 15, 'THR': 16, 'TYR': 18, 'TRP': 17, 'ALA': 0, 'PHE': 13, 'GLY': 7, 'ILE': 9, 'VAL': 19, 'MET': 12, 'PRO': 14, 'LEU': 10, 'GLU': 6, 'ASP': 3, 'LYS': 11, 'ARG': 1} self.polarity_encoding = { 'apolar': 0, 'polar': 1, 'neg_charged': 2, 'pos_charged': 3} self.biopython = biopython #self.edge_polarity_encoding, iencod = {}, 0 # for k1, v1 in self.polarity_encoding.items(): # for k2, v2 in self.polarity_encoding.items(): ##key = tuple(np.sort([v1, v2])) # if key not in self.edge_polarity_encoding: ##self.edge_polarity_encoding[key] = iencod #iencod += 1 # check if external execs are installed self.check_execs() # create the sqldb db = interface(self.pdb) # get the graphs t0 = time() self.get_graph(db) print('Graph %f' % (time()-t0)) # get the nodes/edge attributes t0 = time() self.get_node_features(db) print('Node %f' % (time()-t0)) t0 = time() self.get_edge_features() print('Edge %f' % (time()-t0)) # close the db db._close()
[docs] def check_execs(self): """Checks if all the required external execs are installed. Raises: OSError: if execs not found """ execs = {'msms': ''} for e, inst in execs.items(): if shutil.which(e) is None: print(e, ' is not installed see ', inst, ' for details')
[docs] def get_graph(self, db): """Gets the interface graph nodes and edges given the internal and external distance threshold Args: db ([type]): sqldb database """ self.nx = nx.Graph() self.nx.edge_index = [] self.res_contact_pairs = db.get_contact_residues( cutoff=self.contact_distance, return_contact_pairs=True) # get all the nodes all_nodes = self._get_all_valid_nodes( self.res_contact_pairs, self.pssm) # create the interface edges for key, val in self.res_contact_pairs.items(): if key in all_nodes: for v in val: if v in all_nodes: d = self._get_edge_distance(key, v, db) self.nx.add_edge(key, v, dist=d, type=bytes( 'interface', encoding='utf-8')) else: print('WARNING: {} is not a valid node'.format(v)) else: print('WARNING: {} is not a valid node'.format(key)) # get the internal edges edges, dist = self.get_internal_edges(db) for e, d in zip(edges, dist): e0, e1 = tuple(e[0]), tuple(e[1]) if e0 in all_nodes and e1 in all_nodes: self.nx.add_edge(e0, e1, dist=d, type=bytes( 'internal', encoding='utf-8')) else: print('WARNING: {} {} is not a valid edge'.format(e0, e1))
@staticmethod def _get_all_valid_nodes(res_contact_pairs, pssm, verbose=False): """Filters out invalid nodes based on 2 criteria: - contains unrecognized residue(s) - contains residue(s) absent form the pssm matrix Args: res_contact_pairs ([type]): list of contact pairs pssm ([type]): pssm file verbose (bool, optional): [description]. Defaults to False. Returns: Valid list of contact pairs """ valid_res = [ 'ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLU', 'GLN', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'ASX', 'SEC', 'GLX'] # tag the non residues keys_to_pop = [] for res in res_contact_pairs.keys(): if res[2] not in valid_res: keys_to_pop.append(res) # res_contact_pairs.pop(res,None) Warning('--> Residue ', res, ' not valid') # tag the ones that are not in PSSM if pssm is not None: for res in list(res_contact_pairs.keys()): if res not in pssm: keys_to_pop.append(res) # res_contact_pairs.pop(res,None) Warning('--> Residue ', res, ' not found in PSSM file') # Remove the residue for res in keys_to_pop: if res in res_contact_pairs: res_contact_pairs.pop(res, None) # get a list of residues of chain B # automatically remove the ones that are not proper residues # and the ones that are not in the PSSM nodesB = [] for k, reslist in list(res_contact_pairs.items()): for res in reslist: if res[2] in valid_res: if pssm is None: nodesB += [res] elif res in pssm: nodesB += [res] else: if verbose: print('removed node', res) nodesB = sorted(set(nodesB)) # make a list of nodes return list(res_contact_pairs.keys()) + nodesB
[docs] def get_node_features(self, db): """Assigns features to each node """ # get the BSA of the residues t0 = time() bsa_calc = BSA.BSA(self.pdb, db) bsa_calc.get_structure() bsa_calc.get_contact_residue_sasa( cutoff=self.contact_distance) bsa_data = bsa_calc.bsa_data # biopython data model = BioWrappers.get_bio_model(db.pdbfile) if self.biopython == True: ResDepth = BioWrappers.get_depth_contact_res( model, self.nx.nodes) HSE = BioWrappers.get_hse(model) # loop over all the nodes for node_key in self.nx.nodes: chainID = node_key[0] resName = node_key[2] resSeq = int(node_key[1]) self.nx.nodes[node_key]['chain'] = { 'A': 0, 'B': 1}[chainID] self.nx.nodes[node_key]['pos'] = np.mean( db.get('x,y,z', chainID=chainID, resSeq=resSeq), 0) self.nx.nodes[node_key]['type'] = self.onehot( self.residue_names[resName], len(self.residue_names)) self.nx.nodes[node_key]['charge'] = self.residue_charge[resName] self.nx.nodes[node_key]['polarity'] = self.onehot( self.polarity_encoding[self.residue_polarity[resName]], len(self.polarity_encoding)) self.nx.nodes[node_key]['bsa'] = bsa_data[node_key] if self.pssm is not None: data = PSSM.get_pssm_data(node_key, self.pssm) self.nx.nodes[node_key]['pssm'] = data self.nx.nodes[node_key]['cons'] = data[self.pssm_pos[resName]] self.nx.nodes[node_key]['ic'] = PSSM.get_ic_data( node_key, self.ic) if self.biopython == True: self.nx.nodes[node_key]['depth'] = ResDepth[node_key] if node_key in ResDepth else 0 bio_key = (chainID, resSeq) self.nx.nodes[node_key]['hse'] = HSE[bio_key] if bio_key in HSE else ( 0, 0, 0)
[docs] def get_edge_features(self): """Assigns distance feature to each edge """ node_keys = list(self.nx.nodes) for e in self.nx.edges: node1, node2 = e self.nx.edge_index.append( [node_keys.index(node1), node_keys.index(node2)])
[docs] def get_internal_edges(self, db): """Gets internal edges and the associated distances """ nodesA, nodesB = [], [] for n in self.nx.nodes: if n[0] == 'A': nodesA.append(n) elif n[0] == 'B': nodesB.append(n) edgesA, distA = self._get_internal_edges_chain( nodesA, db, self.internal_contact_distance) edgesB, distB = self._get_internal_edges_chain( nodesB, db, self.internal_contact_distance) return edgesA+edgesB, distA+distB
def _get_internal_edges_chain(self, nodes, db, cutoff): """Gets internal edges and the associated distances (min) per chain Args: nodes ([type]): list of nodes db ([type]): sqldb database cutoff ([type]): internal edge cutoff Returns: edges between two nodes, and the minimum distance between those two nodes """ lnodes = list(nodes) nn = len(nodes) edges, dist = [], [] for i1 in range(nn): xyz1 = np.array( db.get('x,y,z', chainID=nodes[i1][0], resSeq=nodes[i1][1])) for i2 in range(i1+1, nn): xyz2 = np.array( db.get('x,y,z', chainID=nodes[i2][0], resSeq=nodes[i2][1])) d2 = -2*, xyz2.T) + np.sum(xyz1**2, axis=1)[:, None] + np.sum(xyz2**2, axis=1) if np.any(d2 < cutoff**2): edges.append((nodes[i1], nodes[i2])) dist.append(np.sqrt(np.min(d2))) return edges, dist def _get_internal_edges_chain_mean(self, nodes, db, cutoff): """Gets internal edges and the associated distance (mean) per chain Args: nodes ([type]): list of nodes db ([type]): sqldb database cutoff ([type]): internal edge cutoff Returns: edges between two nodes, and the mean distance between those two nodes """ lnodes = list(nodes) nn = len(nodes) edges, dist, centers = [], [], [] for i1 in range(nn): centers.append( np.mean(db.get('x,y,z', chainID=nodes[i1][0], resSeq=nodes[i1][1]), 0)) centers = np.array(centers) distances = -2*, centers.T) + np.sum( centers**2, axis=1)[:, None] + np.sum(centers**2, axis=1) indexes = np.argwhere(distances < 2*cutoff) for i, j in indexes: if i != j: edges.append((nodes[i], nodes[j])) dist.append(distances[i, j]) return edges, dist def _get_edge_polarity(self, node1, node2): """Gets edge polarity Args: node1 node2 """ v1 = self.nx.nodes[node1]['polarity'] v2 = self.nx.nodes[node2]['polarity'] key = tuple(np.sort([v1, v2])) return self.edge_polarity_encoding[key] def _get_edge_distance(self, node1, node2, db): """Get the edge distance between residues Args: node1 (tuple): (chainID, resID) node2 (tuple): (chainID, resID) db ([type]): pdb2sql database Returns: np.array: distance between residues """ xyz1 = np.array( db.get('x,y,z', chainID=node1[0], resSeq=node1[1])) xyz2 = np.array( db.get('x,y,z', chainID=node2[0], resSeq=node2[1])) d2 = -2*, xyz2.T) + np.sum(xyz1**2, axis=1)[:, None] + np.sum(xyz2**2, axis=1) return np.sqrt(np.min(d2))
[docs] def onehot(self, idx, size): """One hot encoder """ onehot = torch.zeros(size) # Fill the one-hot encoded sequence with 1 at the corresponding idx onehot[idx] = 1 return np.array(onehot)