diff --git a/mdagent/tools/base_tools/analysis_tools/plot_tools.py b/mdagent/tools/base_tools/analysis_tools/plot_tools.py index d479fe07..0dfe79de 100644 --- a/mdagent/tools/base_tools/analysis_tools/plot_tools.py +++ b/mdagent/tools/base_tools/analysis_tools/plot_tools.py @@ -1,11 +1,12 @@ import csv +import os import re from typing import Optional import matplotlib.pyplot as plt from langchain.tools import BaseTool -from mdagent.utils import PathRegistry +from mdagent.utils import FileType, PathRegistry class PlottingTools: @@ -61,18 +62,29 @@ def plot_data(self) -> str: header_lab = ( header.split("(")[0].strip() if "(" in header else header ).lower() - plot_name = f"{self.file_id}_{xlab}_vs_{header_lab}.png" - # Generate and save the plot plt.figure() plt.plot(x, y) plt.xlabel(xlab) plt.ylabel(header) plt.title(f"{self.file_id}_{xlab} vs {header_lab}") - plt.savefig(plot_name) + fig_vs = f"{xlab}vs{header_lab}" + # PR: Mapping involves writing file name -> get file id + plot_name = self.path_registry.write_file_name( + type=FileType.FIGURE, + Log_id=self.file_id, + fig_analysis=fig_vs, + file_format="png", + ) + plot_id = self.path_registry.get_fileid( + file_name=plot_name, type=FileType.FIGURE + ) + if not os.path.exists("files/figures"): # PR: Needed to avoid error + os.makedirs("files/figures") + plt.savefig(f"files/figures/{plot_name}") self.path_registry.map_path( - plot_name, - plot_name, + plot_id, + f"files/figures/{plot_name}", ( f"Post Simulation Figure for {self.file_id}" f" - {header_lab} vs {xlab}" diff --git a/mdagent/tools/base_tools/analysis_tools/rgy.py b/mdagent/tools/base_tools/analysis_tools/rgy.py index 32a18a20..798f5863 100644 --- a/mdagent/tools/base_tools/analysis_tools/rgy.py +++ b/mdagent/tools/base_tools/analysis_tools/rgy.py @@ -1,3 +1,4 @@ +import os from typing import Optional import matplotlib.pyplot as plt @@ -5,7 +6,7 @@ import numpy as np from langchain.tools import BaseTool -from mdagent.utils import PathRegistry +from mdagent.utils import FileType, PathRegistry class RadiusofGyration: @@ -71,20 +72,28 @@ def rad_gyration_average(self, pdb_id: str) -> str: def plot_rad_gyration(self, pdb_id: str) -> str: _ = self.rad_gyration_per_frame(pdb_id) rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1) - plot_name = f"{self.pdb_id}_rgy.png" + fig_analysis = f"rgy_{self.pdb_id}" + plot_name = self.path_registry.write_file_name( + type=FileType.FIGURE, fig_analysis=fig_analysis, file_format="png" + ) + plot_id = self.path_registry.get_fileid( + file_name=plot_name, type=FileType.FIGURE + ) plt.plot(rg_per_frame) plt.xlabel("Frame") plt.ylabel("Radius of Gyration (nm)") plt.title(f"{pdb_id} - Radius of Gyration Over Time") - plt.savefig(plot_name) + if not os.path.exists("files/figures"): + os.makedirs("files/figures") + plt.savefig(f"files/figures/{plot_name}") self.path_registry.map_path( - f"{self.pdb_id}_radii_of_gyration_plot", + plot_id, plot_name, description=f"Plot of radii of gyration over time for {self.pdb_id}", ) - return "Plot saved as: " + f"{plot_name}.png" + return "Plot saved as: " + f"{plot_name}.png with plot ID {plot_id}" class RadiusofGyrationAverage(BaseTool): diff --git a/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py b/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py index a6e5a1b3..0205f0c0 100644 --- a/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py +++ b/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py @@ -9,7 +9,7 @@ from MDAnalysis.analysis import align, diffusionmap, rms from pydantic import BaseModel, Field -from mdagent.utils import PathRegistry +from mdagent.utils import FileType, PathRegistry # all things related to RMSD as 'standard deviation' # 1 RMSD between two protein conformations or trajectories (1D scalar value) @@ -120,11 +120,26 @@ def compute_rmsd(self, selection="backbone", plot=True): plt.title("Time-Dependent RMSD") plt.legend() plt.show() - plt.savefig(f"{self.filename}.png") + if not os.path.exists("files/figures"): # PR: Needed to avoid error + os.makedirs("files/figures") + plot_name = self.path_registry.write_file_name( + type=FileType.FIGURE, + fig_analysis=self.filename, + file_format="png", + ) + plot_id = self.path_registry.get_fileid( + file_name=plot_name, type=FileType.FIGURE + ) + plt.savefig(f"files/figures/{plot_name}.png") # plt.close() # if you don't want to show the plot in notebooks - message += f"Plotted RMSD over time. Saved to {self.filename}.png.\n" + # PRComment: Getting description only for the plot + plot_message = ( + f"Plotted RMSD over time for{self.pdb_file}." + f" Saved with plot id {plot_id}.\n" + ) + message += plot_message self.path_registry.map_path( - f"{self.filename}.png", f"{self.filename}.png", message + plot_id, f"files/figures/{plot_name}", plot_message ) return message diff --git a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py b/mdagent/tools/base_tools/preprocess_tools/clean_tools.py index 589a4294..ec7fbe36 100644 --- a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/clean_tools.py @@ -1,10 +1,10 @@ import os -from typing import Dict, Optional, Type +from typing import Optional, Type from langchain.tools import BaseTool from openmm.app import PDBFile, PDBxFile from pdbfixer import PDBFixer -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field from mdagent.utils import FileType, PathRegistry @@ -227,12 +227,6 @@ class CleaningToolFunctionInput(BaseModel): ) add_hydrogens_ph: int = Field(7.0, description="pH at which hydrogens are added.") - @root_validator - def validate_query(cls, values) -> Dict: - """Check that the input is valid.""" - - return values - class CleaningToolFunction(BaseTool): name = "CleaningToolFunction" diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py b/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py index 4cef4ef0..b63af3f3 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py @@ -660,7 +660,7 @@ class PDBFilesFixInp(BaseModel): ), ) - @root_validator + @root_validator(skip_on_failure=True) def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: if isinstance(values, str): print("values is a string", values) diff --git a/mdagent/utils/path_registry.py b/mdagent/utils/path_registry.py index 4c65e0cb..84d75d96 100644 --- a/mdagent/utils/path_registry.py +++ b/mdagent/utils/path_registry.py @@ -10,7 +10,7 @@ class FileType(Enum): PROTEIN = 1 SIMULATION = 2 RECORD = 3 - SOLVENT = 4 + FIGURE = 4 UNKNOWN = 5 @@ -29,7 +29,7 @@ def __init__(self): def _init_path_registry(self): base_directory = "files" - subdirectories = ["pdb", "records", "simulations", "solvents"] + subdirectories = ["pdb", "records", "simulations", "figures"] existing_registry = self._load_existing_registry() file_names_in_registry = [] if existing_registry != {}: @@ -61,10 +61,10 @@ def _init_path_registry(self): else "" ) ) - elif file_type == FileType.SOLVENT: + elif file_type == FileType.FIGURE: name_parts = file_name.split("_") - solvent_name = name_parts[0] - description = f"Solvent {solvent_name} pdb file. " + figure_name = name_parts[0] + description = f"Figure {figure_name} pdb file. " else: description = "Auto-Registered during registry init." self.map_path( @@ -93,8 +93,8 @@ def _determine_file_type(self, subdir): return FileType.RECORD elif subdir == "simulations": return FileType.SIMULATION - elif subdir == "solvents": - return FileType.SOLVENT + elif subdir == "figures": + return FileType.FIGURE else: return FileType.UNKNOWN # or some default value @@ -239,10 +239,16 @@ def get_fileid(self, file_name: str, type: FileType): num += 1 rec_id = "rec" + f"{num}" + "_" + timestamp_digits return rec_id - if type == FileType.SOLVENT: - return parts + "_" + timestamp_digits + if type == FileType.FIGURE: + num = 0 + fig_id = "fig" + f"{num}" + "_" + timestamp_digits + while fig_id in current_ids: + num += 1 + fig_id = "fig" + f"{num}" + "_" + timestamp_digits + return fig_id def write_file_name(self, type: FileType, **kwargs): + # PR: I know this looks messy, it is, im adding as things keep coming :c time_stamp = self.get_timestamp() protein_name = kwargs.get("protein_name", None) description = kwargs.get("description", "No description provided") @@ -251,8 +257,10 @@ def write_file_name(self, type: FileType, **kwargs): type_of_sim = kwargs.get("type_of_sim", None) conditions = kwargs.get("conditions", None) Sim_id = kwargs.get("Sim_id", None) + Log_id = kwargs.get("Log_id", None) modified = kwargs.get("modified", False) term = kwargs.get("term", "term") # Default term if not provided + fig_analysis = kwargs.get("fig_analysis", None) file_name = "" if type == FileType.PROTEIN: file_name += f"{protein_name}_{description}_{time_stamp}.{file_format}" @@ -272,6 +280,25 @@ def write_file_name(self, type: FileType, **kwargs): file_name = ( f"{record_type_name}_{Sim_id}_{protein_file_id}_" f"{time_stamp}.{term}" ) + if type == FileType.FIGURE: + if fig_analysis: + if Sim_id: + file_name += ( + f"FIG_{fig_analysis}_{Sim_id}_{time_stamp}.{file_format}" + ) + elif Log_id: + file_name += ( + f"FIG_{fig_analysis}_{Log_id}_{time_stamp}.{file_format}" + ) + else: + file_name += f"FIG_{fig_analysis}_{time_stamp}.{file_format}" + else: + if Sim_id: + file_name += f"FIG_{Sim_id}_{time_stamp}.{file_format}" + elif Log_id: + file_name += f"FIG_{Log_id}_{time_stamp}.{file_format}" + else: + file_name += f"FIG_{time_stamp}.{file_format}" if file_name == "": file_name += "ErrorDuringNaming_error.py" diff --git a/tests/test_analysis_tools.py b/tests/test_analysis_tools.py index 1ce3ab8d..48750af8 100644 --- a/tests/test_analysis_tools.py +++ b/tests/test_analysis_tools.py @@ -86,8 +86,8 @@ def test_plot_data(plotting_tools): plotting_tools.headers = headers plotting_tools.matched_headers = matched_headers created_plots = plotting_tools.plot_data() - assert "time_vs_value1.png" in created_plots - assert "time_vs_value2.png" in created_plots + assert "FIG_timevsvalue1" in created_plots + assert "FIG_timevsvalue2" in created_plots # Test failure due to non-numeric data data_failure = [ diff --git a/tests/test_util_tools.py b/tests/test_util_tools.py index 3b95eaea..b61077ca 100644 --- a/tests/test_util_tools.py +++ b/tests/test_util_tools.py @@ -91,6 +91,45 @@ def test_write_file_name_record(path_registry, todays_date): assert file_name.endswith(".dcd") +def test_write_file_name_figure_1(path_registry, todays_date): + file_name = path_registry.write_file_name( + FileType.FIGURE, + Sim_id="SIM456", + time_stamp=todays_date, + file_format="png", + irrelevant="irrelevant", + ) + assert "FIG_SIM456_" in file_name + assert todays_date in file_name + assert file_name.endswith(".png") + + +def test_write_file_name_figure_2(path_registry, todays_date): + file_name = path_registry.write_file_name( + FileType.FIGURE, + Log_id="LOG_123456", + time_stamp=todays_date, + file_format="jpg", + irrelevant="irrelevant", + ) + assert "FIG_LOG_123456_" in file_name + assert todays_date in file_name + assert file_name.endswith(".jpg") + + +def test_write_file_name_figure_3(path_registry, todays_date): + file_name = path_registry.write_file_name( + FileType.FIGURE, + Log_id="LOG_123456", + fig_analysis="randomanalytic", + file_format="jpg", + irrelevant="irrelevant", + ) + assert "FIG_randomanalytic_LOG_123456_" in file_name + assert todays_date in file_name + assert file_name.endswith(".jpg") + + def test_map_path(path_registry): mock_json_data = { "existing_name": {