Source code for WatCon.residue_analysis

'''
Per-residue water analysis
'''

import os
import numpy as np
import MDAnalysis as mda
from MDAnalysis.analysis import distances
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd


[docs] def get_interaction_counts(network, selection='all'): """ Calculate numbers of interactions split by 'water-water' and 'water-protein'. Parameters ---------- network : WaterNetwork object selection : {'all', 'active_region', 'not_active_region'} Specifies which subset of the graph to analyze. Returns ---------- dict Describes number of 'water-water' and 'water-protein' interactions. """ interaction_counts = {'water-water': 0, 'water-protein': 0} for _, _, data in network.graph.edges(data=True): if selection=='all': if data['connection_type']=='WAT-PROT': interaction_counts['water-protein'] += 1 else: interaction_counts['water-water'] += 1 else: if data['active_region'] == selection: if data['connection_type']=='WAT-PROT': interaction_counts['water-protein'] += 1 else: interaction_counts['water-water'] += 1 return interaction_counts
[docs] def get_per_residue_interactions(network, selection='all', msa=False): """ Calculate numbers of interactions per residue. Parameters ---------- network : WaterNetwork object selection : {'all', 'active_region', 'not_active_region'} Specifies which subset of the graph to analyze. msa : bool, optional Indicate whether to use msa common residue indices Returns ---------- dict Describes number of interactions per residue. """ def get_resid_by_index(index, is_water): if is_water: matches = [f.resid for f in network.water_molecules if f.O.index == index] else: if msa==True: matches = [f.msa_resid for f in network.protein_atoms if f.index == index] else: matches = [f.resid for f in network.protein_atoms if f.index == index] return matches[0] if matches else None residue_dict = {} edges = [ (cn1, cn2, data) for (cn1, cn2, data) in network.graph.edges(data=True) if data['connection_type'] == 'WAT-PROT' and (selection == 'all' or data['active_region'] == selection) ] for cn1, cn2, _ in edges: cn1_res = get_resid_by_index(cn1, is_water=True) or get_resid_by_index(cn1, is_water=False) cn2_res = get_resid_by_index(cn2, is_water=False) or get_resid_by_index(cn2, is_water=True) cn1_wat = cn1_res in [f.resid for f in network.water_molecules] if cn1_wat: target_res = cn2_res else: target_res = cn1_res if target_res: residue_key = str(target_res) residue_dict[residue_key] = residue_dict.get(residue_key, 0) + 1 return residue_dict
[docs] def get_all_water_distances(network_group, box, selection='No-active-site', msa=False, offset=0): """ Collect all distances for each protein-interacting water Parameters ---------- network_group : list[WaterNetwork] List of WaterNetwork objects box : array-like Dimensions of unit-cell selection : {'No-active-site', 'active-site', 'all'} Analysis selection msa : bool, optional Indicate whether to use MSA indexing or standard residue indexing. Defualt is False offset : int, optional Residue offset from desired numbering that can be given to match standard residue indexing. Default is 0. """ tmp_list = [] residue_dict, interaction_data = get_per_residue_interactions(network_group, selection=selection, msa=msa) residue_dict['water_distances'] = {} for network, connections in zip(network_group, interaction_data[selection]['Water-Protein'][1]): dist_arr = [] for connection in connections: residue_atom = [mol for mol in network.protein_atoms if (mol.index == connection[0]) or (mol.index == connection[1])][0] if msa==True: tmp_list.append(residue_atom.resid+offset) #Add optional offset res = residue_atom.resid else: tmp_list.append(residue_atom.msa_resid+offset) res = residue_atom.msa_resid #ONLY WORKS FOR OXYGEN NETWORK RN try: water_mol = [mol for mol in network.water_molecules if (mol.O.index == connection[0]) or (mol.O.index == connection[1]) or (mol.H1.index == connection[0]) or (mol.H1.index == connection[1]) or (mol.H2.index == connection[0]) or (mol.H2.index == connection[1])][0] except: print(connection[0], connection[1]) dist = np.min(distances.distance_array(np.array(residue_atom.coordinates), np.array([np.array(water_mol.O.coordinates), np.array(water_mol.H1.coordinates), np.array(water_mol.H2.coordinates)]).reshape(-1,3), box=box)) dist_arr.append(dist) residue_dict['water_distances'][str(res)] = dist_arr return residue_dict, interaction_data
[docs] def classify_waters(network, ref1_coords, ref2_coords): """ Classify all water-protein interactions based on two reference angles. This function analyzes the geometric relationships between water molecules and protein atoms by calculating angles relative to two reference points. Parameters ---------- network : WaterNetwork The water network object containing interaction data. ref1_coords : array-like Coordinates of the first reference point (e.g., an atom or centroid). ref2_coords : array-like Coordinates of the second reference point (e.g., an atom or centroid). Returns ------- dict A dictionary describing interaction classifications and calculated angles. """ #Maybe extend to implement ML model to optimize angles try: if ref1_coords[0] is None: ref1_coords = [(0,10,0)] except: if ref1_coords is None: ref1_coords = [(0,10,0)] if ref2_coords is None: ref2_coords = [(10,0,10)] #Check dimensionality to ensure reference is one array if type(ref1_coords) == tuple: ref1_coords=tuple(ref1_coords[0][0]) #Careful else: ref1_coords=tuple(ref1_coords[0]) #Careful if type(ref2_coords) == tuple: ref2_coords=tuple(ref2_coords[0][0]) #Careful else: ref2_coords=tuple(ref2_coords[0]) #Careful if np.all(ref1_coords == ref2_coords): print('Reference coordinates are the same, input two separate coordinates.') raise ValueError #return None def get_angles(wat_coords, prot_coords, ref_coords): v1 = np.array([prot_coords[0]-wat_coords[0], prot_coords[1]-wat_coords[1], prot_coords[2]-wat_coords[2]]) v2 = np.array([ref_coords[0]-wat_coords[0], ref_coords[1]-wat_coords[1], ref_coords[2]-wat_coords[2]]) mag_v1 = np.sqrt((v1[0]**2+v1[1]**2+v1[2]**2)) mag_v2 = np.sqrt((v2[0]**2+v2[1]**2+v2[2]**2)) angle = (180/np.pi) * np.arccos(np.dot(v1, v2)/(mag_v1*mag_v2)) return angle classification_dict = {} for connection in [f for f in network.connections if f[3]=='WAT-PROT']: wat_coords = [f.O.coordinates for f in network.water_molecules if ( f.O.index == connection[0] or f.O.index == connection[1] or (f.H1 is not None and (f.H1.index == connection[0] or f.H1.index == connection[1])) or (f.H2 is not None and (f.H2.index == connection[0] or f.H2.index == connection[1])) )][0] prot_coords = [f.coordinates for f in network.protein_atoms if (f.index == connection[0] or f.index==connection[1])][0] if len(prot_coords) > 0: angle1 = get_angles(wat_coords, prot_coords, ref_coords=ref1_coords) angle2 = get_angles(wat_coords, prot_coords, ref_coords=ref2_coords) prot_name = [f"{f.resid},{f.msa_resid},{connection[0]},{connection[1]},{connection[2]},{connection[5]},{f.coordinates[0]} {f.coordinates[1]} {f.coordinates[2]},{wat_coords[0]} {wat_coords[1]} {wat_coords[2]}" for f in network.protein_atoms if (f.index == connection[0] or f.index == connection[1])][0] classification_dict[prot_name] = [angle1, angle2] #Consider combining into one value else: continue return classification_dict
[docs] def plot_interactions_from_angles(csvs, input_dir='msa_classification',output_dir='MSA_images', name1='DYNAMIC', name2='STATIC'): """ Plot water classifications from 2-angle analysis Parameters ---------- csvs: list List of .csv files outputted from classify_waters output_dir: str Directory to write images Returns --------- None """ dfs = {} for csv in csvs: df = pd.read_csv(os.path.join(input_dir,csv), delimiter=',') name = '_'.join(csv.split('_')[0:2]).split('.')[0] dfs[name] = df scatters = {} classifications = {} max_length=0 MSA_min = 1000 MSA_max = 1 pdb_names = {} for name, df in dfs.items(): if len([f for f in df.iterrows()]) > max_length: max_length = len([f for f in df.iterrows()]) df.sort_values(by='MSA_Resid') scatters[name] = {} classifications[name] = {} pdb_names[name] = {} for i, row in df.iterrows(): if row['MSA_Resid'] in scatters[name].keys(): classifications[name][row['MSA_Resid']].append(row['Classification']) scatters[name][row['MSA_Resid']].append((row['Angle_1'], row['Angle_2'])) pdb_names[name][row['MSA_Resid']].append(row['PDB ID']) else: scatters[name][row['MSA_Resid']] = [(row['Angle_1'], row['Angle_2'])] classifications[name][row['MSA_Resid']] = [row['Classification']] pdb_names[name][row['MSA_Resid']] = [row['PDB ID']] if row['MSA_Resid'] < MSA_min: MSA_min = row['MSA_Resid'] if row['MSA_Resid'] > MSA_max: MSA_max = row['MSA_Resid'] # Collect all unique MSAs across all scatters all_MSAs = sorted({MSA for data in scatters.values() for MSA in data.keys()}) # Calculate global x and y limits all_x = [] all_y = [] for data in scatters.values(): for coord_list in data.values(): for coords in coord_list: x, y = map(float, coords) all_x.append(x) all_y.append(y) x_min, x_max = min(all_x), max(all_x) y_min, y_max = min(all_y), max(all_y) # Add padding for better visualization x_range = x_max - x_min y_range = y_max - y_min x_min, x_max = x_min - 0.1 * x_range, x_max + 0.1 * x_range y_min, y_max = y_min - 0.1 * y_range, y_max + 0.1 * y_range names = list(scatters.keys()) print(names) colors = {name1: 'gray', name2: {'backbone': 'dodgerblue', 'sidechain': 'mediumorchid'}} # Adjust colors os.makedirs(output_dir, exist_ok=True) # Generate and save a separate plot for each MSA for MSA in all_MSAs: plt.figure(figsize=(3, 2.5), tight_layout=True) # Collect all MD data md_x, md_y = [], [] for name, data in scatters.items(): if MSA in data: if name.startswith(name1): # Combine all MD_* data x_vals, y_vals = zip(*data[MSA]) # Extract coordinates md_x.extend(map(float, x_vals)) md_y.extend(map(float, y_vals)) # Plot the combined MD data as a surface if md_x and md_y: hist, x_edges, y_edges = np.histogram2d(md_x, md_y, bins=50) X, Y = np.meshgrid(x_edges[:-1], y_edges[:-1]) mesh = plt.contourf(X, Y, np.log(hist.T), cmap="gray") cbar = plt.colorbar(mesh) cbar.set_label('log(Density)') # Plot STATIC data as scatter plots for name, data in scatters.items(): if MSA in data and name == name2: for i, coords in enumerate(data[MSA]): x, y = map(float, coords) classification = classifications[name][MSA][i] color = colors[name2]['backbone'] if 'backbone' in classification else colors[name2]['sidechain'] name_new = pdb_names[name][MSA][i] if 'open' in name_new or 'Open' in name_new: facecolor='none' else: facecolor=color plt.scatter(x, y, edgecolor=color, facecolor=facecolor, s=10) plt.text(x+0.1,y+0.1, name_new, fontsize=4) #plt.text(x,y, name_new) plt.xlim(x_min, x_max) plt.ylim(y_min, y_max) plt.xticks(np.arange(0,181, 50)) plt.yticks(np.arange(0,181,50)) plt.xlabel('Angle 1') plt.ylabel('Angle 2') plt.title(f"Common residue {int(MSA)}", fontsize=11) #plt.title(f"MSA: {int(MSA)}", fontsize=12) plt.tight_layout() #plt.savefig(f"{output_dir}/MSA_{int(MSA)}.png", dpi=200) plt.savefig(f"{output_dir}/MSA_{int(MSA)}.png", dpi=600) plt.close()
[docs] def histogram_metrics(all_files, input_directory, concatenate, output_dir='images'): """ Plot histograms for calculated metrics Parameters ---------- all_files: list List of all files input_directory: str Directory which contains .pkl files concatenate: list List of files to concatenate output_dir: str, optional Output directory. Default is 'images' Returns --------- None """ import pickle if not isinstance(concatenate, list): concatenate = [concatenate] os.makedirs(output_dir, exist_ok=True) #Initialize dictionaries to store data metrics = ['density', 'characteristic_path_length', 'entropy'] metric_dict = {'density':[],'characteristic_path_length':[], 'entropy':[], 'water-water':[], 'water-protein':[]} metrics_plot = ['density', 'characteristic_path_length', 'entropy', 'water-water', 'water-protein'] #Formatted titles for plotting plotting_names = ['Graph Density', 'CPL', 'Graph Entropy', 'Water-Water', 'Water-Protein'] #Combine data in all concatenated files for file in concatenate: watcon_file = os.path.join(input_directory, file) with open(watcon_file, 'rb') as FILE: e = pickle.load(FILE) for ts_dict in e[0]: metric_dict['water-water'].append(ts_dict['interaction_counts']['water-water']) metric_dict['water-protein'].append(ts_dict['interaction_counts']['water-protein']) for metric in metrics: if isinstance(ts_dict[metric], float): metric_dict[metric].append(ts_dict[metric]) else: metric_dict[metric].extend([f for f in ts_dict[metric]]) #Select all other files all_files = [f for f in all_files if f not in concatenate] #Create list to store other dictionaries metric_dicts = [] for file in all_files: metric_dict_static = {'density':[],'characteristic_path_length':[], 'entropy':[], 'water-water':[], 'water-protein':[]} watcon_file = os.path.join(input_directory, file) with open(watcon_file, 'rb') as FILE: e = pickle.load(FILE) for ts_dict in e[0]: metric_dict_static['water-water'].append(ts_dict['interaction_counts']['water-water']) metric_dict_static['water-protein'].append(ts_dict['interaction_counts']['water-protein']) for metric in metrics: if isinstance(ts_dict[metric], float): metric_dict_static[metric].append(ts_dict[metric]) else: metric_dict_static[metric].extend(f for f in ts_dict[metric]) metric_dicts.append(metric_dict_static) #Begin plotting for i, metric in enumerate(metrics_plot): fig, ax = plt.subplots(1,figsize=(3,2), tight_layout=True) fig.subplots_adjust(left=0, right=0.85) metric_cur_concatenate = np.array(metric_dict[metric]) hist, xedges = np.histogram(metric_cur_concatenate, bins=15, density=True) xcenters = (xedges[1:]+xedges[:-1])/2 ax.plot(xcenters, hist, label='Concatenated') #sns.kdeplot(data=np.array(metric_cur_dynamic), ax=ax, bw_adjust=2) for i, metric_dict in enumerate(metric_dicts): metric_cur = np.array(metric_dict[metric]) hist, xedges = np.histogram(metric_cur, bins=15, density=True) xcenters = (xedges[1:]+xedges[:-1])/2 ax.plot(xcenters, hist, label=f'Sample {i}') #sns.kdeplot(data=np.array(metric_cur_static), ax=ax) ax.legend(fontsize=8, frameon=False) ax.set_xlabel(plotting_names[i]) ax.set_ylabel('Density') fig.savefig(os.path.join(output_dir,f"{metric}_comparehists.png"), dpi=200, bbox_inches='tight')
[docs] def plot_residue_interactions(topology_file, cutoff=0.0, watcon_directory='watcon_output', output_dir='images'): """ Plot water-protein interactions by residue and color by average number of simultaneous interactions Parameters ---------- topology_file : str Full path to an MDAnalysis-readable topology cutoff : float, optional Cutoff to show residue interactions. Default is 0.2. watcon_directory : str, optional Directory containing WatCon output files. Default is 'watcon_output' output_dir : str, optional Directory to save resulting image. Default is 'images' Returns ------- None """ import pickle from MDAnalysis.lib.util import convert_aa_code import matplotlib.cm as cm import matplotlib.colors as mcolors from collections import Counter #Initialize dictionary for interaction counts interaction_counts = {} #Track how many waters interact per residue water_count_distribution = {} #Store number of valid dictionaries analyzed num_dicts = 0 for watcon_file in os.listdir(watcon_directory): full_path = os.path.join(watcon_directory, watcon_file) with open(full_path, 'rb') as FILE: e = pickle.load(FILE) #Increment num_dicts by the size of the calculated metrics dict num_dicts += len(e[0]) for metrics_dict in e[0]: #Isolate the per_residue_interaction dict per_residue_interaction = metrics_dict['per_residue_interaction'] for res, count in per_residue_interaction.items(): if res not in interaction_counts: #Add interaction count to interaction dictionary interaction_counts[res] = count #If there is an interaction, add the number of simultaneous interactions to the water_count if count > 0: water_count_distribution[res] = [count] else: #Increment interaction counts interaction_counts[res] += count #Add simultaneous waters if count > 0: water_count_distribution[res].append(count) #Normalize interaction counts if num_dicts > 0: for res in interaction_counts: #Normalize by total number of frames interaction_counts[res] /= num_dicts #Take average of simultaneous water interactions mean_water_counts = {res: np.mean(water_list) for res, water_list in water_count_distribution.items()} #Sort residues numerically and remove those under cutoff sorted_residues = sorted([key for key in interaction_counts.keys() if interaction_counts[key] > cutoff], key=int) #Take counts from sorted_residues normalized_counts = [interaction_counts[res] for res in sorted_residues] #Get residue names #Initialize universe with given reference topology file u = mda.Universe(topology_file) #Initialize blank list of resnames (for labelling) resnames = [] for val in sorted_residues: residue = u.select_atoms(f"resid {val}")[0].resname try: one_letter = convert_aa_code(residue) except: #Try a series of known nonstandard residue names if residue == 'CYM' or residue =='CSP': one_letter = 'C' elif residue.startswith('H'): one_letter = 'H' elif residue == 'ASH' or residue == 'AS4': one_letter = 'D' elif residue =='GLH' or residue == 'GL4': one_letter = 'E' elif residue == 'LYN': one_letter = 'K' elif residue == "ARN": one_letter = 'R' elif residue == 'SEP': one_letter = 'S' else: #Default to X print(f"{residue} has no one-letter code, using X") one_letter = 'X' resnames.append(one_letter) #Assign colors based on means of simultaneous waters color_reference_values = np.array([mean_water_counts[res] for res in sorted_residues]) norm = mcolors.Normalize(vmin=min(color_reference_values), vmax=max(color_reference_values)) cmap = cm.PuBu colors = cmap(norm(color_reference_values)) #Format residue labels sorted_residues = np.array([str(resnames[i] + str(int(f) + 1)) for i, f in enumerate(sorted_residues)]) #Create bar plot fig, ax = plt.subplots(figsize=(7.5,2)) fig.subplots_adjust(right=0.75, top=0.90, wspace=0.30) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.bar(sorted_residues, normalized_counts, color=colors, edgecolor='k') #Add colorbar for bars sm = cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) cbar = plt.colorbar(sm, aspect=30, ax=ax, pad=0.010) cbar.set_label("Average Simultaneous\nWaters",fontsize=10) ax.set_ylabel('Interaction Score', fontsize=12) plt.xticks(rotation=90) fig.savefig(os.path.join(output_dir, 'Interaction_counts_bar.png'), dpi=200, bbox_inches='tight')