Source code for WatCon.find_conserved_networks

'''
Cluster water coordinates and analyze conservation to clustered networks
'''

import os
import numpy as np
import networkx as nx
from sklearn.cluster import OPTICS, DBSCAN, HDBSCAN
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import pandas as pd
import pickle


[docs] def combine_graphs(list_of_graphs): """ Combine multiple NetworkX graph objects into a single graph. Parameters ---------- list_of_graphs : list of networkx.Graph List of NetworkX graph objects to be merged. Returns ------- networkx.Graph A combined graph containing all nodes and edges from input graphs. """ U = nx.disjoint_union_all(list_of_graphs) return U
[docs] def collect_coordinates(pkl_list): """ Collect coordinates into one array from .pkl files Parameters ---------- pkl_list : list List of pkl_files (full paths) Returns ------- np.ndarray Array of combined coordinates """ combined_coords = [] for file in pkl_list: with open(file, 'rb') as FILE: e = pickle.load(FILE) try: combined_coords.append([f['coordinates'] for f in e if f['coordinates'].shape[1] == 3]) except: print('Could not find ["coordinates"], check your inputs!') return np.array(combined_coords)
[docs] def get_coordinates_from_topology(pdb_file, atom_selection='all'): """ Collect coordinates from a given MDAnalysis-readable topology file Parameters ---------- pdb_file : str Full path to MDAnalysis-readable topology file Returns ------- np.ndarray Array of coordinates """ import MDAnalysis as mda u = mda.Universe(pdb_file) ag = u.select_atoms(atom_selection) coords = ag.positions return np.array(coords).reshape(-1,3)
[docs] def get_coordinates_from_pdb(pdb_file): """ Returns coordinates for all "ATOM" or "HETATM" lines -- useful for getting coordinates from cluster PDBS. Parameters ---------- pdb_file : str Full path to PDB file. Return ------ np.ndarray Array of coordinates """ with open(pdb_file, 'r') as FILE: lines = FILE.readlines() centers = np.zeros((len(lines), 3)) for i, line in enumerate(lines): if line.startswith("ATOM") or line.startswith("HETATM"): x = float(line[30:38].strip()) y = float(line[38:46].strip()) z = float(line[46:54].strip()) centers[i,:] = x,y,z return centers
[docs] def cluster_nodes(combined_graph, cluster='hdbscan', min_samples=10): """ Cluster node positions from a combined NetworkX graph. Parameters ---------- combined_graph : networkx.Graph A combined graph used for clustering. cluster : str Clustering method, can be 'optics', 'dbscan', or 'hdbscan'. min_samples : int Minimum number of samples required for a cluster. Returns ------- tuple - Cluster labels (array-like) - Cluster centers (dict) """ node_positions = nx.get_node_attributes(combined_graph, 'pos') positions = [pos for (node, pos) in node_positions.items()] if cluster == 'optics': print('Using OPTICS clustering') clustering = OPTICS(max_eps=1.0, metric='euclidean', min_samples=min_samples).fit(positions) elif cluster == 'dbscan': print('Using DBSCAN clustering') clustering = DBSCAN(eps=0.1, min_samples=min_samples).fit(positions) elif cluster == 'hdbscan': print('Using HDBSCAN Clustering') clustering = HDBSCAN(min_cluster_size=min_samples, cluster_selection_epsilon=0.0, algorithm='kd_tree').fit(positions) cluster_labels = clustering.labels_ unique_labels = np.unique(cluster_labels) unique_labels.sort() cluster_centers = {} for label in unique_labels: if label != -1: cluster_indices = np.where(cluster_labels == label)[0] cluster_center = np.mean(np.array([positions[i] for i in cluster_indices]), axis=0) cluster_centers[label] = cluster_center print(len(cluster_centers)) return(cluster_labels, cluster_centers)
[docs] def cluster_coordinates_only(coordinate_list, cluster='hdbscan', min_samples=10, eps=0.0, n_jobs=1): """ Cluster a set of coordinates. Parameters ---------- coordinate_list : list of array-like A combined list of all coordinates to be clustered. cluster : str Clustering method, can be 'optics', 'dbscan', or 'hdbscan'. min_samples : int Minimum number of samples required for a cluster. Returns ------- tuple - Cluster labels (array-like) - Cluster centers (dict) """ try: coordinate_list = np.array(coordinate_list).reshape(-1,3) except: print("Couldn't reshape coordinates correctly, check your inputs.") #scaler = MinMaxScaler() #scaler.fit(coordinate_list) #coordinate_norm = scaler.transform(coordinate_list) if cluster == 'optics': print('Using OPTICS clustering') clustering = OPTICS(min_samples=min_samples, eps=eps, n_jobs=n_jobs).fit(coordinate_list) elif cluster == 'dbscan': print('Using DBSCAN clustering') clustering = DBSCAN(min_samples=min_samples, eps=eps, n_jobs=n_jobs).fit(coordinate_list) elif cluster == 'hdbscan': print('Using HDBSCAN clustering') print(min_samples, eps, coordinate_list.shape) clustering = HDBSCAN(min_cluster_size=min_samples, cluster_selection_epsilon=eps, algorithm='kd_tree', n_jobs=n_jobs).fit(coordinate_list) cluster_labels = clustering.labels_ unique_labels = np.unique(cluster_labels) unique_labels.sort() cluster_centers = {} for label in unique_labels: if label != -1: cluster_indices = np.where((cluster_labels == label) & (cluster_labels!=-1))[0] #cluster_center = [coordinate_list[i] for i in cluster_indices][0] cluster_center = np.mean(np.array([coordinate_list[i] for i in cluster_indices]), axis=0) cluster_centers[label] = cluster_center print(len(cluster_centers)) return(cluster_labels, cluster_centers)
[docs] def find_commonality(networks, centers, names, dist_cutoff=1.5, local_dens_radius=6): """ Find the commonality of a list of networks relative to a summary network created from clustering. Parameters ---------- networks : list of WaterNetwork List of WaterNetwork objects to be analyzed. centers : array-like Locations of clustered centers. names : list[str] List of IDs dist_cutoff : float Distance cutoff for classifying conserved water local_dens_radius : float Radius cutoff for calculating local water density Returns ------- dict A dictionary containing calculated commonalities for each network. """ commonality_dict = {} dist = lambda x1, y1, z1, x2, y2, z2: np.sqrt((x1-x2)**2 + (y1-y2)**2 + (z1-z2)**2) for i, network in enumerate(networks): conserved = 0 unique = 0 #net = networks[i] for wat in network.water_molecules: x1 = wat.O.coordinates[0] y1 = wat.O.coordinates[1] z1 = wat.O.coordinates[2] if any((dist(x1,y1,z1, x2,y2,z2)<dist_cutoff) for (x2, y2, z2) in centers): local_waters_count = len([wat for wat in network.water_molecules if (dist(wat.O.coordinates[0], wat.O.coordinates[1],wat.O.coordinates[2],x1,y1,z1)<local_dens_radius and dist(wat.O.coordinates[0], wat.O.coordinates[1],wat.O.coordinates[2],x1,y1,z1)>2)])+1 conserved += 1/local_waters_count else: unique += 1 commonality_dict[names[i]] = conserved/len(centers) return commonality_dict
[docs] def identify_conserved_water_clusters(networks, centers, dist_cutoff=1.0, filename_base='CLUSTERS'): """ Create a dictionary of cluster conservation and generate a PDB file of clusters. Parameters ---------- networks : list of WaterNetwork List of WaterNetwork objects to analyze. centers : array-like List of XYZ coordinates representing cluster centers. dist_cutoff : float Distance cutoff to classify a water molecule as part of a cluster. filename_base : str Base filename for saving projected clusters. Returns ------- dict A dictionary mapping cluster centers to the count of included waters. """ from sklearn.preprocessing import MinMaxScaler from WatCon.visualize_structures import project_clusters #Add protein atoms eventually center_dict = {} dist = lambda x1, y1, z1, x2, y2, z2: np.sqrt((x1-x2)**2 + (y1-y2)**2 + (z1-z2)**2) for i, network in enumerate(networks): for j, center in enumerate(centers): if str(j) not in center_dict.keys(): center_dict[str(j)] = 0 for wat in network.water_molecules: x1 = wat.O.coordinates[0] y1 = wat.O.coordinates[1] z1 = wat.O.coordinates[2] x2, y2, z2 = center if dist(x1,y1,z1, x2,y2,z2) <= dist_cutoff: center_dict[str(j)] += 1 values = np.array([f for f in center_dict.values()]).reshape(-1,1) scaler = MinMaxScaler() scaler.fit(values) b_factors = scaler.transform(values).flatten() project_clusters(centers, filename_base=filename_base, b_factors=b_factors) return(center_dict)
[docs] def create_clustered_network(clusters, max_connection_distance, create_graph=True): """ Create a WaterNetwork object from cluster centers. Parameters ---------- clusters : array-like List of XYZ coordinates of cluster centers. max_connection_distance : float Maximum allowed distance between two clusters to form an interaction. create_graph : bool, optional Whether to create a NetworkX graph from the clustered WaterNetwork. Default is True. Returns ------- WaterNetwork A WaterNetwork object representing the clustered water network. """ from WatCon.generate_static_networks import WaterNetwork, WaterAtom, WaterMolecule clustered_network = WaterNetwork() for i, center in enumerate(clusters): o = WaterAtom(i, 'O', i, *center) water = WaterMolecule(i, o, H1=None, H2=None, residue_number=i) clustered_network.water_molecules.append(water) clustered_network.connections = clustered_network.find_connections(dist_cutoff=max_connection_distance, water_only=True) if create_graph: G = nx.Graph() for molecule in clustered_network.water_molecules: G.add_node(molecule.O.index, pos=molecule.O.coordinates, atom_category='WAT', MSA=None) #have nodes on all oxygens for connection in clustered_network.connections: G.add_edge(connection[0], connection[1], connection_type=connection[3], active_region=connection[4]) clustered_network.graph = G return clustered_network
[docs] def identify_conserved_water_interactions_clustering(networks, clusters, max_connection_distance=2.0, dist_cutoff=1.0, filename_base='CLUSTER', out_dir='cluster_pdbs'): """ Rank water-water interactions in relation to clustering. Parameters ---------- networks : list of WaterNetwork List of WaterNetwork objects to be analyzed. centers : array-like List of XYZ coordinates of cluster centers. max_connection_distance : float Maximum allowed distance between two clusters to form an interaction. dist_cutoff : float Distance cutoff to classify a water molecule as part of a cluster. filename_base : str Base filename for saving projected clusters. Returns ------- dict A dictionary mapping cluster interactions to the count of included waters. """ from sklearn.preprocessing import MinMaxScaler import matplotlib.pyplot as plt import os os.makedirs(out_dir, exist_ok=True) interaction_dict = {} dist = lambda x1, y1, z1, x2, y2, z2: np.sqrt((x1-x2)**2 + (y1-y2)**2 + (z1-z2)**2) clustered_network = create_clustered_network(clusters, max_connection_distance, create_graph=True) for network in networks: for connection in clustered_network.connections: name = f"{connection[0]}, {connection[1]}" if str(name) not in interaction_dict.keys(): interaction_dict[str(name)] = 0 x2, y2, z2 = clusters[connection[0]] x3, y3, z3 = clusters[connection[1]] if any(dist(wat.O.coordinates[0], wat.O.coordinates[1], wat.O.coordinates[2], x2,y2,z2) <= dist_cutoff for wat in network.water_molecules) and any(dist(wat.O.coordinates[0], wat.O.coordinates[1], wat.O.coordinates[2], x3,y3,z3) <= dist_cutoff for wat in network.water_molecules): interaction_dict[str(name)] += 1 values = np.array([f for f in interaction_dict.values()]).reshape(-1,1) scaler = MinMaxScaler() scaler.fit(values) b_factors = scaler.transform(values).flatten() pairs = [f for f in interaction_dict.keys()] num_interactions = len(pairs) # Sort pairs based on b_factors pairs = [pair for _, pair in sorted(zip(b_factors, pairs), key=lambda x: x[0])] cmap = plt.get_cmap('bwr') len_colors = np.linspace(0,1,num_interactions) colors = cmap(len_colors) #colors = [(int(r*255), int(g*255), int(b*255)) for r, g, b, _ in colors] colors = [(float(r), float(g), float(b)) for r, g, b, _ in colors] colors = [color for _, color in sorted(zip(b_factors, colors), key=lambda x: x[0])] with open(f'{out_dir}/{filename_base}.pml', 'w') as f: for i, (pair, color) in enumerate(zip(pairs, colors)): f.write(f"distance interaction{i}, resid {pair.split(',')[0]}, resid {pair.split(',')[1]}\n") f.write(f"set dash_color, [{color[0]},{color[1]},{color[2]}], interaction{i}\n") f.write(f"show spheres, resid {pair.split(',')[0]}\n") f.write(f"show spheres, resid {pair.split(',')[1]}\n") f.write("hide labels, all\n") f.write('set dash_radius, 0.15, interaction*\n') f.write('set dash_gap, 0.0, interaction*\n') f.write('hide labels, interaction*\n') f.write('set sphere_scale, 0.2\n') f.write('bg white\n') f.write('group ClusterNetwork, interaction*\n') return interaction_dict
[docs] def identify_clustered_angles(classification_file): """ Cluster two-angle calculations and count frequencies of clusters. Finds the closest point to the center of the cluster. Parameters ---------- classification_file : str Path to angle classification file (.csv) Returns ------- cluster_conservation_dict : dict Dictionary containing, for each residue, frequency of individual clusters and water/protein atom coords corresponding (as close as possible) to each cluster """ from scipy.optimize import minimize df = pd.read_csv(classification_file, delimiter=',') classification_dict = {} coord_dict = {} watercoord_dict = {} for i, row in df.iterrows(): if str(row['MSA_Resid']) not in classification_dict.keys(): classification_dict[str(row['MSA_Resid'])] = [] coord_dict[str(row['MSA_Resid'])] = [] watercoord_dict[str(row['MSA_Resid'])] = [] classification_dict[str(row['MSA_Resid'])].append((float(row['Angle_1']), float(row['Angle_2']))) coord_dict[str(row['MSA_Resid'])].append(np.array([float(f) for f in row['Protein_Coords'].split()])) watercoord_dict[str(row['MSA_Resid'])].append(np.array([float(f) for f in row['Water_Coords'].split()])) cluster_conservation_dict = {} for msa_resid, values in classification_dict.items(): if str(msa_resid) not in cluster_conservation_dict.keys(): cluster_conservation_dict[str(msa_resid)] = {} hdb = HDBSCAN(min_cluster_size=3) values = np.array([np.array(f) for f in values]).reshape(-1,2) if values.shape[0] > 3: hdb.fit(values) cluster_labels = hdb.labels_ unique_labels = np.unique(cluster_labels) unique_labels.sort() cluster_centers = {} for label in [f for f in unique_labels if f != -1]: cluster_indices = np.where(cluster_labels == label)[0] cluster_center = np.mean(np.array([values[i] for i in cluster_indices]), axis=0) cluster_centers[label] = cluster_center centroid_coords = {} centroid_water_coords = {} for label, center in cluster_centers.items(): min_distance=1000 for val_ind, val in enumerate(values): if np.linalg.norm(center-val) < min_distance: closest_coord = coord_dict[str(msa_resid)][val_ind] closest_water = watercoord_dict[str(msa_resid)][val_ind] centroid_coords[label] = closest_coord centroid_water_coords[label] = closest_water for label in cluster_labels: if label != -1: if label not in cluster_conservation_dict[str(msa_resid)].keys(): cluster_conservation_dict[str(msa_resid)][str(label)] = {'counts': 0, 'center':cluster_centers[label], 'closest_coord': centroid_coords[label], 'wat_coord':centroid_water_coords[label]} #print(centroid_coords[label], ref1_coords, ref2_coords, *cluster_centers[label]) #wat_coords = find_wat_coords(centroid_coords[label], ref1_coords, ref2_coords, *cluster_centers[label]) #cluster_conservation_dict[str(msa_resid)][str(label)] = {'wat_coord': wat_coords} cluster_conservation_dict[str(msa_resid)][str(label)]['counts'] += 1 return cluster_conservation_dict
[docs] def find_clusters_from_densities(density_file, output_name=None, threshold=1.5): """ Find clusters from densities. Parameters ---------- density_file : str .dx file containing density information output_name : str Base name for output threshold : float Threshold for cutoff Returns ------- np.ndarray Array of density hotspot coordinate locations """ from gridData import Grid import scipy.ndimage as ndimage if output_name is None: output_name = f"{density_file.split('.dx')[0]}" grid = Grid(density_file) data = grid.grid origin = np.array(grid.origin) delta = np.array(grid.delta) neighborhood = np.ones((3,3,3)) local_max = (data == ndimage.maximum_filter(data, footprint=neighborhood)) hotspot_indices = np.argwhere(local_max & (data > threshold)) hotspot_coords = np.array([origin + idx * delta for idx in hotspot_indices]) with open(f"{output_name}.pdb", 'w') as FILE: for i, (x,y,z) in enumerate(hotspot_coords, start=1): FILE.write(f"ATOM{i:5d} O HOH 1 {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 O \n") return(hotspot_coords)
[docs] def plot_commonality(files=None, input_directory=None, cluster_pdb=None, commonality_dict=None, plot_type='bar', output='commonality', out_dir='images'): """ Plot commonality to cluster centers. Parameters ---------- files : list List of outputted .pkl files from WatCon input_directory : str Directory containing files cluster_pdb : str PDB of clusters commonality_dict : dict Commonality dict plot_type : {'bar', 'hist'} Type of plot output : str Base filename for saving images Returns ------- None """ import matplotlib.pyplot as plt import os os.makedirs(out_dir, exist_ok=True) names = True if commonality_dict is None: name_list = [] network_list = [] #Check if files have names (from crystal structures) or do not with open(os.path.join(input_directory,files[0]), 'rb') as FILE: e = pickle.load(FILE) if len(e) < 4: names = False if names: for i, file in enumerate(files): with open(os.path.join(input_directory,file), 'rb') as FILE: e = pickle.load(FILE) name_list.extend(e[3]) network_list.extend(e[1]) else: for i, file in enumerate(files): with open(os.path.join(input_directory,file), 'rb') as FILE: e = pickle.load(FILE) name_list.extend([f"{i}-{j}" for j, _ in enumerate(e)]) network_list.extend(e[1]) with open(cluster_pdb, 'r') as FILE: lines = FILE.readlines() centers = np.zeros((len(lines), 3)) for i, line in enumerate(lines): if line.startswith("ATOM") or line.startswith("HETATM"): x = float(line[30:38].strip()) y = float(line[38:46].strip()) z = float(line[46:54].strip()) centers[i,:] = x,y,z commonality_dict = find_commonality(network_list, centers, name_list) if plot_type == 'bar': fig, ax = plt.subplots(1, figsize=(5,3), tight_layout=True) gene_data = {} for i, (name, commonality) in enumerate(commonality_dict.items()): # Plot the bar graph names = list(commonality_dict.keys()) x = np.arange(len(names)) width = 0.5 for i, name in enumerate(names): ax.bar(x[i], commonality, width, color='gray', hatch='//', edgecolor='k', label=name) ax.set_xticks(x) ax.set_xticklabels(names, fontsize=12) ax.set_ylabel('Conservation score', fontsize=15) ax.tick_params(axis='y', labelsize=12) ax.tick_params(axis='x', rotation=60) #ax.legend(frameon=True, edgecolor='k', fontsize=10) plt.savefig(f"{out_dir}/{output}_bar.png", dpi=200) elif plot_type == 'hist': fig, ax = plt.subplots(1, figsize=(3,2), tight_layout=True) vals = np.array(list(commonality_dict.values())) hist, xedges = np.histogram(vals, density=True, bins=15) xcenters = (xedges[1:]+xedges[:-1])/2 ax.plot(xcenters, hist) ax.set_xlabel('Commonality score') ax.set_ylabel('Density') plt.savefig(f"{out_dir}/{output}_hist.png", dpi=200) else: print('Select a valid plot type. Currently only "bar" or "hist".') raise ValueError