Skip to content

Commit

Permalink
Figures path registry (#108)
Browse files Browse the repository at this point in the history
* replace solvent to figure in path registry

* Adding the path registry in rmsd, plot and rgy tools to map figures to path registry

* save files in path registry in rgy tool

* added unit testing
  • Loading branch information
Jgmedina95 authored Mar 15, 2024
1 parent 982372f commit ee28a70
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 35 deletions.
24 changes: 18 additions & 6 deletions mdagent/tools/base_tools/analysis_tools/plot_tools.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import csv
import os
import re
from typing import Optional

import matplotlib.pyplot as plt
from langchain.tools import BaseTool

from mdagent.utils import PathRegistry
from mdagent.utils import FileType, PathRegistry


class PlottingTools:
Expand Down Expand Up @@ -61,18 +62,29 @@ def plot_data(self) -> str:
header_lab = (
header.split("(")[0].strip() if "(" in header else header
).lower()
plot_name = f"{self.file_id}_{xlab}_vs_{header_lab}.png"

# Generate and save the plot
plt.figure()
plt.plot(x, y)
plt.xlabel(xlab)
plt.ylabel(header)
plt.title(f"{self.file_id}_{xlab} vs {header_lab}")
plt.savefig(plot_name)
fig_vs = f"{xlab}vs{header_lab}"
# PR: Mapping involves writing file name -> get file id
plot_name = self.path_registry.write_file_name(
type=FileType.FIGURE,
Log_id=self.file_id,
fig_analysis=fig_vs,
file_format="png",
)
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)
if not os.path.exists("files/figures"): # PR: Needed to avoid error
os.makedirs("files/figures")
plt.savefig(f"files/figures/{plot_name}")
self.path_registry.map_path(
plot_name,
plot_name,
plot_id,
f"files/figures/{plot_name}",
(
f"Post Simulation Figure for {self.file_id}"
f" - {header_lab} vs {xlab}"
Expand Down
19 changes: 14 additions & 5 deletions mdagent/tools/base_tools/analysis_tools/rgy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from typing import Optional

import matplotlib.pyplot as plt
import mdtraj as md
import numpy as np
from langchain.tools import BaseTool

from mdagent.utils import PathRegistry
from mdagent.utils import FileType, PathRegistry


class RadiusofGyration:
Expand Down Expand Up @@ -71,20 +72,28 @@ def rad_gyration_average(self, pdb_id: str) -> str:
def plot_rad_gyration(self, pdb_id: str) -> str:
_ = self.rad_gyration_per_frame(pdb_id)
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
plot_name = f"{self.pdb_id}_rgy.png"
fig_analysis = f"rgy_{self.pdb_id}"
plot_name = self.path_registry.write_file_name(
type=FileType.FIGURE, fig_analysis=fig_analysis, file_format="png"
)
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)

plt.plot(rg_per_frame)
plt.xlabel("Frame")
plt.ylabel("Radius of Gyration (nm)")
plt.title(f"{pdb_id} - Radius of Gyration Over Time")

plt.savefig(plot_name)
if not os.path.exists("files/figures"):
os.makedirs("files/figures")
plt.savefig(f"files/figures/{plot_name}")
self.path_registry.map_path(
f"{self.pdb_id}_radii_of_gyration_plot",
plot_id,
plot_name,
description=f"Plot of radii of gyration over time for {self.pdb_id}",
)
return "Plot saved as: " + f"{plot_name}.png"
return "Plot saved as: " + f"{plot_name}.png with plot ID {plot_id}"


class RadiusofGyrationAverage(BaseTool):
Expand Down
23 changes: 19 additions & 4 deletions mdagent/tools/base_tools/analysis_tools/rmsd_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from MDAnalysis.analysis import align, diffusionmap, rms
from pydantic import BaseModel, Field

from mdagent.utils import PathRegistry
from mdagent.utils import FileType, PathRegistry

# all things related to RMSD as 'standard deviation'
# 1 RMSD between two protein conformations or trajectories (1D scalar value)
Expand Down Expand Up @@ -120,11 +120,26 @@ def compute_rmsd(self, selection="backbone", plot=True):
plt.title("Time-Dependent RMSD")
plt.legend()
plt.show()
plt.savefig(f"{self.filename}.png")
if not os.path.exists("files/figures"): # PR: Needed to avoid error
os.makedirs("files/figures")
plot_name = self.path_registry.write_file_name(
type=FileType.FIGURE,
fig_analysis=self.filename,
file_format="png",
)
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)
plt.savefig(f"files/figures/{plot_name}.png")
# plt.close() # if you don't want to show the plot in notebooks
message += f"Plotted RMSD over time. Saved to {self.filename}.png.\n"
# PRComment: Getting description only for the plot
plot_message = (
f"Plotted RMSD over time for{self.pdb_file}."
f" Saved with plot id {plot_id}.\n"
)
message += plot_message
self.path_registry.map_path(
f"{self.filename}.png", f"{self.filename}.png", message
plot_id, f"files/figures/{plot_name}", plot_message
)
return message

Expand Down
10 changes: 2 additions & 8 deletions mdagent/tools/base_tools/preprocess_tools/clean_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from typing import Dict, Optional, Type
from typing import Optional, Type

from langchain.tools import BaseTool
from openmm.app import PDBFile, PDBxFile
from pdbfixer import PDBFixer
from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, Field

from mdagent.utils import FileType, PathRegistry

Expand Down Expand Up @@ -227,12 +227,6 @@ class CleaningToolFunctionInput(BaseModel):
)
add_hydrogens_ph: int = Field(7.0, description="pH at which hydrogens are added.")

@root_validator
def validate_query(cls, values) -> Dict:
"""Check that the input is valid."""

return values


class CleaningToolFunction(BaseTool):
name = "CleaningToolFunction"
Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/base_tools/preprocess_tools/pdb_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ class PDBFilesFixInp(BaseModel):
),
)

@root_validator
@root_validator(skip_on_failure=True)
def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict:
if isinstance(values, str):
print("values is a string", values)
Expand Down
45 changes: 36 additions & 9 deletions mdagent/utils/path_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class FileType(Enum):
PROTEIN = 1
SIMULATION = 2
RECORD = 3
SOLVENT = 4
FIGURE = 4
UNKNOWN = 5


Expand All @@ -29,7 +29,7 @@ def __init__(self):

def _init_path_registry(self):
base_directory = "files"
subdirectories = ["pdb", "records", "simulations", "solvents"]
subdirectories = ["pdb", "records", "simulations", "figures"]
existing_registry = self._load_existing_registry()
file_names_in_registry = []
if existing_registry != {}:
Expand Down Expand Up @@ -61,10 +61,10 @@ def _init_path_registry(self):
else ""
)
)
elif file_type == FileType.SOLVENT:
elif file_type == FileType.FIGURE:
name_parts = file_name.split("_")
solvent_name = name_parts[0]
description = f"Solvent {solvent_name} pdb file. "
figure_name = name_parts[0]
description = f"Figure {figure_name} pdb file. "
else:
description = "Auto-Registered during registry init."
self.map_path(
Expand Down Expand Up @@ -93,8 +93,8 @@ def _determine_file_type(self, subdir):
return FileType.RECORD
elif subdir == "simulations":
return FileType.SIMULATION
elif subdir == "solvents":
return FileType.SOLVENT
elif subdir == "figures":
return FileType.FIGURE
else:
return FileType.UNKNOWN # or some default value

Expand Down Expand Up @@ -239,10 +239,16 @@ def get_fileid(self, file_name: str, type: FileType):
num += 1
rec_id = "rec" + f"{num}" + "_" + timestamp_digits
return rec_id
if type == FileType.SOLVENT:
return parts + "_" + timestamp_digits
if type == FileType.FIGURE:
num = 0
fig_id = "fig" + f"{num}" + "_" + timestamp_digits
while fig_id in current_ids:
num += 1
fig_id = "fig" + f"{num}" + "_" + timestamp_digits
return fig_id

def write_file_name(self, type: FileType, **kwargs):
# PR: I know this looks messy, it is, im adding as things keep coming :c
time_stamp = self.get_timestamp()
protein_name = kwargs.get("protein_name", None)
description = kwargs.get("description", "No description provided")
Expand All @@ -251,8 +257,10 @@ def write_file_name(self, type: FileType, **kwargs):
type_of_sim = kwargs.get("type_of_sim", None)
conditions = kwargs.get("conditions", None)
Sim_id = kwargs.get("Sim_id", None)
Log_id = kwargs.get("Log_id", None)
modified = kwargs.get("modified", False)
term = kwargs.get("term", "term") # Default term if not provided
fig_analysis = kwargs.get("fig_analysis", None)
file_name = ""
if type == FileType.PROTEIN:
file_name += f"{protein_name}_{description}_{time_stamp}.{file_format}"
Expand All @@ -272,6 +280,25 @@ def write_file_name(self, type: FileType, **kwargs):
file_name = (
f"{record_type_name}_{Sim_id}_{protein_file_id}_" f"{time_stamp}.{term}"
)
if type == FileType.FIGURE:
if fig_analysis:
if Sim_id:
file_name += (
f"FIG_{fig_analysis}_{Sim_id}_{time_stamp}.{file_format}"
)
elif Log_id:
file_name += (
f"FIG_{fig_analysis}_{Log_id}_{time_stamp}.{file_format}"
)
else:
file_name += f"FIG_{fig_analysis}_{time_stamp}.{file_format}"
else:
if Sim_id:
file_name += f"FIG_{Sim_id}_{time_stamp}.{file_format}"
elif Log_id:
file_name += f"FIG_{Log_id}_{time_stamp}.{file_format}"
else:
file_name += f"FIG_{time_stamp}.{file_format}"

if file_name == "":
file_name += "ErrorDuringNaming_error.py"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def test_plot_data(plotting_tools):
plotting_tools.headers = headers
plotting_tools.matched_headers = matched_headers
created_plots = plotting_tools.plot_data()
assert "time_vs_value1.png" in created_plots
assert "time_vs_value2.png" in created_plots
assert "FIG_timevsvalue1" in created_plots
assert "FIG_timevsvalue2" in created_plots

# Test failure due to non-numeric data
data_failure = [
Expand Down
39 changes: 39 additions & 0 deletions tests/test_util_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,45 @@ def test_write_file_name_record(path_registry, todays_date):
assert file_name.endswith(".dcd")


def test_write_file_name_figure_1(path_registry, todays_date):
file_name = path_registry.write_file_name(
FileType.FIGURE,
Sim_id="SIM456",
time_stamp=todays_date,
file_format="png",
irrelevant="irrelevant",
)
assert "FIG_SIM456_" in file_name
assert todays_date in file_name
assert file_name.endswith(".png")


def test_write_file_name_figure_2(path_registry, todays_date):
file_name = path_registry.write_file_name(
FileType.FIGURE,
Log_id="LOG_123456",
time_stamp=todays_date,
file_format="jpg",
irrelevant="irrelevant",
)
assert "FIG_LOG_123456_" in file_name
assert todays_date in file_name
assert file_name.endswith(".jpg")


def test_write_file_name_figure_3(path_registry, todays_date):
file_name = path_registry.write_file_name(
FileType.FIGURE,
Log_id="LOG_123456",
fig_analysis="randomanalytic",
file_format="jpg",
irrelevant="irrelevant",
)
assert "FIG_randomanalytic_LOG_123456_" in file_name
assert todays_date in file_name
assert file_name.endswith(".jpg")


def test_map_path(path_registry):
mock_json_data = {
"existing_name": {
Expand Down

0 comments on commit ee28a70

Please sign in to comment.