Skip to content

Commit

Permalink
radius of gyration tools and notebook examples for 10 proteins (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 authored Feb 28, 2024
1 parent 759b34e commit ddf0614
Show file tree
Hide file tree
Showing 19 changed files with 3,072 additions and 17 deletions.
8 changes: 8 additions & 0 deletions mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from .analysis_tools.plot_tools import SimulationOutputFigures
from .analysis_tools.ppi_tools import PPIDistance
from .analysis_tools.rgy import (
RadiusofGyrationAverage,
RadiusofGyrationPerFrame,
RadiusofGyrationPlot,
)
from .analysis_tools.rmsd_tools import RMSDCalculator
from .analysis_tools.vis_tools import VisFunctions, VisualizeProtein
from .preprocess_tools.clean_tools import (
Expand Down Expand Up @@ -33,6 +38,9 @@
"VisualizeProtein",
"RMSDCalculator",
"RemoveWaterCleaningTool",
"RadiusofGyrationAverage",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"Scholar2ResultLLM",
"SerpGitTool",
"SetUpAndRunTool",
Expand Down
4 changes: 4 additions & 0 deletions mdagent/tools/base_tools/analysis_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from .plot_tools import SimulationOutputFigures
from .ppi_tools import PPIDistance
from .rgy import RadiusofGyrationAverage, RadiusofGyrationPerFrame, RadiusofGyrationPlot
from .rmsd_tools import RMSDCalculator
from .vis_tools import VisFunctions, VisualizeProtein

__all__ = [
"PPIDistance",
"RMSDCalculator",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"SimulationOutputFigures",
"VisualizeProtein",
"VisFunctions",
"RadiusofGyrationAverage",
]
166 changes: 166 additions & 0 deletions mdagent/tools/base_tools/analysis_tools/rgy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from typing import Optional

import matplotlib.pyplot as plt
import mdtraj as md
import numpy as np
from langchain.tools import BaseTool

from mdagent.utils import PathRegistry


class RadiusofGyration:
def __init__(self, path_registry):
self.path_registry = path_registry
self.includes_top = [".h5", ".lh5", ".pdb"]

def _grab_files(self, pdb_id: str) -> None:
if "_" in pdb_id:
pdb_id = pdb_id.split("_")[0]
self.pdb_id = pdb_id
all_names = self.path_registry._list_all_paths()
try:
self.pdb_path = [
name
for name in all_names
if pdb_id in name and ".pdb" in name and "records" in name
][0]
except IndexError:
raise ValueError(f"No pdb file found for {pdb_id}")
try:
self.dcd_path = [
name
for name in all_names
if pdb_id in name and ".dcd" in name and "records" in name
][0]
except IndexError:
self.dcd_path = None
pass
return None

def _load_traj(self, pdb_id: str) -> None:
self._grab_files(pdb_id)
if self.dcd_path:
self.traj = md.load(self.dcd_path, top=self.pdb_path)
else:
self.traj = md.load(self.pdb_path)
return None

def rad_gyration_per_frame(self, pdb_id: str) -> str:
self._load_traj(pdb_id)
rg_per_frame = md.compute_rg(self.traj)

self.rgy_file = f"files/radii_of_gyration_{self.pdb_id}.csv"

np.savetxt(
self.rgy_file, rg_per_frame, delimiter=",", header="Radius of Gyration (nm)"
)
self.path_registry.map_path(
f"radii_of_gyration_{self.pdb_id}",
self.rgy_file,
description=f"Radii of gyration per frame for {self.pdb_id}",
)
return f"Radii of gyration saved to {self.rgy_file}"

def rad_gyration_average(self, pdb_id: str) -> str:
_ = self.rad_gyration_per_frame(pdb_id)
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
avg_rg = rg_per_frame.mean()

return f"Average radius of gyration: {avg_rg:.2f} nm"

def plot_rad_gyration(self, pdb_id: str) -> str:
_ = self.rad_gyration_per_frame(pdb_id)
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
plot_name = f"{self.pdb_id}_rgy.png"

plt.plot(rg_per_frame)
plt.xlabel("Frame")
plt.ylabel("Radius of Gyration (nm)")
plt.title(f"{pdb_id} - Radius of Gyration Over Time")

plt.savefig(plot_name)
self.path_registry.map_path(
f"{self.pdb_id}_radii_of_gyration_plot",
plot_name,
description=f"Plot of radii of gyration over time for {self.pdb_id}",
)
return "Plot saved as: " + f"{plot_name}.png"


class RadiusofGyrationAverage(BaseTool):
name = "RadiusofGyrationAverage"
description = """This tool calculates the average radius of gyration
for the given trajectory file. Give this tool the
protein ID (PDB ID) only. The tool will automatically find the necessary files."""

path_registry: Optional[PathRegistry]

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, pdb_id: str) -> str:
"""use the tool."""
try:
RGY = RadiusofGyration(self.path_registry)
return RGY.rad_gyration_average(pdb_id)
except ValueError as e:
return str(e)

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")


class RadiusofGyrationPerFrame(BaseTool):
name = "RadiusofGyrationPerFrame"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory file. Give this tool the
protein ID (PDB ID) only. The tool will automatically find the necessary files.
The tool will save the radii of gyration to a csv file and
map it to the registry."""

path_registry: Optional[PathRegistry]

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, pdb_id: str) -> str:
"""use the tool."""
try:
RGY = RadiusofGyration(self.path_registry)
return RGY.rad_gyration_per_frame(pdb_id)
except ValueError as e:
return str(e)

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")


class RadiusofGyrationPlot(BaseTool):
name = "RadiusofGyrationPlot"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory file and plots it.
Give this tool the protein ID (PDB ID) only.
The tool will automatically find the necessary files.
The tool will save the plot to a png file and map it to the registry."""

path_registry: Optional[PathRegistry]

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, pdb_id: str) -> str:
"""use the tool."""
try:
RGY = RadiusofGyration(self.path_registry)
return RGY.plot_rad_gyration(pdb_id)
except ValueError as e:
return str(e)

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _prompt_summary(self, query: str):

prompt_template = (
"You're an expert programmer and in molecular dynamics. "
"Your job is to make a script to make a simmulation "
"Your job is to make a script to make a simulation "
"in openmm. "
"Youre starting point is a base script that runs a protein on its own. "
"The protein itself doesnt require more preperation. "
Expand Down Expand Up @@ -65,7 +65,7 @@ class ModifyScriptInput(BaseModel):
query: str = Field(
...,
description=(
"Simmulation required by the user.You MUST "
"simulation required by the user.You MUST "
"specify the objective, requirements of the simulation as well "
"as on what protein you are working."
),
Expand Down
53 changes: 39 additions & 14 deletions mdagent/tools/base_tools/simulation_tools/setup_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ class SetUpandRunFunctionInput(BaseModel):
},
description="""Parameters for the openmm integrator.""",
)
simmulation_params: Dict[str, Any] = Field(
simulation_params: Dict[str, Any] = Field(
{
"Ensemble": "NVT",
"Number of Steps": 5000,
Expand Down Expand Up @@ -656,7 +656,7 @@ def __init__(
"constraintTolerance": 0.000001,
"solvate": False,
}
self.sim_params = self.params.get("simmulation_params", None)
self.sim_params = self.params.get("simulation_params", None)
if self.sim_params is None:
self.sim_params = {
"Ensemble": "NVT",
Expand Down Expand Up @@ -736,6 +736,13 @@ def create_simulation(self):
Sim_id=self.sim_id,
term="dcd",
)
topology_name = self.path_registry.write_file_name(
type=FileType.RECORD,
record_type="TOP",
protein_file_id=self.pdb_id,
Sim_id=self.sim_id,
term="pdb",
)

log_name = self.path_registry.write_file_name(
type=FileType.RECORD,
Expand All @@ -749,6 +756,10 @@ def create_simulation(self):
f"Simulation trajectory for protein {self.pdb_id}"
f" and simulation {self.sim_id}"
)
top_desc = (
f"Simulation topology for protein"
f"{self.pdb_id} and simulation {self.sim_id}"
)
log_desc = (
f"Simulation state log for protein {self.pdb_id} "
f"and simulation {self.sim_id}"
Expand All @@ -760,6 +771,12 @@ def create_simulation(self):
self.sim_params["record_interval_steps"],
)
)
self.simulation.reporters.append(
PDBReporter(
f"{topology_name}",
self.sim_params["record_interval_steps"],
)
)
self.simulation.reporters.append(
StateDataReporter(
f"{log_name}",
Expand All @@ -773,6 +790,7 @@ def create_simulation(self):
self.registry_records = [
("holder", f"files/records/{trajectory_name}", traj_desc),
("holder", f"files/records/{log_name}", log_desc),
("holder", f"files/records/{topology_name}", top_desc),
]

# TODO add checkpoint too?
Expand All @@ -784,6 +802,12 @@ def create_simulation(self):
self.sim_params["record_interval_steps"],
)
)
self.simulation.reporters.append(
PDBReporter(
"temp_topology.pdb",
self.sim_params["record_interval_steps"],
)
)
self.simulation.reporters.append(
StateDataReporter(
"temp_log.txt",
Expand Down Expand Up @@ -947,6 +971,7 @@ def unit_to_string(unit):
equilibrationSteps = 1000
platform = Platform.getPlatformByName('CPU')
dcdReporter = DCDReporter('trajectory.dcd', 1000)
pdbReporter = PDBReporter('trajectory.pdb', 1000)
dataReporter = StateDataReporter('log.txt', {record_interval_steps},
totalSteps=steps,
step=True, speed=True, progress=True, elapsedTime=True, remainingTime=True,
Expand Down Expand Up @@ -1041,6 +1066,7 @@ def unit_to_string(unit):
print('Simulating...')
simulation.reporters.append(dcdReporter)
simulation.reporters.append(pdbReporter)
simulation.reporters.append(dataReporter)
simulation.reporters.append(checkpointReporter)
simulation.currentStep = 0
Expand Down Expand Up @@ -1124,7 +1150,6 @@ class SetUpandRunFunction(BaseTool):

def _run(self, **input_args):
if self.path_registry is None:
print("Path registry not initialized")
return "Path registry not initialized"
input = self.check_system_params(input_args)
error = input.get("error", None)
Expand All @@ -1138,7 +1163,6 @@ def _run(self, **input_args):
if pdb_id not in self.path_registry.list_path_names():
return "No pdb_id found in input, use the file id not the file name"
except KeyError:
print("whoops no pdb_id found in input,", input)
return "No pdb_id found in input"
try:
save = input["save"] # either this simulation
Expand All @@ -1152,7 +1176,7 @@ def _run(self, **input_args):
try:
file_name = self.path_registry.write_file_name(
type=FileType.SIMULATION,
type_of_sim=input["simmulation_params"]["Ensemble"],
type_of_sim=input["simulation_params"]["Ensemble"],
protein_file_id=pdb_id,
)

Expand Down Expand Up @@ -1207,10 +1231,11 @@ def _run(self, **input_args):
for record in records:
os.rename(record[1].split("/")[-1], f"{record[1]}")
for record in records:
record[0] = self.path_registry.get_fileid( # Step necessary here to
record[1].split("/")[-1], # avoid id being repeated
FileType.RECORD,
record_list = list(record)
record_list[0] = self.path_registry.get_fileid(
record_list[1].split("/")[-1], FileType.RECORD
)
record = tuple(record_list)
self.path_registry.map_path(*record)
return (
"Simulation done! \n Summary: \n"
Expand Down Expand Up @@ -1531,7 +1556,7 @@ def _process_parameters(self, user_params, param_type="system_params"):
error_msg += msg

return processed_params, error_msg
if param_type == "simmulation_params":
if param_type == "simulation_params":
for key, value in user_params.items():
if key == "Ensemble" or key == "ensemble":
if value == "NPT":
Expand Down Expand Up @@ -1592,9 +1617,9 @@ def check_system_params(cls, values):
"Timestep": 0.004 * picoseconds,
"Pressure": 1.0 * bar,
}
simmulation_params = values.get("simmulation_params")
if simmulation_params is None:
simmulation_params = {
simulation_params = values.get("simulation_params")
if simulation_params is None:
simulation_params = {
"Ensemble": "NVT",
"Number of Steps": 10000,
"record_interval_steps": 100,
Expand All @@ -1604,7 +1629,7 @@ def check_system_params(cls, values):

# system_params = {k.lower(): v for k, v in system_params.items()}
# integrator_params = {k.lower(): v for k, v in integrator_params.items()}
# simmulation_params = {k.lower(): v for k, v in simmulation_params.items()}
# simulation_params = {k.lower(): v for k, v in simulation_params.items()}

nonbondedMethod = system_params.get("nonbondedMethod")
nonbondedCutoff = system_params.get("nonbondedCutoff")
Expand Down Expand Up @@ -1708,7 +1733,7 @@ def check_system_params(cls, values):
"save": save,
"system_params": system_params,
"integrator_params": integrator_params,
"simmulation_params": simmulation_params,
"simulation_params": simulation_params,
}
# if no error, return the values
return values
Expand Down
Loading

0 comments on commit ddf0614

Please sign in to comment.