Skip to content

Commit

Permalink
Hydrogen bonding (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
brittyscience authored Jan 23, 2025
1 parent 573259f commit 9375bf4
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .analysis_tools.distance_tools import ContactsTool, DistanceMatrixTool
from .analysis_tools.hydrogen_bonding_tools import HydrogenBondTool
from .analysis_tools.inertia import MomentOfInertia
from .analysis_tools.pca_tools import PCATool
from .analysis_tools.plot_tools import SimulationOutputFigures
Expand Down Expand Up @@ -67,6 +68,7 @@
"ComputeRMSF",
"ContactsTool",
"DistanceMatrixTool",
"HydrogenBondTool",
"ListRegistryPaths",
"MapPath2Name",
"MapProteinRepresentation",
Expand Down
2 changes: 2 additions & 0 deletions mdagent/tools/base_tools/analysis_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .distance_tools import ContactsTool, DistanceMatrixTool
from .hydrogen_bonding_tools import HydrogenBondTool
from .inertia import MomentOfInertia
from .pca_tools import PCATool
from .plot_tools import SimulationOutputFigures
Expand All @@ -14,6 +15,7 @@
"ComputeRMSF",
"ContactsTool",
"DistanceMatrixTool",
"HydrogenBondTool",
"MomentOfInertia",
"PCATool",
"PPIDistance",
Expand Down
128 changes: 128 additions & 0 deletions mdagent/tools/base_tools/analysis_tools/hydrogen_bonding_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import matplotlib.pyplot as plt
import mdtraj as md
from langchain.tools import BaseTool

from mdagent.utils import FileType, PathRegistry, load_single_traj


class HydrogenBondTool(BaseTool):
"""Note that this tool only usees the Baker-Hubbard method for identifying hydrogen bonds.
Other methods (kabsch-sander, wernet-nilsson) can be implemented later, if desired.
"""

name = "hydrogen_bond_tool"
description = (
"Identifies hydrogen bonds and plots the results from the"
"provided trajectory data."
"Input the File ID for the trajectory file and optionally the topology file. "
"The tool will output the file ID of the results and plot."
)

path_registry: PathRegistry | None = None
freq: float = 0.3

def __init__(self, path_registry, freq=0.1):
super().__init__()
self.path_registry = path_registry
self.freq = freq

def compute_hbonds_traj(self, traj):
hbond_counts = []
for frame in range(traj.n_frames):
hbonds = md.baker_hubbard(traj[frame], freq=self.freq)
hbond_counts.append(len(hbonds))
return hbond_counts

def write_hbond_counts_to_file(self, hbond_counts, traj_id):
output_file = f"{traj_id}_hbond_counts"

file_name = self.path_registry.write_file_name(
type=FileType.RECORD, fig_analysis=output_file, file_format="csv"
)
file_id = self.path_registry.get_fileid(
file_name=file_name, type=FileType.FIGURE
)

file_path = f"{self.path_registry.ckpt_records}/{file_name}"
file_path = file_path if file_path.endswith(".csv") else file_path + ".csv"

with open(file_path, "w") as f:
f.write("Frame,Hydrogen Bonds\n")
for frame, count in enumerate(hbond_counts):
f.write(f"{frame},{count}\n")
self.path_registry.map_path(
file_id,
file_path,
description=f"Hydrogen bond counts for {traj_id}",
)
return f"Data saved to: {file_id}, full path: {file_path}"

def plot_hbonds_over_time(self, hbond_counts, traj, traj_id):
fig_analysis = f"hbonds_over_time_{traj_id}"
plot_name = self.path_registry.write_file_name(
type=FileType.FIGURE, fig_analysis=fig_analysis, file_format="png"
)
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)
plot_path = f"{self.path_registry.ckpt_figures}/{plot_name}"
plot_path = plot_path if plot_path.endswith(".png") else plot_path + ".png"
plt.plot(range(traj.n_frames), hbond_counts, marker="o")
plt.xlabel("Frame")
plt.ylabel("Number of Hydrogen Bonds")
plt.title(f"Hydrogen Bonds Over Time for traj {traj_id}")
plt.grid(True)
plt.savefig(f"{plot_path}")

self.path_registry.map_path(
plot_id,
plot_path,
description=f"Plot of hydrogen bonds over time for {traj_id}",
)
plt.close()
plt.clf()
return f"plot saved to: {plot_id}, full path: {plot_path}"

def _run(
self,
top_file: str,
traj_file: str | None = None,
) -> str:
try:
traj_file = (
traj_file
if (traj_file is not None) and (traj_file != top_file)
else None
)
traj = load_single_traj(
path_registry=self.path_registry,
top_fileid=top_file,
traj_fileid=traj_file,
traj_required=False,
)
if not traj:
raise ValueError("Trajectory could not be loaded.")
except Exception as e:
return f"Error loading traj: {e}"

try:
hbond_counts = self.compute_hbonds_traj(traj)
rtrn_msg = ""
if all(count == 0 for count in hbond_counts):
rtrn_msg += (
"No hydrogen bonds found in the trajectory. "
"Did you forget to add missing hydrogens? "
)
traj_file = top_file if not traj_file else traj_file
plot_id = self.plot_hbonds_over_time(hbond_counts, traj, traj_file)
data_id = self.write_hbond_counts_to_file(hbond_counts, traj_file)
return f"Hydrogen bond analysis completed. {data_id}, {plot_id} {rtrn_msg}."
except Exception as e:
return f"Error during hydrogen bond analysis: {e}"

async def _arun(
self,
top_file: str,
traj_file: str | None = None,
) -> str:
raise NotImplementedError
2 changes: 2 additions & 0 deletions mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
GetSubunitStructure,
GetTurnsBetaSheetsHelices,
GetUniprotID,
HydrogenBondTool,
ListRegistryPaths,
MapProteinRepresentation,
ModifyBaseSimulationScriptTool,
Expand Down Expand Up @@ -91,6 +92,7 @@ def make_all_tools(
ComputeRMSF(path_registry=path_instance),
ContactsTool(path_registry=path_instance),
DistanceMatrixTool(path_registry=path_instance),
HydrogenBondTool(path_registry=path_instance),
ListRegistryPaths(path_registry=path_instance),
MomentOfInertia(path_registry=path_instance),
PackMolTool(path_registry=path_instance),
Expand Down
2 changes: 2 additions & 0 deletions mdagent/utils/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def load_single_traj(
),
UserWarning,
)

return md.load(top_path)
else:
raise ValueError("Trajectory File ID is required, and it's not provided.")
Expand Down Expand Up @@ -88,6 +89,7 @@ def load_traj_with_ref(
ref_traj = load_single_traj(
path_registry, ref_top_id, ref_traj_id, traj_required, ignore_warnings
)

return traj, ref_traj


Expand Down
55 changes: 55 additions & 0 deletions tests/test_analysis/test_hydrogen_bonding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import mdtraj as md
import numpy as np
import pytest

from mdagent.tools.base_tools.analysis_tools.hydrogen_bonding_tools import (
HydrogenBondTool,
)


@pytest.fixture
def hydrogen_bond_tool(get_registry):
path_registry = get_registry("raw", True)
return HydrogenBondTool(path_registry)


@pytest.fixture
def dummy_traj():
topology = md.Topology()
chain = topology.add_chain()
residue = topology.add_residue("ALA", chain)
atom1 = topology.add_atom("N", element=md.element.nitrogen, residue=residue)
atom2 = topology.add_atom("H", element=md.element.hydrogen, residue=residue)
atom3 = topology.add_atom("O", element=md.element.oxygen, residue=residue)
topology.add_bond(atom1, atom2)
topology.add_bond(atom1, atom3)

n_atoms = topology.n_atoms
n_frames = 3
coordinates = np.zeros((n_frames, n_atoms, 3))

coordinates[0, :, :] = [[0, 0, 0], [1, 0, 0], [0, 1, 0]]
coordinates[1, :, :] = [[0, 0, 0], [1.1, 0, 0], [0, 1.1, 0]]
coordinates[2, :, :] = [[0, 0, 0], [1.2, 0, 0], [0, 1.2, 0]]

traj = md.Trajectory(coordinates, topology)
return traj


def test_compute_hbonds_traj(hydrogen_bond_tool, dummy_traj):
hbond_counts = hydrogen_bond_tool.compute_hbonds_traj(dummy_traj)
assert hbond_counts == [0, 0, 0]


def test_plot_hbonds_over_time(hydrogen_bond_tool, dummy_traj):
hbond_counts = hydrogen_bond_tool.compute_hbonds_traj(dummy_traj)
result = hydrogen_bond_tool.plot_hbonds_over_time(hbond_counts, dummy_traj, "dummy")
assert "plot saved to" in result
assert ".png" in result


def test_write_hbond_counts_to_file(hydrogen_bond_tool, dummy_traj):
hbond_counts = hydrogen_bond_tool.compute_hbonds_traj(dummy_traj)
result = hydrogen_bond_tool.write_hbond_counts_to_file(hbond_counts, "dummy")
assert "Data saved to" in result
assert ".csv" in result

0 comments on commit 9375bf4

Please sign in to comment.