Skip to content

Commit

Permalink
salt bridge code update
Browse files Browse the repository at this point in the history
  • Loading branch information
brittyscience committed Jan 23, 2025
1 parent 6841574 commit ed756af
Showing 1 changed file with 105 additions and 29 deletions.
134 changes: 105 additions & 29 deletions mdagent/tools/base_tools/analysis_tools/salt_bridge_tool.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional

import mdtraj as md
import numpy as np
import pandas as pd
from langchain.tools import BaseTool
from pydantic import BaseModel, Field

from mdagent.utils import PathRegistry
from mdagent.utils import FileType, PathRegistry, load_single_traj


class SaltBridgeFunction:
Expand All @@ -14,39 +16,97 @@ def __init__(self, path_registry):
self.traj = None

def find_salt_bridges(
self, traj_file, top_file=None, threshold_distance=0.4, residue_pairs=None
self,
traj,
threshold_distance: float = 0.4,
residue_pairs=None,
):
# add two files here in similar format as line 14 above
traj_file_path = self.path_registry.get_mapped_path(traj_file)
ending = traj_file_path.split(".")[-1]
if ending in ["dcd", "xtc", "xyz"] and top_file is not None:
top_file_path = self.path_registry.get_mapped_path(top_file)
self.traj = md.load(traj_file_path, top=top_file_path)
else:
self.traj = md.load(traj_file_path)
"""Find Salt Bridge in molecular dynamics simulation trajectory.
Description:
traj: MDtraj trajectory
thresold_distance: maximum distance between residues for salt bridge formation
residue_pairs: list of pairs of residues
"""
if traj is None:
raise Exception("Trajectory is None")

self.traj = traj
if residue_pairs is None:
residue_pairs = [
("ARG", "ASP"),
("ARG", "GLU"),
("LYS", "ASP"),
("LYS", "GLU"),
]
donor_acceptor_pairs = []

for pair in residue_pairs:
print(f"Looking for salt bridges between {pair[0]} and {pair[1]} pairs...")

donor_residues = self.traj.topology.select(f'residue_name == "{pair[0]}"')
acceptor_residues = self.traj.topology.select(
f'residue_name == "{pair[1]}"'
)
# generate all possible donor-acceptor pairs
pairs = np.array(np.meshgrid(donor_residues, acceptor_residues)).T.reshape(
-1, 2
)
donor_acceptor_pairs.append(pairs)

# Combines all rsidue pairs
donor_acceptor_pairs = np.vstack(donor_acceptor_pairs)
# filter by threshold distance

all_distance = md.compute_distances(self.traj, donor_acceptor_pairs)

mini_distances = np.min(all_distance, axis=0)
within_threshold = mini_distances <= threshold_distance
filtered_pairs = donor_acceptor_pairs[within_threshold]

self.salt_bridges = [tuple(pair) for pair in filtered_pairs]
file_id = self.save_results_to_file()

return file_id

for donor_idx in donor_residues:
for acceptor_idx in acceptor_residues:
distances = md.compute_distances(
self.traj, [[donor_idx, acceptor_idx]]
)
if any(d <= threshold_distance for d in distances):
self.salt_bridges.append((donor_idx, acceptor_idx))
return self.salt_bridges
def save_results_to_file(self):
if self.traj is None:
raise Exception("Trajectory is None")

salt_bridge_data = []

for bridge in self.salt_bridges:
donor_atom = self.traj.topology.atom(bridge[0])
acceptor_atom = self.traj.topology.atom(bridge[1])
salt_bridge_data.append(
{
"Donor Atom Index": bridge[0],
"Donor Residue": (
f"{donor_atom.residue.index +1} "
f" ({donor_atom.residue.name})"
),
"Acceptor Atom Index": bridge[1],
"Acceptor Atom Residue": (
f"{acceptor_atom.residue.index + 1} "
f"({acceptor_atom.residue.name})"
),
}
)

df = pd.DataFrame(salt_bridge_data)

file_name = self.path_registry.write_file_name(
FileType.RECORD,
record_type="salt_bridges",
file_format="csv",
)
file_id = self.path_registry.get_fileid(file_name, FileType.RECORD)
file_path = f"{self.path_registry.ckpt_records}/{file_name}"
df.to_csv(file_path, index=False)
self.path_registry.map_path(
file_id, file_path, description="salt bridge analysis"
)
return file_id

def get_results_string(self):
msg = "Salt bridges found: "
Expand All @@ -55,17 +115,18 @@ def get_results_string(self):
f"Residue {self.traj.topology.atom(bridge[0]).residue.index + 1} "
f"({self.traj.topology.atom(bridge[0]).residue.name}) - "
f"Residue {self.traj.topology.atom(bridge[1]).residue.index + 1} "
f"({self.traj.topology.atom(bridge[1]).residue.name})"
f"({self.traj.topology.atom(bridge[1]).residue.name})\n"
)
return msg


class SaltBridgeToolInput(BaseModel):
trajectory_fileid: str = Field(
None, description="Trajectory file ID. Either dcd, hdf5, xtc, or xyz"
traj_id: str = Field(
None,
description="Trajectory file ID. Either dcd, hdf5, xtc, or xyz",
)

topology_fileid: Optional[str] = Field(None, description="Topology file ID")
top_id: Optional[str] = Field(None, description="Topology file ID")

threshold_distance: Optional[float] = Field(
0.4,
Expand All @@ -74,31 +135,46 @@ class SaltBridgeToolInput(BaseModel):
),
)

residue_pairs: Optional[dict] = Field(
None, description=("Identifies the amino acid residues for salt bridge")
residue_pairs: Optional[str] = Field(
None,
description=("Identifies the amino acid residues for salt bridge"),
)


class SaltBridgeTool(BaseTool):
name = "SaltBridgeTool"
description = "A tool to find salt bridge in a protein trajectory"
args_schema = SaltBridgeToolInput
path_registry: Optional[PathRegistry]
path_registry: PathRegistry | None = None

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(
self, traj_file, top_file=None, threshold_distance=0.4, residue_pairs=None
self,
traj_file: str,
top_file: str | None = None,
threshold_distance=0.4,
residue_pairs=None,
):
try:
if self.path_registry is None:
return "Path registry is not set"
traj = load_single_traj(self.path_registry, traj_file, top_file)
if not traj:
return "Trajectory Failed to load ."

# Load trajectory using MDTraj

traj = load_single_traj(traj_file, top_file)
# calls the salt bridge function
salt_bridge_function = SaltBridgeFunction(self.path_registry)
salt_bridge_function.find_salt_bridges(
traj_file, top_file, threshold_distance, residue_pairs
results_file_id = salt_bridge_function.find_salt_bridges(
traj, threshold_distance, residue_pairs
)
message = salt_bridge_function.get_results_string()
message += f"Saved to results file with fle id: {results_file_id}"
return "Succeeded. " + message
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"
return "Succeeded. " + message

0 comments on commit ed756af

Please sign in to comment.