Skip to content

Commit

Permalink
merging main to UnitTestsJ
Browse files Browse the repository at this point in the history
  • Loading branch information
Jgmedina95 committed Mar 18, 2024
2 parents e86d25a + ee28a70 commit c0f9aa8
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 103 deletions.
24 changes: 18 additions & 6 deletions mdagent/tools/base_tools/analysis_tools/plot_tools.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import csv
import os
import re
from typing import Optional

import matplotlib.pyplot as plt
from langchain.tools import BaseTool

from mdagent.utils import PathRegistry
from mdagent.utils import FileType, PathRegistry


class PlottingTools:
Expand Down Expand Up @@ -61,18 +62,29 @@ def plot_data(self) -> str:
header_lab = (
header.split("(")[0].strip() if "(" in header else header
).lower()
plot_name = f"{self.file_id}_{xlab}_vs_{header_lab}.png"

# Generate and save the plot
plt.figure()
plt.plot(x, y)
plt.xlabel(xlab)
plt.ylabel(header)
plt.title(f"{self.file_id}_{xlab} vs {header_lab}")
plt.savefig(plot_name)
fig_vs = f"{xlab}vs{header_lab}"
# PR: Mapping involves writing file name -> get file id
plot_name = self.path_registry.write_file_name(
type=FileType.FIGURE,
Log_id=self.file_id,
fig_analysis=fig_vs,
file_format="png",
)
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)
if not os.path.exists("files/figures"): # PR: Needed to avoid error
os.makedirs("files/figures")
plt.savefig(f"files/figures/{plot_name}")
self.path_registry.map_path(
plot_name,
plot_name,
plot_id,
f"files/figures/{plot_name}",
(
f"Post Simulation Figure for {self.file_id}"
f" - {header_lab} vs {xlab}"
Expand Down
19 changes: 14 additions & 5 deletions mdagent/tools/base_tools/analysis_tools/rgy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from typing import Optional

import matplotlib.pyplot as plt
import mdtraj as md
import numpy as np
from langchain.tools import BaseTool

from mdagent.utils import PathRegistry
from mdagent.utils import FileType, PathRegistry


class RadiusofGyration:
Expand Down Expand Up @@ -71,20 +72,28 @@ def rad_gyration_average(self, pdb_id: str) -> str:
def plot_rad_gyration(self, pdb_id: str) -> str:
_ = self.rad_gyration_per_frame(pdb_id)
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
plot_name = f"{self.pdb_id}_rgy.png"
fig_analysis = f"rgy_{self.pdb_id}"
plot_name = self.path_registry.write_file_name(
type=FileType.FIGURE, fig_analysis=fig_analysis, file_format="png"
)
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)

plt.plot(rg_per_frame)
plt.xlabel("Frame")
plt.ylabel("Radius of Gyration (nm)")
plt.title(f"{pdb_id} - Radius of Gyration Over Time")

plt.savefig(plot_name)
if not os.path.exists("files/figures"):
os.makedirs("files/figures")
plt.savefig(f"files/figures/{plot_name}")
self.path_registry.map_path(
f"{self.pdb_id}_radii_of_gyration_plot",
plot_id,
plot_name,
description=f"Plot of radii of gyration over time for {self.pdb_id}",
)
return "Plot saved as: " + f"{plot_name}.png"
return "Plot saved as: " + f"{plot_name}.png with plot ID {plot_id}"


class RadiusofGyrationAverage(BaseTool):
Expand Down
23 changes: 19 additions & 4 deletions mdagent/tools/base_tools/analysis_tools/rmsd_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from MDAnalysis.analysis import align, diffusionmap, rms
from pydantic import BaseModel, Field

from mdagent.utils import PathRegistry
from mdagent.utils import FileType, PathRegistry

# all things related to RMSD as 'standard deviation'
# 1 RMSD between two protein conformations or trajectories (1D scalar value)
Expand Down Expand Up @@ -120,11 +120,26 @@ def compute_rmsd(self, selection="backbone", plot=True):
plt.title("Time-Dependent RMSD")
plt.legend()
plt.show()
plt.savefig(f"{self.filename}.png")
if not os.path.exists("files/figures"): # PR: Needed to avoid error
os.makedirs("files/figures")
plot_name = self.path_registry.write_file_name(
type=FileType.FIGURE,
fig_analysis=self.filename,
file_format="png",
)
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)
plt.savefig(f"files/figures/{plot_name}.png")
# plt.close() # if you don't want to show the plot in notebooks
message += f"Plotted RMSD over time. Saved to {self.filename}.png.\n"
# PRComment: Getting description only for the plot
plot_message = (
f"Plotted RMSD over time for{self.pdb_file}."
f" Saved with plot id {plot_id}.\n"
)
message += plot_message
self.path_registry.map_path(
f"{self.filename}.png", f"{self.filename}.png", message
plot_id, f"files/figures/{plot_name}", plot_message
)
return message

Expand Down
10 changes: 2 additions & 8 deletions mdagent/tools/base_tools/preprocess_tools/clean_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from typing import Dict, Optional, Type
from typing import Optional, Type

from langchain.tools import BaseTool
from openmm.app import PDBFile, PDBxFile
from pdbfixer import PDBFixer
from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, Field

from mdagent.utils import FileType, PathRegistry

Expand Down Expand Up @@ -227,12 +227,6 @@ class CleaningToolFunctionInput(BaseModel):
)
add_hydrogens_ph: int = Field(7.0, description="pH at which hydrogens are added.")

@root_validator(skip_on_failure=True)
def validate_query(cls, values) -> Dict:
"""Check that the input is valid."""

return values


class CleaningToolFunction(BaseTool):
name = "CleaningToolFunction"
Expand Down
45 changes: 36 additions & 9 deletions mdagent/utils/path_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class FileType(Enum):
PROTEIN = 1
SIMULATION = 2
RECORD = 3
SOLVENT = 4
FIGURE = 4
UNKNOWN = 5


Expand All @@ -29,7 +29,7 @@ def __init__(self):

def _init_path_registry(self):
base_directory = "files"
subdirectories = ["pdb", "records", "simulations", "solvents"]
subdirectories = ["pdb", "records", "simulations", "figures"]
existing_registry = self._load_existing_registry()
file_names_in_registry = []
if existing_registry != {}:
Expand Down Expand Up @@ -61,10 +61,10 @@ def _init_path_registry(self):
else ""
)
)
elif file_type == FileType.SOLVENT:
elif file_type == FileType.FIGURE:
name_parts = file_name.split("_")
solvent_name = name_parts[0]
description = f"Solvent {solvent_name} pdb file. "
figure_name = name_parts[0]
description = f"Figure {figure_name} pdb file. "
else:
description = "Auto-Registered during registry init."
self.map_path(
Expand Down Expand Up @@ -93,8 +93,8 @@ def _determine_file_type(self, subdir):
return FileType.RECORD
elif subdir == "simulations":
return FileType.SIMULATION
elif subdir == "solvents":
return FileType.SOLVENT
elif subdir == "figures":
return FileType.FIGURE
else:
return FileType.UNKNOWN # or some default value

Expand Down Expand Up @@ -239,10 +239,16 @@ def get_fileid(self, file_name: str, type: FileType):
num += 1
rec_id = "rec" + f"{num}" + "_" + timestamp_digits
return rec_id
if type == FileType.SOLVENT:
return parts + "_" + timestamp_digits
if type == FileType.FIGURE:
num = 0
fig_id = "fig" + f"{num}" + "_" + timestamp_digits
while fig_id in current_ids:
num += 1
fig_id = "fig" + f"{num}" + "_" + timestamp_digits
return fig_id

def write_file_name(self, type: FileType, **kwargs):
# PR: I know this looks messy, it is, im adding as things keep coming :c
time_stamp = self.get_timestamp()
protein_name = kwargs.get("protein_name", None)
description = kwargs.get("description", "No description provided")
Expand All @@ -251,8 +257,10 @@ def write_file_name(self, type: FileType, **kwargs):
type_of_sim = kwargs.get("type_of_sim", None)
conditions = kwargs.get("conditions", None)
Sim_id = kwargs.get("Sim_id", None)
Log_id = kwargs.get("Log_id", None)
modified = kwargs.get("modified", False)
term = kwargs.get("term", "term") # Default term if not provided
fig_analysis = kwargs.get("fig_analysis", None)
file_name = ""
if type == FileType.PROTEIN:
file_name += f"{protein_name}_{description}_{time_stamp}.{file_format}"
Expand All @@ -272,6 +280,25 @@ def write_file_name(self, type: FileType, **kwargs):
file_name = (
f"{record_type_name}_{Sim_id}_{protein_file_id}_" f"{time_stamp}.{term}"
)
if type == FileType.FIGURE:
if fig_analysis:
if Sim_id:
file_name += (
f"FIG_{fig_analysis}_{Sim_id}_{time_stamp}.{file_format}"
)
elif Log_id:
file_name += (
f"FIG_{fig_analysis}_{Log_id}_{time_stamp}.{file_format}"
)
else:
file_name += f"FIG_{fig_analysis}_{time_stamp}.{file_format}"
else:
if Sim_id:
file_name += f"FIG_{Sim_id}_{time_stamp}.{file_format}"
elif Log_id:
file_name += f"FIG_{Log_id}_{time_stamp}.{file_format}"
else:
file_name += f"FIG_{time_stamp}.{file_format}"

if file_name == "":
file_name += "ErrorDuringNaming_error.py"
Expand Down
18 changes: 6 additions & 12 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,21 @@
from mdagent.subagents.agents.skill import SkillManager
from mdagent.subagents.subagent_fxns import Iterator
from mdagent.subagents.subagent_setup import SubAgentSettings
from mdagent.utils import PathRegistry


@pytest.fixture
def path_registry():
return PathRegistry()
def skill_manager(get_registry):
return SkillManager(path_registry=get_registry("raw", False))


@pytest.fixture
def skill_manager(path_registry):
return SkillManager(path_registry=path_registry)
def action(get_registry):
return Action(get_registry("raw", False))


@pytest.fixture
def action(path_registry):
return Action(path_registry)


@pytest.fixture
def iterator(path_registry):
settings = SubAgentSettings(path_registry=path_registry)
def iterator(get_registry):
settings = SubAgentSettings(path_registry=get_registry("raw", False))
return Iterator(subagent_settings=settings)


Expand Down
14 changes: 4 additions & 10 deletions tests/test_analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,16 @@

from mdagent.tools.base_tools import VisFunctions
from mdagent.tools.base_tools.analysis_tools.plot_tools import PlottingTools
from mdagent.utils import PathRegistry


@pytest.fixture
def get_registry():
return PathRegistry()


@pytest.fixture
def plotting_tools(get_registry):
return PlottingTools(get_registry)
return PlottingTools(get_registry("raw", False))


@pytest.fixture
def vis_fxns(get_registry):
return VisFunctions(get_registry)
return VisFunctions(get_registry("raw", False))


@pytest.fixture
Expand Down Expand Up @@ -86,8 +80,8 @@ def test_plot_data(plotting_tools):
plotting_tools.headers = headers
plotting_tools.matched_headers = matched_headers
created_plots = plotting_tools.plot_data()
assert "time_vs_value1.png" in created_plots
assert "time_vs_value2.png" in created_plots
assert "FIG_timevsvalue1" in created_plots
assert "FIG_timevsvalue2" in created_plots

# Test failure due to non-numeric data
data_failure = [
Expand Down
12 changes: 3 additions & 9 deletions tests/test_pdb_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,25 @@
from mdagent.tools.base_tools import get_pdb
from mdagent.tools.base_tools.preprocess_tools.packing import PackMolTool
from mdagent.tools.base_tools.preprocess_tools.pdb_get import MolPDB
from mdagent.utils import PathRegistry


@pytest.fixture
def fibronectin():
return "fibronectin pdb"


@pytest.fixture
def get_registry():
return PathRegistry()


@pytest.fixture
def molpdb(get_registry):
return MolPDB(get_registry)
return MolPDB(get_registry("raw", False))


@pytest.fixture
def packmol(get_registry):
return PackMolTool(get_registry)
return PackMolTool(get_registry("raw", False))


def test_getpdb(fibronectin, get_registry):
name, _ = get_pdb(fibronectin, get_registry)
name, _ = get_pdb(fibronectin, get_registry("raw", False))
assert name.endswith(".pdb")


Expand Down
8 changes: 1 addition & 7 deletions tests/test_simulation_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,11 @@
import pytest

from mdagent.tools.base_tools import SimulationFunctions
from mdagent.utils import PathRegistry


@pytest.fixture
def get_registry():
return PathRegistry()


@pytest.fixture
def sim_fxns(get_registry):
return SimulationFunctions(get_registry)
return SimulationFunctions(get_registry("raw", False))


@patch("os.path.exists")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_subagents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ def set_env():
load_dotenv()


def test_subagent_setup():
settings = SubAgentSettings(path_registry=None)
def test_subagent_setup(get_registry):
settings = SubAgentSettings(get_registry("raw", False))
initializer = SubAgentInitializer(settings)
subagents = initializer.create_iteration_agents()
action = subagents["action"]
Expand Down
Loading

0 comments on commit c0f9aa8

Please sign in to comment.