From 9228878e0fd5fcd88ed52e3fa8a0e52c90d0cfe1 Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Tue, 27 Feb 2024 09:04:41 -0800 Subject: [PATCH] Registry (#94) --- mdagent/tools/base_tools/__init__.py | 17 +- .../base_tools/analysis_tools/__init__.py | 3 +- .../base_tools/analysis_tools/plot_tools.py | 157 +- .../base_tools/analysis_tools/ppi_tools.py | 23 +- .../base_tools/analysis_tools/rmsd_tools.py | 60 +- .../base_tools/analysis_tools/vis_tools.py | 106 +- .../base_tools/preprocess_tools/__init__.py | 3 +- .../preprocess_tools/clean_tools.py | 107 +- .../base_tools/preprocess_tools/packing.py | 466 ++++++ .../base_tools/preprocess_tools/pdb_fix.py | 764 +++++++++ .../base_tools/preprocess_tools/pdb_get.py | 222 +++ .../base_tools/preprocess_tools/pdb_tools.py | 1486 ----------------- .../base_tools/simulation_tools/__init__.py | 8 +- .../simulation_tools/create_simulation.py | 85 +- .../simulation_tools/setup_and_run.py | 132 +- .../base_tools/util_tools/git_issues_tool.py | 19 +- mdagent/tools/maketools.py | 13 +- tests/test_fxns.py | 126 +- 18 files changed, 1847 insertions(+), 1950 deletions(-) create mode 100644 mdagent/tools/base_tools/preprocess_tools/packing.py create mode 100644 mdagent/tools/base_tools/preprocess_tools/pdb_fix.py create mode 100644 mdagent/tools/base_tools/preprocess_tools/pdb_get.py delete mode 100644 mdagent/tools/base_tools/preprocess_tools/pdb_tools.py diff --git a/mdagent/tools/base_tools/__init__.py b/mdagent/tools/base_tools/__init__.py index 404fd3ca..1a333e72 100644 --- a/mdagent/tools/base_tools/__init__.py +++ b/mdagent/tools/base_tools/__init__.py @@ -1,11 +1,7 @@ from .analysis_tools.plot_tools import SimulationOutputFigures from .analysis_tools.ppi_tools import PPIDistance from .analysis_tools.rmsd_tools import RMSDCalculator -from .analysis_tools.vis_tools import ( - CheckDirectoryFiles, - VisFunctions, - VisualizeProtein, -) +from .analysis_tools.vis_tools import VisFunctions, VisualizeProtein from .preprocess_tools.clean_tools import ( AddHydrogensCleaningTool, CleaningToolFunction, @@ -13,15 +9,10 @@ RemoveWaterCleaningTool, SpecializedCleanTool, ) -from .preprocess_tools.pdb_tools import ( - PackMolTool, - ProteinName2PDBTool, - SmallMolPDB, - get_pdb, -) +from .preprocess_tools.packing import PackMolTool +from .preprocess_tools.pdb_get import ProteinName2PDBTool, SmallMolPDB, get_pdb from .simulation_tools.create_simulation import ModifyBaseSimulationScriptTool from .simulation_tools.setup_and_run import ( - InstructionSummary, SetUpandRunFunction, SetUpAndRunTool, SimulationFunctions, @@ -32,9 +23,7 @@ __all__ = [ "AddHydrogensCleaningTool", - "CheckDirectoryFiles", "CleaningTools", - "InstructionSummary", "ListRegistryPaths", "MapPath2Name", "ProteinName2PDBTool", diff --git a/mdagent/tools/base_tools/analysis_tools/__init__.py b/mdagent/tools/base_tools/analysis_tools/__init__.py index 2243f0d2..7cb79fbd 100644 --- a/mdagent/tools/base_tools/analysis_tools/__init__.py +++ b/mdagent/tools/base_tools/analysis_tools/__init__.py @@ -1,13 +1,12 @@ from .plot_tools import SimulationOutputFigures from .ppi_tools import PPIDistance from .rmsd_tools import RMSDCalculator -from .vis_tools import CheckDirectoryFiles, VisFunctions, VisualizeProtein +from .vis_tools import VisFunctions, VisualizeProtein __all__ = [ "PPIDistance", "RMSDCalculator", "SimulationOutputFigures", - "CheckDirectoryFiles", "VisualizeProtein", "VisFunctions", ] diff --git a/mdagent/tools/base_tools/analysis_tools/plot_tools.py b/mdagent/tools/base_tools/analysis_tools/plot_tools.py index bf004fc0..d479fe07 100644 --- a/mdagent/tools/base_tools/analysis_tools/plot_tools.py +++ b/mdagent/tools/base_tools/analysis_tools/plot_tools.py @@ -8,60 +8,88 @@ from mdagent.utils import PathRegistry -def process_csv(file_name): - with open(file_name, "r") as f: - reader = csv.DictReader(f) - headers = reader.fieldnames - data = list(reader) - - matched_headers = [ - (i, header) - for i, header in enumerate(headers) - if re.search(r"(step|time)", header, re.IGNORECASE) - ] - - return data, headers, matched_headers - - -def plot_data(data, headers, matched_headers): - # Get the first matched header - if matched_headers: - time_or_step = matched_headers[0][1] - xlab = "step" if "step" in time_or_step.lower() else "time" - else: - print("No 'step' or 'time' headers found.") - return - - failed_headers = [] - created_plots = [] - for header in headers: - if header != time_or_step: - try: - x = [float(row[time_or_step]) for row in data] - y = [float(row[header]) for row in data] - - header_lab = ( - header.split("(")[0].strip() if "(" in header else header - ).lower() - plot_name = f"{xlab}_vs_{header_lab}.png" - - # Generate and save the plot - plt.figure() - plt.plot(x, y) - plt.xlabel(xlab) - plt.ylabel(header) - plt.title(f"{xlab} vs {header_lab}") - plt.savefig(plot_name) - plt.close() - - created_plots.append(plot_name) - except ValueError: - failed_headers.append(header) - - if len(failed_headers) == len(headers) - 1: # -1 to account for time_or_step header - raise Exception("All plots failed due to non-numeric data.") - - return ", ".join(created_plots) +class PlottingTools: + def __init__( + self, + path_registry, + ): + self.path_registry = path_registry + self.data = None + self.headers = None + self.matched_headers = None + self.file_id = None + self.file_path = None + + def _find_file(self, file_id: str) -> None: + self.file_id = file_id + self.file_path = self.path_registry.get_mapped_path(file_id) + if not self.file_path: + raise FileNotFoundError("File not found.") + return None + + def process_csv(self) -> None: + with open(self.file_path, "r") as f: + reader = csv.DictReader(f) + self.headers = reader.fieldnames if reader.fieldnames is not None else [] + self.data = list(reader) + + self.matched_headers = [ + (i, header) + for i, header in enumerate(self.headers) + if re.search(r"(step|time)", header, re.IGNORECASE) + ] + + if not self.matched_headers or not self.headers or not self.data: + raise ValueError("File could not be processed.") + return None + + def plot_data(self) -> str: + if self.matched_headers: + time_or_step = self.matched_headers[0][1] + xlab = "step" if "step" in time_or_step.lower() else "time" + else: + raise ValueError("No timestep found.") + + failed_headers = [] + created_plots = [] + for header in self.headers: + if header != time_or_step: + try: + x = [float(row[time_or_step]) for row in self.data] + y = [float(row[header]) for row in self.data] + + header_lab = ( + header.split("(")[0].strip() if "(" in header else header + ).lower() + plot_name = f"{self.file_id}_{xlab}_vs_{header_lab}.png" + + # Generate and save the plot + plt.figure() + plt.plot(x, y) + plt.xlabel(xlab) + plt.ylabel(header) + plt.title(f"{self.file_id}_{xlab} vs {header_lab}") + plt.savefig(plot_name) + self.path_registry.map_path( + plot_name, + plot_name, + ( + f"Post Simulation Figure for {self.file_id}" + f" - {header_lab} vs {xlab}" + ), + ) + plt.close() + + created_plots.append(plot_name) + except ValueError: + failed_headers.append(header) + + if ( + len(failed_headers) == len(self.headers) - 1 + ): # -1 to account for time_or_step header + raise Exception("All plots failed due to non-numeric data.") + + return ", ".join(created_plots) class SimulationOutputFigures(BaseTool): @@ -71,24 +99,27 @@ class SimulationOutputFigures(BaseTool): simulation and create figures for all physical parameters versus timestep of the simulation. - Give this tool the path to the - csv file output from the simulation.""" + Give this tool the name of the + csv file output from the simulation. + The tool will get the exact path.""" path_registry: Optional[PathRegistry] - def _run(self, file_path: str) -> str: + def __init__(self, path_registry: Optional[PathRegistry] = None): + super().__init__() + self.path_registry = path_registry + + def _run(self, file_id: str) -> str: """use the tool.""" try: - data, headers, matched_headers = process_csv(file_path) - plot_result = plot_data(data, headers, matched_headers) + plotting_tools = PlottingTools(self.path_registry) + plotting_tools._find_file(file_id) + plotting_tools.process_csv() + plot_result = plotting_tools.plot_data() if type(plot_result) == str: return "Figures created: " + plot_result else: return "No figures created." - except ValueError: - return "No timestep data found in csv file." - except FileNotFoundError: - return "Issue with CSV file, file not found." except Exception as e: return str(e) diff --git a/mdagent/tools/base_tools/analysis_tools/ppi_tools.py b/mdagent/tools/base_tools/analysis_tools/ppi_tools.py index 3fa7146c..c267b505 100644 --- a/mdagent/tools/base_tools/analysis_tools/ppi_tools.py +++ b/mdagent/tools/base_tools/analysis_tools/ppi_tools.py @@ -6,8 +6,10 @@ from langchain.tools import BaseTool from pydantic import BaseModel, Field +from mdagent.utils import PathRegistry -def ppi_distance(pdb_file, binding_site="protein"): + +def ppi_distance(file_path, binding_site="protein"): """ Calculates minimum heavy-atom distance between peptide (assumed to be smallest chain) and protein. Returns average distance between these two. @@ -16,7 +18,7 @@ def ppi_distance(pdb_file, binding_site="protein"): Can work with any protein-protein interaction (PPI) """ # load and find smallest chain - u = mda.Universe(pdb_file) + u = mda.Universe(file_path) peptide = None for chain in u.segments: if peptide is None or len(chain.residues) < len(peptide): @@ -49,14 +51,25 @@ class PPIDistance(BaseTool): name: str = "ppi_distance" description: str = """Useful for calculating minimum heavy-atom distance between peptide and protein. First, make sure you have valid PDB file with - any protein-protein interaction.""" + any protein-protein interaction. Give this tool the name of the file. The + tool will find the path.""" args_schema: Type[BaseModel] = PPIDistanceInputSchema + path_registry: Optional[PathRegistry] + + def __init__(self, path_registry: Optional[PathRegistry]): + super().__init__() + self.path_registry = path_registry def _run(self, pdb_file: str, binding_site: str = "protein"): - if not pdb_file.endswith(".pdb"): + if not self.path_registry: + return "Error: Path registry is not set" # this should not happen + file_path = self.path_registry.get_mapped_path(pdb_file) + if not file_path: + return f"File not found: {pdb_file}" + if not file_path.endswith(".pdb"): return "Error with input: PDB file must have .pdb extension" try: - avg_dist = ppi_distance(pdb_file, binding_site=binding_site) + avg_dist = ppi_distance(file_path, binding_site=binding_site) except ValueError as e: return ( f"ValueError: {e}. \nMake sure to provide valid PBD " diff --git a/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py b/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py index 684d5f37..a6e5a1b3 100644 --- a/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py +++ b/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py @@ -9,6 +9,8 @@ from MDAnalysis.analysis import align, diffusionmap, rms from pydantic import BaseModel, Field +from mdagent.utils import PathRegistry + # all things related to RMSD as 'standard deviation' # 1 RMSD between two protein conformations or trajectories (1D scalar value) # 2. time-dependent RMSD of the whole trajectory with all or selected atoms @@ -17,16 +19,15 @@ class RMSDFunctions: - def __init__(self, pdb_file, trajectory, ref_file=None, ref_trajectory=None): - self.pdb_file = pdb_file - self.trajectory = trajectory - self.pdb_name = os.path.splitext(os.path.basename(pdb_file))[0] - self.ref_file = ref_file - self.ref_trajectory = ref_trajectory - if ref_file: - self.ref_name = os.path.splitext(os.path.basename(ref_file))[0] - else: - self.ref_name = None + def __init__(self, path_registry, pdb, traj, ref=None, ref_traj=None): + self.path_registry = path_registry + self.pdb_file = self.path_registry.get_mapped_path(pdb) + self.trajectory = self.path_registry.get_mapped_path(traj) + self.pdb_name = os.path.splitext(os.path.basename(self.pdb_file))[0] + self.ref_file = self.path_registry.get_mapped_path(ref) + self.ref_trajectory = self.path_registry.get_mapped_path(ref_traj) + if self.ref_file: + self.ref_name = os.path.splitext(os.path.basename(self.ref_file))[0] def calculate_rmsd( self, @@ -34,13 +35,9 @@ def calculate_rmsd( selection="backbone", plot=True, ): - i = 0 - base_filename = f"{rmsd_type}_{self.pdb_name}" - filename = base_filename - while os.path.exists(filename + ".csv"): - i += 1 - filename = f"{base_filename}_{i}" - self.filename = filename + if self.trajectory is None or self.pdb_file is None: + raise FileNotFoundError("PDB and trajectory files are required.") + self.filename = f"{rmsd_type}_{self.pdb_name}" if rmsd_type == "rmsd": if self.ref_file: @@ -110,6 +107,9 @@ def compute_rmsd(self, selection="backbone", plot=True): final_rmsd = R.results.rmsd[-1, 2] message = f"""Calculated RMSD for each timestep with respect\ to the initial frame. Saved to {self.filename}.csv. """ + self.path_registry.map_path( + f"{self.filename}.csv", f"{self.filename}.csv", message + ) message += f"Average RMSD is {avg_rmsd} \u212B. " message += f"Final RMSD is {final_rmsd} \u212B.\n" @@ -123,6 +123,9 @@ def compute_rmsd(self, selection="backbone", plot=True): plt.savefig(f"{self.filename}.png") # plt.close() # if you don't want to show the plot in notebooks message += f"Plotted RMSD over time. Saved to {self.filename}.png.\n" + self.path_registry.map_path( + f"{self.filename}.png", f"{self.filename}.png", message + ) return message def compute_2d_rmsd(self, selection="backbone", plot_heatmap=True): @@ -154,6 +157,9 @@ def compute_2d_rmsd(self, selection="backbone", plot_heatmap=True): delimiter=",", ) message = f"Saved pairwise RMSD matrix to {self.filename}.csv.\n" + self.path_registry.map_path( + f"{self.filename}.csv", f"{self.filename}.csv", message + ) if plot_heatmap: plt.imshow(pairwise_matrix, cmap="viridis") plt.xlabel(x_label) @@ -162,6 +168,9 @@ def compute_2d_rmsd(self, selection="backbone", plot_heatmap=True): plt.show() plt.savefig(f"{self.filename}.png") message += f"Plotted pairwise RMSD matrix. Saved to {self.filename}.png.\n" + self.path_registry.map_path( + f"{self.filename}.png", f"{self.filename}.png", message + ) return message def compute_rmsf(self, selection="backbone", plot=True): @@ -188,6 +197,9 @@ def compute_rmsf(self, selection="backbone", plot=True): comments="", ) message = f"Saved RMSF data to {self.filename}.csv.\n" + self.path_registry.map_path( + f"{self.filename}.csv", f"{self.filename}.csv", message + ) # Plot RMSF if plot: @@ -200,6 +212,9 @@ def compute_rmsf(self, selection="backbone", plot=True): plt.show() plt.savefig(f"{self.filename}.png") message += f"Plotted RMSF. Saved to {self.filename}.png.\n" + self.path_registry.map_path( + f"{self.filename}.png", f"{self.filename}.png", message + ) return message @@ -245,6 +260,11 @@ class RMSDCalculator(BaseTool): 3. root mean square fluctuation (RMSF) Make sure to provide any necessary files for a chosen RMSD type.""" args_schema: Type[BaseModel] = RMSDInputSchema + path_registry: Optional[PathRegistry] + + def __init__(self, path_registry: Optional[PathRegistry] = None): + super().__init__() + self.path_registry = path_registry def _run( self, @@ -257,13 +277,17 @@ def _run( plot: bool = True, ): try: - rmsd = RMSDFunctions(pdb_file, trajectory, ref_file, ref_trajectory) + rmsd = RMSDFunctions( + self.path_registry, pdb_file, trajectory, ref_file, ref_trajectory + ) message = rmsd.calculate_rmsd(rmsd_type, selection, plot) except ValueError as e: return ( f"ValueError: {e}. \nMake sure to provide valid PBD " "file and binding site using MDAnalysis selection syntax." ) + except FileNotFoundError as e: + return str(e) except Exception as e: return f"Something went wrong. {type(e).__name__}: {e}" return message diff --git a/mdagent/tools/base_tools/analysis_tools/vis_tools.py b/mdagent/tools/base_tools/analysis_tools/vis_tools.py index 3d9ec17b..1303be08 100644 --- a/mdagent/tools/base_tools/analysis_tools/vis_tools.py +++ b/mdagent/tools/base_tools/analysis_tools/vis_tools.py @@ -9,15 +9,17 @@ class VisFunctions: - def list_files_in_directory(self, directory): - files = [ - f - for f in os.listdir(directory) - if os.path.isfile(os.path.join(directory, f)) - ] - return ", ".join(files) - - def run_molrender(self, cif_path): + def __init__(self, path_registry): + self.path_registry = path_registry + self.starting_files = os.listdir(".") + + def _find_png(self): + current_files = os.listdir(".") + new_files = [f for f in current_files if f not in self.starting_files] + png_files = [f for f in new_files if f.endswith(".png")] + return png_files + + def run_molrender(self, cif_path: str) -> str: """Function to run molrender, it requires node.js to be installed and the molrender package to be @@ -25,22 +27,40 @@ def run_molrender(self, cif_path): This will save .png files in the current directory.""" + self.cif_file_name = os.path.basename(cif_path) cmd = ["molrender", "all", cif_path, ".", "--format", "png"] - result = subprocess.run(cmd, capture_output=True, text=True) + try: + result = subprocess.run(cmd, capture_output=True, text=True) + except subprocess.CalledProcessError: + raise RuntimeError("molrender package not found") + file_name = self._find_png() + if not file_name: + raise FileNotFoundError("No .png files were created") + self.path_registry.map_path( + f"mol_render_{self.cif_file_name}", + file_name[0], + "Visualization of cif file {cif_file_name} as png file. using molrender.", + ) if result.returncode != 0: - return Exception(f"Error running molrender: {result.stderr}") + raise RuntimeError(f"Error running molrender: {result.stderr}") else: print(f"Output: {result.stdout}") + return ( + "Visualization using molrender complete, " + "saved as: mol_render_{self.cif_file_name}" + ) - def create_notebook(self, query, PathRegistry): + def create_notebook(self, cif_file: str) -> str: """This is for plan B tool, it will create a notebook with the code to install nglview and display the cif/pdb file.""" + self.cif_file_name = os.path.basename(cif_file) + # Create a new notebook nb = nbf.v4.new_notebook() @@ -49,10 +69,10 @@ def create_notebook(self, query, PathRegistry): # Code to import NGLview and display a file import_code = f""" - import nglview as nv - view = nv.show_file("{query}") - view - """ +import nglview as nv +view = nv.show_file("{cif_file}") +view +""" # Create new code cells install_cell = nbf.v4.new_code_cell(source=install_code) @@ -62,12 +82,14 @@ def create_notebook(self, query, PathRegistry): nb.cells.extend([install_cell, import_cell]) # Write the notebook to a file - with open("Visualization.ipynb", "w") as f: + notebook_name = self.cif_file_name.split(".")[0] + "_vis.ipynb" + with open(notebook_name, "w") as f: nbf.write(nb, f) - # add filename to registry - file_description = "Notebook to visualize cif/pdb files" - PathRegistry.map_path( - "visualize_notebook", "Visualization.ipynb", file_description + + self.path_registry.map_path( + notebook_name, + notebook_name, + f"Notebook to visualize cif/pdb file {self.cif_file_name} using nglview.", ) return "Visualization Complete" @@ -81,8 +103,7 @@ class VisualizeProtein(BaseTool): name = "PDBVisualization" description = """This tool will create a visualization of a cif - file as a png file in - the same directory OR + file as a png file OR it will create a .ipynb file with the visualization of the @@ -98,38 +119,23 @@ def __init__(self, path_registry: Optional[PathRegistry]): super().__init__() self.path_registry = path_registry - def _run(self, query: str) -> str: + def _run(self, cif_file_name: str) -> str: """use the tool.""" - vis = VisFunctions() + if not self.path_registry: + return "Error: Path registry is not set" # this should not happen + cif_path = self.path_registry.get_mapped_path(cif_file_name) + if not cif_path: + return f"File not found: {cif_file_name}" + vis = VisFunctions(self.path_registry) try: - vis.run_molrender(query) - return "Visualization created as png" - except Exception: + return vis.run_molrender(cif_path) + except (RuntimeError, FileNotFoundError) as e: + print(f"Error running molrender: {str(e)}. Using NGLView instead.") try: - vis.create_notebook(query, self.path_registry) + vis.create_notebook(cif_path) return "Visualization created as notebook" except Exception as e: - return f"An error occurred while running molrender: {str(e)}" - - async def _arun(self, query: str) -> str: - """Use the tool asynchronously.""" - raise NotImplementedError("custom_search does not support async") - - -class CheckDirectoryFiles(BaseTool): - name = "ListDirectoryFiles" - description = """This tool will - give you a list of comma - separated files in the - current directory.""" - - def _run(self, query: str) -> str: - """use the tool.""" - try: - vis = VisFunctions() - return vis.list_files_in_directory(".") - except Exception: - return "An error occurred while listing files in directory" + return f"An error occurred {str(e)}" async def _arun(self, query: str) -> str: """Use the tool asynchronously.""" diff --git a/mdagent/tools/base_tools/preprocess_tools/__init__.py b/mdagent/tools/base_tools/preprocess_tools/__init__.py index b45ebce3..a2c538ed 100644 --- a/mdagent/tools/base_tools/preprocess_tools/__init__.py +++ b/mdagent/tools/base_tools/preprocess_tools/__init__.py @@ -5,7 +5,8 @@ RemoveWaterCleaningTool, SpecializedCleanTool, ) -from .pdb_tools import PackMolTool, ProteinName2PDBTool, SmallMolPDB, get_pdb +from .packing import PackMolTool +from .pdb_get import ProteinName2PDBTool, SmallMolPDB, get_pdb __all__ = [ "AddHydrogensCleaningTool", diff --git a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py b/mdagent/tools/base_tools/preprocess_tools/clean_tools.py index c5605012..589a4294 100644 --- a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/clean_tools.py @@ -1,5 +1,4 @@ import os -import re from typing import Dict, Optional, Type from langchain.tools import BaseTool @@ -11,31 +10,10 @@ class CleaningTools: - def _extract_path(self, user_input: str, path_registry: PathRegistry) -> str: - """Extract file path from user input.""" - - # Remove any leading or trailing white space - user_input = user_input.strip() - - # Remove single and double quotes from the user_input - user_input = user_input.replace("'", "") - user_input = user_input.replace('"', "") - - # First check the path registry - mapped_path = path_registry.get_mapped_path(user_input) - if mapped_path != "Name not found in path registry.": - return mapped_path - - # If not found in registry, check if it is a valid path - match = re.search(r"[a-zA-Z0-9_\-/\\:.]+(?:\.pdb|\.cif)", user_input) - - if match: - return match.group(0) - else: - raise ValueError("No valid file path found in user input.") + def __init__(self, path_registry): + self.path_registry = path_registry - def _standard_cleaning(self, pdbfile: str, path_registry: PathRegistry): - pdbfile = self._extract_path(pdbfile, path_registry) + def _standard_cleaning(self, pdbfile: str) -> str: name, end = os.path.splitext(os.path.basename(pdbfile)) end = end.lstrip(".") fixer = PDBFixer(filename=pdbfile) @@ -56,11 +34,10 @@ def _standard_cleaning(self, pdbfile: str, path_registry: PathRegistry): # add filename to registry short_name = f"tidy_{name}" file_description = "Cleaned File. Standard cleaning." - path_registry.map_path(short_name, tidy_filename, file_description) + self.path_registry.map_path(short_name, tidy_filename, file_description) return f"{file_description} Written to {tidy_filename}" - def _remove_water(self, pdbfile: str, path_registry: PathRegistry): - pdbfile = self._extract_path(pdbfile, path_registry) + def _remove_water(self, pdbfile: str) -> str: name, end = os.path.splitext(os.path.basename(pdbfile)) end = end.lstrip(".") fixer = PDBFixer(filename=pdbfile) @@ -75,13 +52,10 @@ def _remove_water(self, pdbfile: str, path_registry: PathRegistry): # add filename to registry short_name = f"tidy_{name}" file_description = "Cleaned File. Removed water." - path_registry.map_path(short_name, tidy_filename, file_description) + self.path_registry.map_path(short_name, tidy_filename, file_description) return f"{file_description} Written to {tidy_filename}" - def _add_hydrogens_and_remove_water( - self, pdbfile: str, path_registry: PathRegistry - ): - pdbfile = self._extract_path(pdbfile, path_registry) + def _add_hydrogens_and_remove_water(self, pdbfile: str) -> str: name, end = os.path.splitext(os.path.basename(pdbfile)) end = end.lstrip(".") fixer = PDBFixer(filename=pdbfile) @@ -96,11 +70,10 @@ def _add_hydrogens_and_remove_water( # add filename to registry short_name = f"tidy_{name}" file_description = "Cleaned File. Missing Hydrogens added and water removed." - path_registry.map_path(short_name, tidy_filename, file_description) + self.path_registry.map_path(short_name, tidy_filename, file_description) return f"{file_description} Written to {tidy_filename}" - def _add_hydrogens(self, pdbfile: str, path_registry: PathRegistry): - pdbfile = self._extract_path(pdbfile, path_registry) + def _add_hydrogens(self, pdbfile: str) -> str: name, end = os.path.splitext(os.path.basename(pdbfile)) end = end.lstrip(".") fixer = PDBFixer(filename=pdbfile) @@ -115,7 +88,7 @@ def _add_hydrogens(self, pdbfile: str, path_registry: PathRegistry): # add filename to registry short_name = f"tidy_{name}" file_description = "Cleaned File. Missing Hydrogens added." - path_registry.map_path(short_name, tidy_filename, file_description) + self.path_registry.map_path(short_name, tidy_filename, file_description) return f"{file_description} Written to {tidy_filename}" @@ -125,22 +98,25 @@ class SpecializedCleanTool(BaseTool): name = "StandardCleaningTool" description = """ This tool will perform a complete cleaning of a PDB or CIF file. - Input: PDB or CIF file. + Input: PDB or CIF file name Output: Cleaned PDB file - Youl will remove heterogens, add missing atoms and hydrogens, and add solvent.""" + You will remove heterogens, add missing atoms and hydrogens, and add solvent.""" path_registry: Optional[PathRegistry] def __init__(self, path_registry: Optional[PathRegistry]): super().__init__() self.path_registry = path_registry - def _run(self, query: str) -> str: + def _run(self, file_name: str) -> str: """use the tool.""" + if self.path_registry is None: + return "Path registry not initialized" try: - if self.path_registry is None: # this should not happen - return "Path registry not initialized" - clean_tools = CleaningTools() - return clean_tools._standard_cleaning(query, self.path_registry) + file_path = self.path_registry.get_mapped_path(file_name) + if file_path is None: + return "File not found" + clean_tools = CleaningTools(self.path_registry) + return clean_tools._standard_cleaning(file_path) except FileNotFoundError: return "Check your file path. File not found." except Exception as e: @@ -160,7 +136,7 @@ class RemoveWaterCleaningTool(BaseTool): to remove water and heterogens, and add hydrogens. This tool will remove water and add hydrogens in a pdb or cif file. - Input: PDB or CIF file. + Input: PDB or CIF file name. Output: Cleaned PDB file """ @@ -170,15 +146,16 @@ def __init__(self, path_registry: Optional[PathRegistry]): super().__init__() self.path_registry = path_registry - def _run(self, query: str) -> str: + def _run(self, file_name: str) -> str: """use the tool.""" + if self.path_registry is None: + return "Path registry not initialized" try: - if self.path_registry is None: # this should not happen - return "Path registry not initialized" - clean_tools = CleaningTools() - return clean_tools._add_hydrogens_and_remove_water( - query, self.path_registry - ) + file_path = self.path_registry.get_mapped_path(file_name) + if file_path is None: + return "File not found" + clean_tools = CleaningTools(self.path_registry) + return clean_tools._add_hydrogens_and_remove_water(file_path) except FileNotFoundError: return "Check your file path. File not found." except Exception as e: @@ -196,7 +173,7 @@ class AddHydrogensCleaningTool(BaseTool): description = """ ] This tool only adds hydrogens to a pdb or cif file. in a pdb or cif file - Input: PDB or CIF file. + Input: PDB or CIF file name. Output: Cleaned PDB file """ @@ -206,13 +183,16 @@ def __init__(self, path_registry: Optional[PathRegistry]): super().__init__() self.path_registry = path_registry - def _run(self, query: str) -> str: + def _run(self, file_name: str) -> str: """use the tool.""" + if self.path_registry is None: + return "Path registry not initialized" try: - if self.path_registry is None: # this should not happen - return "Path registry not initialized" - clean_tools = CleaningTools() - return clean_tools._add_hydrogens(query, self.path_registry) + file_path = self.path_registry.get_mapped_path(file_name) + if file_path is None: + return "File not found" + clean_tools = CleaningTools(self.path_registry) + return clean_tools._add_hydrogens(file_path) except FileNotFoundError: return "Check your file path. File not found." except Exception as e: @@ -267,17 +247,21 @@ class CleaningToolFunction(BaseTool): path_registry: Optional[PathRegistry] + def __init__(self, path_registry: Optional[PathRegistry]): + super().__init__() + self.path_registry = path_registry + def _run(self, **input_args) -> str: """Use the tool with specified operations.""" + if self.path_registry is None: + return "Path registry not initialized" try: - ### No idea why the input is a dictionary with the key "input_args" - # instead of the arguments themselves if "input_args" in input_args.keys(): input_args = input_args["input_args"] else: input_args = input_args pdbfile_id = input_args.get("pdb_id", None) - # TODO check if pdbfile_id is a valid pdb_id from the registry + pdbfile_id = self.path_registry.get_mapped_path(pdbfile_id) if pdbfile_id is None: return """No file was provided. The input has to be a dictionary with the key 'pdb_id'""" @@ -294,7 +278,6 @@ def _run(self, **input_args) -> str: if self.path_registry is None: return "Path registry not initialized" file_description = "Cleaned File: " - CleaningTools() try: pdbfile_path = self.path_registry.get_mapped_path(pdbfile_id) if "/" in pdbfile_path: diff --git a/mdagent/tools/base_tools/preprocess_tools/packing.py b/mdagent/tools/base_tools/preprocess_tools/packing.py new file mode 100644 index 00000000..1df85cb1 --- /dev/null +++ b/mdagent/tools/base_tools/preprocess_tools/packing.py @@ -0,0 +1,466 @@ +import os +import re +import subprocess +import typing +from typing import Any, Dict, List, Type, Union + +from langchain.tools import BaseTool +from pydantic import BaseModel, Field, ValidationError + +from mdagent.utils import PathRegistry + +from .pdb_fix import Validate_Fix_PDB +from .pdb_get import MolPDB + + +def summarize_errors(errors): + error_summary = {} + + # Regular expression pattern to capture the error type and line number + pattern = r"\[!\] Offending field \((.+?)\) at line (\d+)" + + for error in errors: + match = re.search(pattern, error) + if match: + error_type, line_number = match.groups() + # If this error type hasn't been seen before, + # initialize it in the dictionary + if error_type not in error_summary: + error_summary[error_type] = {"lines": []} + error_summary[error_type]["lines"].append(line_number) + + # Format the summarized errors for display + summarized_strings = [] + for error_type, data in error_summary.items(): + line_count = len(data["lines"]) + if line_count > 3: + summarized_strings.append(f"{error_type}: total {line_count} lines") + else: + summarized_strings.append(f"{error_type}: lines: {','.join(data['lines'])}") + + return summarized_strings + + +class Molecule: + def __init__(self, filename, file_id, number_of_molecules=1, instructions=None): + self.filename = filename + self.id = file_id + self.number_of_molecules = number_of_molecules + self.instructions = instructions if instructions else [] + self.load() + + def load(self): + # load the molecule data (optional) + pass + + def get_number_of_atoms(self): + # return the number of atoms in this molecule + pass + + +class PackmolBox: + def __init__( + self, + path_registry, + file_number=1, + file_description="PDB file for simulation with: \n", + ): + self.path_registry = path_registry + self.molecules = [] + self.file_number = 1 + self.file_description = file_description + self.final_name = None + + def add_molecule(self, molecule: Molecule) -> None: + self.molecules.append(molecule) + self.file_description += f"""{molecule.number_of_molecules} of + {molecule.filename} as {molecule.instructions} \n""" + return None + + def generate_input_header(self) -> None: + # Generate the header of the input file in .inp format + orig_pdbs_ids = [ + f"{molecule.number_of_molecules}_{molecule.id}" + for molecule in self.molecules + ] + + _final_name = f'{"_and_".join(orig_pdbs_ids)}' + + self.file_description = ( + "Packed Structures of the following molecules:\n" + + "\n".join( + [ + f"Molecule ID: {molecule.id}, " + f"Number of Molecules: {molecule.number_of_molecules}" + for molecule in self.molecules + ] + ) + ) + while os.path.exists(f"files/pdb/{_final_name}_v{self.file_number}.pdb"): + self.file_number += 1 + + self.final_name = f"{_final_name}_v{self.file_number}.pdb" + with open("packmol.inp", "w") as out: + out.write("##Automatically generated by LangChain\n") + out.write("tolerance 2.0\n") + out.write("filetype pdb\n") + out.write( + f"output {self.final_name}\n" + ) # this is the name of the final file + out.close() + return None + + def generate_input(self) -> str: + input_data = [] + for molecule in self.molecules: + input_data.append(f"structure {molecule.filename}") + input_data.append(f" number {molecule.number_of_molecules}") + for idx, instruction in enumerate(molecule.instructions): + input_data.append(f" {molecule.instructions[idx]}") + input_data.append("end structure") + + # Convert list of input data to a single string + return "\n".join(input_data) + + def run_packmol(self): + validator = Validate_Fix_PDB() + # Use the generated input to execute Packmol + input_string = self.generate_input() + # Write the input to a file + with open("packmol.inp", "a") as f: + f.write(input_string) + # Here, run Packmol using the subprocess module or similar + cmd = "packmol < packmol.inp" + result = subprocess.run(cmd, shell=True, text=True, capture_output=True) + if result.returncode != 0: + print("Packmol failed to run with 'packmol < packmol.inp' command") + result = subprocess.run( + "./" + cmd, shell=True, text=True, capture_output=True + ) + if result.returncode != 0: + raise RuntimeError( + "Packmol failed to run with './packmol < packmol.inp' " + "command. Please check the input file and try again." + ) + + # validate final pdb + pdb_validation = validator.validate_pdb_format(f"{self.final_name}") + if pdb_validation[0] == 0: + # delete .inp files + # os.remove("packmol.inp") + for molecule in self.molecules: + os.remove(molecule.filename) + # name of packed pdb file + time_stamp = self.path_registry.get_timestamp()[-6:] + os.rename(self.final_name, f"files/pdb/{self.final_name}") + self.path_registry.map_path( + f"PACKED_{time_stamp}", + f"files/pdb/{self.final_name}", + self.file_description, + ) + # move file to files/pdb + print("successfull!") + return f"PDB file validated successfully. FileID: PACKED_{time_stamp}" + elif pdb_validation[0] == 1: + # format pdb_validation[1] list of errors + errors = summarize_errors(pdb_validation[1]) + # delete .inp files + + # os.remove("packmol.inp") + print("errors:", f"{errors}") + return "PDB file not validated, errors found {}".format(("\n").join(errors)) + + +# define function that takes in a list of +# molecules and a list of instructions and returns a pdb file + + +def packmol_wrapper( + path_registry, + pdbfiles: List, + files_id: List, + number_of_molecules: List, + instructions: List[List], +): + """Useful when you need to create a box + of different types of molecules molecules""" + + # create a box + box = PackmolBox(path_registry) + # add molecules to the box + for ( + pdbfile, + file_id, + number_of_molecules, + instructions, + ) in zip(pdbfiles, files_id, number_of_molecules, instructions): + molecule = Molecule(pdbfile, file_id, number_of_molecules, instructions) + box.add_molecule(molecule) + # generate input header + box.generate_input_header() + # generate input + # run packmol + print("Packing:", box.file_description, "\nThe file name is:", box.final_name) + return box.run_packmol() + + +"""Args schema for packmol_wrapper tool. Useful for OpenAI functions""" +##TODO + + +class PackmolInput(BaseModel): + pdbfiles_id: typing.Optional[typing.List[str]] = Field( + ..., description="List of PDB files id (path_registry) to pack into a box" + ) + small_molecules: typing.Optional[typing.List[str]] = Field( + [], + description=( + "List of small molecules to be packed in the system. " + "Examples: water, benzene, toluene, etc." + ), + ) + + number_of_molecules: typing.Optional[typing.List[int]] = Field( + ..., + description=( + "List of number of instances of each species to pack into the box. " + "One number per species (either protein or small molecule) " + ), + ) + instructions: typing.Optional[typing.List[List[str]]] = Field( + ..., + description=( + "List of instructions for each species. " + "One List per Molecule. " + "Every instruction should be one string like:\n" + "'inside box 0. 0. 0. 90. 90. 90.'" + ), + ) + + +class PackMolTool(BaseTool): + name: str = "packmol_tool" + description: str = ( + "Useful when you need to create a box " + "of different types of chemical species.\n" + "Three different examples:\n" + "pdbfiles_id: ['1a2b_123456']\n" + "small_molecules: ['water'] \n" + "number_of_molecules: [1, 1000]\n" + "instructions: [['fixed 0. 0. 0. 0. 0. 0. \n centerofmass'], " + "['inside box 0. 0. 0. 90. 90. 90.']]\n" + "will pack 1 molecule of 1a2b_123456 at the origin " + "and 1000 molecules of water. \n" + "pdbfiles_id: ['1a2b_123456']\n" + "number_of_molecules: [1]\n" + "instructions: [['fixed 0. 0. 0. 0. 0. 0.' \n center]]\n" + "This will fix the barocenter of protein 1a2b_123456 at " + "the center of the box with no rotation.\n" + "pdbfiles_id: ['1a2b_123456']\n" + "number_of_molecules: [1]\n" + "instructions: [['outside sphere 2.30 3.40 4.50 8.0]]\n" + "This will place the protein 1a2b_123456 outside a sphere " + "centered at 2.30 3.40 4.50 with radius 8.0\n" + ) + + args_schema: Type[BaseModel] = PackmolInput + + path_registry: typing.Optional[PathRegistry] + + def __init__(self, path_registry: typing.Optional[PathRegistry]): + super().__init__() + self.path_registry = path_registry + + def _get_sm_pdbs(self, small_molecules): + all_files = self.path_registry.list_path_names() + for molecule in small_molecules: + # check path registry for molecule.pdb + if molecule not in all_files: + # download molecule using small_molecule_pdb from MolPDB + molpdb = MolPDB(self.path_registry) + molpdb.small_molecule_pdb(molecule) + print("Small molecules PDBs created successfully") + + def _run(self, **values) -> str: + """use the tool.""" + + if self.path_registry is None: # this should not happen + raise ValidationError("Path registry not initialized") + try: + values = self.validate_input(values) + except ValidationError as e: + return str(e) + error_msg = values.get("error", None) + if error_msg: + print("Error in Packmol inputs:", error_msg) + return f"Error in inputs: {error_msg}" + print("Starting Packmol Tool!") + pdbfile_ids = values.get("pdbfiles_id", []) + pdbfiles = [ + self.path_registry.get_mapped_path(pdbfile) for pdbfile in pdbfile_ids + ] + pdbfile_names = [pdbfile.split("/")[-1] for pdbfile in pdbfiles] + # copy them to the current directory with temp_ names + + pdbfile_names = [f"temp_{pdbfile_name}" for pdbfile_name in pdbfile_names] + number_of_molecules = values.get("number_of_molecules", []) + instructions = values.get("instructions", []) + small_molecules = values.get("small_molecules", []) + # make sure small molecules are all downloaded + self._get_sm_pdbs(small_molecules) + small_molecules_files = [ + self.path_registry.get_mapped_path(sm) for sm in small_molecules + ] + small_molecules_file_names = [ + small_molecule.split("/")[-1] for small_molecule in small_molecules_files + ] + small_molecules_file_names = [ + f"temp_{small_molecule_file_name}" + for small_molecule_file_name in small_molecules_file_names + ] + # append small molecules to pdbfiles + pdbfiles.extend(small_molecules_files) + pdbfile_names.extend(small_molecules_file_names) + pdbfile_ids.extend(small_molecules) + + for pdbfile, pdbfile_name in zip(pdbfiles, pdbfile_names): + os.system(f"cp {pdbfile} {pdbfile_name}") + # check if packmol is installed + cmd = "command -v packmol" + result = subprocess.run(cmd, shell=True, text=True, capture_output=True) + if result.returncode != 0: + result = subprocess.run( + "./" + cmd, shell=True, text=True, capture_output=True + ) + if result.returncode != 0: + return ( + "Packmol is not installed. Please install" + "packmol at " + "'https://m3g.github.io/packmol/download.shtml'" + "and try again." + ) + try: + return packmol_wrapper( + self.path_registry, + pdbfiles=pdbfile_names, + files_id=pdbfile_ids, + number_of_molecules=number_of_molecules, + instructions=instructions, + ) + except RuntimeError as e: + return f"Packmol failed to run with error: {e}" + + def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: + # check if is only a string + if isinstance(values, str): + print("values is a string", values) + raise ValidationError("Input must be a dictionary") + pdbfiles = values.get("pdbfiles_id", []) + small_molecules = values.get("small_molecules", []) + number_of_molecules = values.get("number_of_molecules", []) + instructions = values.get("instructions", []) + number_of_species = len(pdbfiles) + len(small_molecules) + + if not number_of_species == len(number_of_molecules): + if not number_of_species == len(instructions): + return { + "error": ( + "The length of number_of_molecules AND instructions " + "must be equal to the number of species in the system. " + f"You have {number_of_species} " + f"from {len(pdbfiles)} pdbfiles and {len(small_molecules)} " + "small molecules" + ) + } + return { + "error": ( + "The length of number_of_molecules must be equal to the " + f"number of species in the system. You have {number_of_species} " + f"from {len(pdbfiles)} pdbfiles and {len(small_molecules)} " + "small molecules" + ) + } + elif not number_of_species == len(instructions): + return { + "error": ( + "The length of instructions must be equal to the " + f"number of species in the system. You have {number_of_species} " + f"from {len(pdbfiles)} pdbfiles and {len(small_molecules)} " + "small molecules" + ) + } + registry = PathRegistry.get_instance() + molPDB = MolPDB(registry) + for instruction in instructions: + if len(instruction) != 1: + return { + "error": ( + "Each instruction must be a single string. " + "If necessary, use newlines in a instruction string." + ) + } + # TODO enhance this validation with more packmol instructions + first_word = instruction[0].split(" ")[0] + if first_word == "center": + if len(instruction[0].split(" ")) == 1: + return { + "error": ( + "The instruction 'center' must be accompanied by more " + "instructions. Example 'fixed 0. 0. 0. 0. 0. 0.' " + "The complete instruction would be: 'center \n fixed 0. 0. " + "0. 0. 0. 0.' with a newline separating the two " + "instructions." + ) + } + elif first_word not in [ + "inside", + "outside", + "fixed", + ]: + return { + "error": ( + "The first word of each instruction must be one of " + "'inside' or 'outside' or 'fixed' \n" + "examples: center \n fixed 0. 0. 0. 0. 0. 0.,\n" + "inside box -10. 0. 0. 10. 10. 10. \n" + ) + } + + # Further validation, e.g., checking if files exist + file_ids = registry.list_path_names() + + for pdbfile_id in pdbfiles: + if "_" not in pdbfile_id: + return { + "error": ( + f"{pdbfile_id} is not a valid pdbfile_id in the path_registry" + ) + } + if pdbfile_id not in file_ids: + # look for files in the current directory + # that match some part of the pdbfile + ids_w_description = registry.list_path_names_and_descriptions() + + return { + "error": ( + f"PDB file ID {pdbfile_id} does not exist " + "in the path registry.\n" + f"This are the files IDs: {ids_w_description} " + ) + } + for small_molecule in small_molecules: + if small_molecule not in file_ids: + result = molPDB.small_molecule_pdb(small_molecule) + if "successfully" not in result: + return { + "error": ( + f"{small_molecule} could not be converted to a pdb " + "file. Try with a different name, or with the SMILES " + "of the small molecule" + ) + } + return values + + async def _arun(self, values: str) -> str: + """Use the tool asynchronously.""" + raise NotImplementedError("custom_search does not support async") diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py b/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py new file mode 100644 index 00000000..4cef4ef0 --- /dev/null +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py @@ -0,0 +1,764 @@ +import os +import re +import sys +import typing +from typing import Any, Dict, Optional, Type, Union + +from langchain.tools import BaseTool +from pdbfixer import PDBFixer +from pydantic import BaseModel, Field, ValidationError, root_validator + +from mdagent.utils import PathRegistry + +from .elements import list_of_elements + + +class PDBsummarizerfxns: + def __init__(self): + self.list_of_elements = list_of_elements + + def _record_inf(self, pdbfile): + with open(pdbfile, "r") as f: + lines = f.readlines() + remarks = [ + record_lines + for record_lines in lines + if record_lines.startswith("REMARK") + ] + atoms = [ + record_lines + for record_lines in lines + if record_lines.startswith("ATOM") + ] + box = [ + record_lines + for record_lines in lines + if record_lines.startswith("CRYST1") + ] + HETATM = [ + record_lines + for record_lines in lines + if record_lines.startswith("HETATM") + ] + + return remarks, atoms, box, HETATM + + def _num_of_dif_residues(self, pdbfile): + remarks, atoms, box, HETATM = self._record_inf(pdbfile) + residues = [atom[17:20] for atom in atoms] + residues = list(set(residues)) + return len(residues) + + # diagnosis + """Checks for the elements names in the pdb file. + Positions 76-78 of the ATOM and HETATM records""" + + def _atoms_have_elements(self, pdbfile): + _, atoms, _, _ = self._record_inf(pdbfile) + print(atoms) + elements = [atom[76:78] for atom in atoms if atom not in [" ", "", " ", " "]] + print(elements) + if len(elements) != len(atoms): + print( + ( + "No elements in the ATOM records there are" + "{len(elements)} elements and {len(atoms)}" + "atoms records" + ) + ) + return False + elements = list(set(elements)) + for element in elements: + if element not in self.list_of_elements: + print("Element not in the list of elements") + return False + return True + + def _atoms_have_tempFactor(self, pdbfile): + _, atoms, _, _ = self._record_inf(pdbfile) + tempFactor = [ + atom[60:66] + for atom in atoms + if atom[60:66] not in [" ", "", " ", " ", " ", " "] + ] + if len(tempFactor) != len(atoms): + return False + return True + + def _atoms_have_occupancy(self, pdbfile): + _, atoms, _, _ = self._record_inf(pdbfile) + occupancy = [ + atom[54:60] + for atom in atoms + if atom[54:60] not in [" ", "", " ", " ", " ", " "] + ] + if len(occupancy) != len(atoms): + return False + return True + + def _hetatom_have_occupancy(self, pdbfile): + _, _, _, HETATM = self._record_inf(pdbfile) + occupancy = [ + atom[54:60] + for atom in HETATM + if atom[54:60] not in [" ", "", " ", " ", " ", " "] + ] + if len(occupancy) != len(HETATM): + return False + return True + + def _hetatm_have_elements(self, pdbfile): + _, _, _, HETATM = self._record_inf(pdbfile) + elements = [ + atom[76:78] for atom in HETATM if atom[76:78] not in [" ", "", " ", " "] + ] + if len(elements) != len(HETATM): + print("No elements in the HETATM records") + return False + return True + + def _hetatm_have_tempFactor(self, pdbfile): + _, _, _, HETATM = self._record_inf(pdbfile) + tempFactor = [ + atom[60:66] for atom in HETATM if atom not in [" ", "", " ", " "] + ] + if len(tempFactor) != len(HETATM): + return False + return True + + """Checks for the residue names in the pdb file. + Positions 17-20 of the ATOM and HETATM records""" + + def _atoms_hetatm_have_residue_names(self, pdbfile): + _, atoms, _, HETATM = self._record_inf(pdbfile) + residues = [atom[17:20] for atom in atoms] + residues = list(set(residues)) + if len(residues) != len(atoms): + return False + residues = [atom[17:20] for atom in HETATM] + residues = list(set(residues)) + if len(residues) != len(HETATM): + return False + return True + + def _atoms_hetatm_have_occupancy(self, pdbfile): + _, atoms, _, HETATM = self._record_inf(pdbfile) + occupancy = [ + atom[54:60] + for atom in atoms + if atom not in [" ", "", " ", " ", " ", " "] + ] + if len(occupancy) != len(atoms): + return False + occupancy = [ + HET[54:60] + for HET in HETATM + if HET not in [" ", "", " ", " ", " ", " "] + ] + if len(occupancy) != len(HETATM): + return False + return True + + def _non_standard_residues(self, pdbfile): + fixer = PDBFixer(file_name=pdbfile) + fixer.findNonstandardResidues() + len(fixer.nonstandardResidues) + + def pdb_summarizer(self, pdb_file): + self.remarks, self.atoms, self.box, self.HETATM = self._record_inf(pdb_file) + self.atoms_elems = self._atoms_have_elements(pdb_file) + self.HETATM_elems = self._hetatm_have_elements(pdb_file) + self.residues = self._atoms_hetatm_have_residue_names(pdb_file) + self.atoms_tempFact = self._atoms_have_tempFactor(pdb_file) + self.num_of_residues = self._num_of_dif_residues(pdb_file) + self.HETATM_tempFact = self._hetatm_have_tempFactor(pdb_file) + + output = ( + f"PDB file: {pdb_file} has the following properties:" + "Number of residues: {pdb.num_of_residues}" + "Are elements identifiers present: {pdb.atoms}" + "Are HETATM elements identifiers present: {pdb.HETATM}" + "Are residue names present: {pdb.residues}" + "Are box dimensions present: {pdb.box}" + "Non-standard residues: {pdb.HETATM}" + ) + return output + + +class Validate_Fix_PDB: + def validate_pdb_format(self, fhandle): + """ + Compare each ATOM/HETATM line with the format defined on the + official PDB website. + + Parameters + ---------- + fhandle : a line-by-line iterator of the original PDB file. + + Returns + ------- + (int, list) + - 1 if error was found, 0 if no errors were found. + - List of error messages encountered. + """ + # check if filename is in directory + if not os.path.exists(fhandle): + return (1, ["File not found. Packmol failed to write the file."]) + errors = [] + _fmt_check = ( + ("Atm. Num.", (slice(6, 11), re.compile(r"[\d\s]+"))), + ("Alt. Loc.", (slice(11, 12), re.compile(r"\s"))), + ("Atm. Nam.", (slice(12, 16), re.compile(r"\s*[A-Z0-9]+\s*"))), + ("Spacer #1", (slice(16, 17), re.compile(r"[A-Z0-9 ]{1}"))), + ("Res. Nam.", (slice(17, 20), re.compile(r"\s*[A-Z0-9]+\s*"))), + ("Spacer #2", (slice(20, 21), re.compile(r"\s"))), + ("Chain Id.", (slice(21, 22), re.compile(r"[A-Za-z0-9 ]{1}"))), + ("Res. Num.", (slice(22, 26), re.compile(r"\s*[\d\-]+\s*"))), + ("Ins. Code", (slice(26, 27), re.compile(r"[A-Z0-9 ]{1}"))), + ("Spacer #3", (slice(27, 30), re.compile(r"\s+"))), + ("Coordn. X", (slice(30, 38), re.compile(r"\s*[\d\.\-]+\s*"))), + ("Coordn. Y", (slice(38, 46), re.compile(r"\s*[\d\.\-]+\s*"))), + ("Coordn. Z", (slice(46, 54), re.compile(r"\s*[\d\.\-]+\s*"))), + ("Occupancy", (slice(54, 60), re.compile(r"\s*[\d\.\-]+\s*"))), + ("Tmp. Fac.", (slice(60, 66), re.compile(r"\s*[\d\.\-]+\s*"))), + ("Spacer #4", (slice(66, 72), re.compile(r"\s+"))), + ("Segm. Id.", (slice(72, 76), re.compile(r"[\sA-Z0-9\-\+]+"))), + ("At. Elemt", (slice(76, 78), re.compile(r"[\sA-Z0-9\-\+]+"))), + ("At. Charg", (slice(78, 80), re.compile(r"[\sA-Z0-9\-\+]+"))), + ) + + def _make_pointer(column): + col_bg, col_en = column.start, column.stop + pt = ["^" if c in range(col_bg, col_en) else " " for c in range(80)] + return "".join(pt) + + for iline, line in enumerate(fhandle, start=1): + line = line.rstrip("\n").rstrip("\r") # CR/LF + if not line: + continue + + if line[0:6] in ["ATOM ", "HETATM"]: + # ... [rest of the code unchanged here] + linelen = len(line) + if linelen < 80: + emsg = "[!] Line {0} is short: {1} < 80\n" + sys.stdout.write(emsg.format(iline, linelen)) + + elif linelen > 80: + emsg = "[!] Line {0} is long: {1} > 80\n" + sys.stdout.write(emsg.format(iline, linelen)) + + for fname, (fcol, fcheck) in _fmt_check: + field = line[fcol] + if not fcheck.match(field): + pointer = _make_pointer(fcol) + emsg = "[!] Offending field ({0}) at line {1}\n".format( + fname, iline + ) + emsg += repr(line) + "\n" + emsg += pointer + "\n" + errors.append(emsg) + + else: + # ... [rest of the code unchanged here] + linelen = len(line) + # ... [rest of the code unchanged here] + linelen = len(line) + skip_keywords = ( + "END", + "ENDMDL", + "HEADER", + "TITLE", + "REMARK", + "CRYST1", + "MODEL", + ) + + if any(keyword in line for keyword in skip_keywords): + continue + + if linelen < 80: + emsg = "[!] Line {0} is short: {1} < 80\n" + sys.stdout.write(emsg.format(iline, linelen)) + elif linelen > 80: + emsg = "[!] Line {0} is long: {1} > 80\n" + sys.stdout.write(emsg.format(iline, linelen)) + + """ + map paths to files in path_registry before you return the string + same for all other functions you want to save files for next tools + Don't forget to import PathRegistry and add path_registry + or PathRegistry as an argument + """ + if errors: + msg = "\nTo understand your errors, read the format specification:\n" + msg += "http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM\n" + errors.append(msg) + return (1, errors) + else: + return (0, ["It *seems* everything is OK."]) + + def _fix_element_column(self, pdb_file, custom_element_dict=None): + records = ("ATOM", "HETATM", "ANISOU") + corrected_lines = [] + for line in pdb_file: + if line.startswith(records): + atom_name = line[12:16] + + if atom_name[0].isalpha() and not atom_name[2:].isdigit(): + element = atom_name.strip() + else: + atom_name = atom_name.strip() + if atom_name[0].isdigit(): + element = atom_name[1] + else: + element = atom_name[0] + + if element not in set(list_of_elements): + element = " " # empty element in case we cannot assign + + line = line[:76] + element.rjust(2) + line[78:] + corrected_lines.append(line) + + return corrected_lines + + def fix_element_column(self, pdb_file, custom_element_dict=None): + """Fixes the Element columns of a pdb file""" + + # extract Title, Header, Remarks, and Cryst1 records + file_name = pdb_file.split(".")[0] + # check if theres a file-name-fixed.pdb file + if os.path.isfile(file_name + "-fixed.pdb"): + pdb_file = file_name + "-fixed.pdb" + assert isinstance(pdb_file, str), "pdb_file must be a string" + with open(pdb_file, "r") as f: + print("I read the initial file") + pdb_file_lines = f.readlines() + # validate if pdbfile has element records + pdb = PDBsummarizerfxns() + atoms_have_elems, HETATM_have_elems = pdb._atoms_have_elements( + pdb_file + ), pdb._hetatm_have_elements(pdb_file) + if atoms_have_elems and HETATM_have_elems: + f.close() + return ( + "Element's column already filled with" + "elements, no fix needed for elements" + ) + print("I closed the initial file") + f.close() + + # fix element column + records = ("TITLE", "HEADER", "REMARK", "CRYST1", "HET", "LINK", "SEQRES") + final_records = ("CONECT", "MASTER", "END") + _unchanged_records = [] + _unchanged_final_records = [] + print("pdb_file", pdb_file) + for line in pdb_file_lines: + if line.startswith(records): + _unchanged_records.append(line) + elif line.startswith(final_records): + _unchanged_final_records.append(line) + print("_unchanged_records", _unchanged_records) + new_pdb = self._fix_element_column(pdb_file_lines, custom_element_dict) + # join the linees + new_pdb = "".join(new_pdb) + # write new pdb file as pdb_file-fixed.pdb + new_pdb_file = file_name.split(".")[0] + "-fixed.pdb" + print("name of fixed pdb file", new_pdb_file) + # write the unchanged records first and then the new pdb file + assert isinstance(new_pdb_file, str), "new_pdb_file must be a string" + with open(new_pdb_file, "w") as f: + print("I wrote the new file") + f.writelines(_unchanged_records) + f.write(new_pdb) + f.writelines(_unchanged_final_records) + f.close() + try: + # read the new pdb file and check if it has element records + with open(new_pdb_file, "r") as f: + pdb_file_lines = f.readlines() + pdb = PDBsummarizerfxns() + atoms_have_elems, HETATM_have_elems = pdb._atoms_have_elements( + new_pdb_file + ), pdb._hetatm_have_elements(new_pdb_file) + if atoms_have_elems and HETATM_have_elems: + f.close() + return "Element's column fixed successfully" + else: + f.close() + return "Element's column not fixed, and i dont know why" + except Exception as e: + return f"Element's column not fixed error: {e}" + + def pad_line(self, line): + """Pad line to 80 characters in case it is shorter.""" + size_of_line = len(line) + if size_of_line < 80: + padding = 80 - size_of_line + 1 + line = line.strip("\n") + " " * padding + "\n" + return line[:81] # 80 + newline character + + def _fix_temp_factor_column(self, pdbfile, bfactor, only_fill): + """Set the temperature column in all ATOM/HETATM records to a given value. + + This function is a generator. + + Parameters + ---------- + fhandle : a line-by-line iterator of the original PDB file. + + bfactor : float + The desired bfactor. + + Yields + ------ + str (line-by-line) + The modified (or not) PDB line.""" + _pad_line = self.pad_line + records = ("ATOM", "HETATM") + corrected_lines = [] + bfactor = "{0:>6.2f}".format(bfactor) + + for line in pdbfile: + if line.startswith(records): + line = _pad_line(line) + if only_fill: + if line[60:66].strip() == "": + corrected_lines.append(line[:60] + bfactor + line[66:]) + else: + corrected_lines.append(line[:60] + bfactor + line[66:]) + else: + corrected_lines.append(line) + + return corrected_lines + + def fix_temp_factor_column(self, pdb_file, bfactor=1.00, only_fill=True): + """Fixes the tempFactor columns of a pdb file""" + + # extract Title, Header, Remarks, and Cryst1 records + # get name from pdb_file + if isinstance(pdb_file, str): + file_name = pdb_file.split(".")[0] + else: + return "pdb_file must be a string" + file_name = pdb_file.split(".")[0] + + if os.path.isfile(file_name + "-fixed.pdb"): + file_name = file_name + "-fixed.pdb" + + assert isinstance(file_name, str), "pdb_file must be a string" + with open(file_name, "r") as f: + print("im reading the files temp factor") + pdb_file_lines = f.readlines() + # validate if pdbfile has temp factors + pdb = PDBsummarizerfxns() + atoms_have_bfactor, HETATM_have_bfactor = pdb._atoms_have_tempFactor( + pdb_file + ), pdb._hetatm_have_tempFactor(pdb_file) + if atoms_have_bfactor and HETATM_have_bfactor and only_fill: + # print("Im closing the file temp factor") + f.close() + return ( + "TempFact column filled with bfactor already," + "no fix needed for temp factor" + ) + f.close() + # fix element column + records = ("TITLE", "HEADER", "REMARK", "CRYST1", "HET", "LINK", "SEQRES") + final_records = ("CONECT", "MASTER", "END") + _unchanged_final_records = [] + _unchanged_records = [] + for line in pdb_file_lines: + if line.startswith(records): + _unchanged_records.append(line) + elif line.startswith(final_records): + _unchanged_final_records.append(line) + + new_pdb = self._fix_temp_factor_column(pdb_file_lines, bfactor, only_fill) + # join the linees + new_pdb = "".join(new_pdb) + # write new pdb file as pdb_file-fixed.pdb + new_pdb_file = file_name + "-fixed.pdb" + # organize columns: + # HEADER, TITLE, REMARKS, CRYST1, ATOM, HETATM, CONECT, MASTER, END + + assert isinstance(new_pdb_file, str), "new_pdb_file must be a string" + # write new pdb file as pdb_file-fixed.pdb + with open(new_pdb_file, "w") as f: + f.writelines(_unchanged_records) + f.write(new_pdb) + f.writelines(_unchanged_final_records) + f.close() + try: + # read the new pdb file and check if it has element records + with open(new_pdb_file, "r") as f: + pdb_file = f.readlines() + pdb = PDBsummarizerfxns() + atoms_have_bfactor, HETATM_have_bfactor = pdb._atoms_have_tempFactor( + new_pdb_file + ), pdb._hetatm_have_tempFactor(new_pdb_file) + if atoms_have_bfactor and HETATM_have_bfactor: + f.close() + return "TempFact fixed successfully" + else: + f.close() + return "TempFact column not fixed" + except Exception as e: + return f"Couldnt read written file TempFact column not fixed error: {e}" + + def _fix_occupancy_column(self, pdbfile, occupancy, only_fill): + """ + Set the occupancy column in all ATOM/HETATM records to a given value. + + Non-ATOM/HETATM lines are give as are. This function is a generator. + + Parameters + ---------- + fhandle : a line-by-line iterator of the original PDB file. + + occupancy : float + The desired occupancy value + + Yields + ------ + str (line-by-line) + The modified (or not) PDB line. + """ + + records = ("ATOM", "HETATM") + corrected_lines = [] + occupancy = "{0:>6.2f}".format(occupancy) + for line in pdbfile: + if line.startswith(records): + line = self.pad_line(line) + if only_fill: + if line[54:60].strip() == "": + corrected_lines.append(line[:54] + occupancy + line[60:]) + else: + corrected_lines.append(line[:54] + occupancy + line[60:]) + else: + corrected_lines.append(line) + + return corrected_lines + + def fix_occupancy_columns(self, pdb_file, occupancy=1.0, only_fill=True): + """Fixes the occupancy columns of a pdb file""" + # extract Title, Header, Remarks, and Cryst1 records + # get name from pdb_file + file_name = pdb_file.split(".")[0] + if os.path.isfile(file_name + "-fixed.pdb"): + file_name = file_name + "-fixed.pdb" + + assert isinstance(pdb_file, str), "pdb_file must be a string" + with open(file_name, "r") as f: + pdb_file_lines = f.readlines() + # validate if pdbfile has occupancy + pdb = PDBsummarizerfxns() + atoms_have_bfactor, HETATM_have_bfactor = pdb._atoms_have_occupancy( + file_name + ), pdb._hetatom_have_occupancy(file_name) + if atoms_have_bfactor and HETATM_have_bfactor and only_fill: + f.close() + return ( + "Occupancy column filled with occupancy" + "already, no fix needed for occupancy" + ) + f.close() + # fix element column + records = ("TITLE", "HEADER", "REMARK", "CRYST1", "HET", "LINK", "SEQRES") + final_records = ("CONECT", "MASTER", "END") + _unchanged_records = [] + _unchanged_final_records = [] + for line in pdb_file_lines: + if line.startswith(records): + _unchanged_records.append(line) + elif line.startswith(final_records): + _unchanged_final_records.append(line) + + new_pdb = self._fix_occupancy_column(pdb_file_lines, occupancy, only_fill) + # join the linees + new_pdb = "".join(new_pdb) + # write new pdb file as pdb_file-fixed.pdb + new_pdb_file = file_name + "-fixed.pdb" + + # write new pdb file as pdb_file-fixed.pdb + assert isinstance(new_pdb_file, str), "new_pdb_file must be a string" + with open(new_pdb_file, "w") as f: + f.writelines(_unchanged_records) + f.write(new_pdb) + f.writelines(_unchanged_final_records) + f.close() + try: + # read the new pdb file and check if it has element records + with open(new_pdb_file, "r") as f: + pdb_file = f.readlines() + pdb = PDBsummarizerfxns() + atoms_have_bfactor, HETATM_have_bfactor = pdb._atoms_have_tempFactor( + new_pdb_file + ), pdb._hetatm_have_tempFactor(new_pdb_file) + if atoms_have_bfactor and HETATM_have_bfactor: + f.close() + return "Occupancy fixed successfully" + else: + f.close() + return "Occupancy column not fixed" + except Exception: + return "Couldnt read file Occupancy's column not fixed" + + def apply_fixes(self, pdbfile, query): + # Define a mapping between query keys and functions. + # If a function requires additional arguments from the query, + # define it as a lambda. + FUNCTION_MAP = { + "ElemColum": lambda pdbfile, params: self.fix_element_column(pdbfile), + "tempFactor": lambda pdbfile, params: self.fix_temp_factor_column( + pdbfile, *params + ), + "Occupancy": lambda pdbfile, params: self.fix_occupancy_columns( + pdbfile, *params + ), + } + # Iterate through the keys and functions in FUNCTION_MAP. + for key, func in FUNCTION_MAP.items(): + # Check if the current key is in the query and is not None. + params = query.get(key) + if params is not None: + # If it is, call the function with + # pdbfile and the parameters from the query. + func(pdbfile, params) + + return "PDB file fixed" + + +class PDBFilesFixInp(BaseModel): + pdbfile: str = Field(..., description="PDB file to be fixed") + ElemColum: typing.Optional[bool] = Field( + False, + description=( + "List of fixes to be applied. If None, a" + "validation of what fixes are needed is performed." + ), + ) + tempFactor: typing.Optional[typing.Tuple[float, bool]] = Field( + (...), + description=( + "Tuple of ( float, bool)" + "first arg is the" + "value to be set as the tempFill, and third arg indicates" + "if only empty TempFactor columns have to be filled" + ), + ) + Occupancy: typing.Optional[typing.Tuple[float, bool]] = Field( + (...), + description=( + "Tuple of (bool, float, bool)" + "where first arg indicates if Occupancy" + "fix has to be applied, second arg is the" + "value to be set, and third arg indicates" + "if only empty Occupancy columns have to be filled" + ), + ) + + @root_validator + def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: + if isinstance(values, str): + print("values is a string", values) + raise ValidationError("Input must be a dictionary") + + pdbfile = values.get("pdbfiles", "") + occupancy = values.get("occupancy") + tempFactor = values.get("tempFactor") + ElemColum = values.get("ElemColum") + + if occupancy is None and tempFactor is None and ElemColum is None: + if pdbfile == "": + return {"error": "No inputs given, failed use of tool."} + else: + return values + else: + if occupancy: + if len(occupancy) != 2: + return { + "error": ( + "if you want to fix the occupancy" + "column argument must be a tuple of (bool, float)" + ) + } + if not isinstance(occupancy[0], float): + return {"error": "occupancy first arg must be a float"} + if not isinstance(occupancy[1], bool): + return {"error": "occupancy second arg must be a bool"} + if tempFactor: + if len(tempFactor != 2): + return { + "error": ( + "if you want to fix the tempFactor" + "column argument must be a tuple of (float, bool)" + ) + } + if not isinstance(tempFactor[0], bool): + return {"error": "occupancy first arg must be a float"} + if not isinstance(tempFactor[1], float): + return {"error": "tempFactor second arg must be a float"} + if ElemColum is not None: + if not isinstance(ElemColum[1], bool): + return {"error": "ElemColum must be a bool"} + return values + + +class FixPDBFile(BaseTool): + name: str = "PDBFileFixer" + description: str = "Fixes PDB files columns if needed" + args_schema: Type[BaseModel] = PDBFilesFixInp + + path_registry: Optional[PathRegistry] + + def __init__(self, path_registry: Optional[PathRegistry]): + super().__init__() + self.path_registry = path_registry + + def _run(self, query: Dict): + """use the tool.""" + if self.path_registry is None: + raise ValidationError("Path registry not initialized") + pdb_ff = Validate_Fix_PDB() + error_msg = query.get("error") + if error_msg: + return error_msg + pdbfile = query.pop("pdbfile") + if len(query.keys()) == 0: + validation = pdb_ff.validate_pdb_format(pdbfile) + if validation[0] == 0: + return "PDB file is valid, no need to fix it" + + if validation[0] == 1: + # Convert summarized_errors into a set for O(1) lookups + error_set = set(validation[1]) + + # Apply Fixes + if "At. Elem." in error_set: + pdb_ff.fix_element_column(pdbfile) + if "Tmp. Fac." in error_set: + pdb_ff.fix_temp_factor_column(pdbfile) + if "Occupancy" in error_set: + pdb_ff.fix_occupancy_columns(pdbfile) + + validate = pdb_ff.validate_pdb_format(pdbfile + "-fixed.pdb") + if validate[0] == 0: + name = pdbfile + "-fixed.pdb" + description = "PDB file fixed" + self.path_registry.map_path(name, name, description) + return "PDB file fixed" + else: + return "PDB not fully fixed" + else: + pdb_ff.apply_fixes(pdbfile, query) + validate = pdb_ff.validate_pdb_format(pdbfile + "-fixed.pdb") + if validate[0] == 0: + name = pdbfile + "-fixed.pdb" + description = "PDB file fixed" + self.path_registry.map_path(name, name, description) + return "PDB file fixed" + else: + return "PDB not fully fixed" diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_get.py b/mdagent/tools/base_tools/preprocess_tools/pdb_get.py new file mode 100644 index 00000000..546aade5 --- /dev/null +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_get.py @@ -0,0 +1,222 @@ +import os +from typing import Optional + +import requests +import streamlit as st +from langchain.tools import BaseTool +from rdkit import Chem + +from mdagent.utils import FileType, PathRegistry + + +def get_pdb(query_string: str, path_registry: PathRegistry): + """ + Search RSCB's protein data bank using the given query string + and return the path to pdb file in either CIF or PDB format + """ + if path_registry is None: + path_registry = PathRegistry.get_instance() + url = "https://search.rcsb.org/rcsbsearch/v2/query?json={search-request}" + query = { + "query": { + "type": "terminal", + "service": "full_text", + "parameters": {"value": query_string}, + }, + "return_type": "entry", + } + r = requests.post(url, json=query) + if r.status_code == 204: + return None + if "cif" in query_string or "CIF" in query_string: + filetype = "cif" + else: + filetype = "pdb" + if "result_set" in r.json() and len(r.json()["result_set"]) > 0: + pdbid = r.json()["result_set"][0]["identifier"] + print(f"PDB file found with this ID: {pdbid}") + st.markdown(f"PDB file found with this ID: {pdbid}", unsafe_allow_html=True) + url = f"https://files.rcsb.org/download/{pdbid}.{filetype}" + pdb = requests.get(url) + filename = path_registry.write_file_name( + FileType.PROTEIN, + protein_name=pdbid, + description="raw", + file_format=filetype, + ) + file_id = path_registry.get_fileid(filename, FileType.PROTEIN) + directory = "files/pdb" + # Create the directory if it does not exist + if not os.path.exists(directory): + os.makedirs(directory) + + with open(f"{directory}/{filename}", "w") as file: + file.write(pdb.text) + path_registry.map_path( + file_id, f"{directory}/{filename}", "PDB file downloaded from RSCB" + ) + + return filename, file_id + return None + + +class ProteinName2PDBTool(BaseTool): + name = "PDBFileDownloader" + description = ( + "This tool downloads PDB (Protein Data Bank) or" + "CIF (Crystallographic Information File) files using" + "a protein's common name (NOT a small molecule)." + "When a specific file type, either PDB or CIF," + "is requested, add file type to the query string with space." + "Input: Commercial name of the protein or file without" + "file extension" + "Output: Corresponding PDB or CIF file" + ) + path_registry: Optional[PathRegistry] + + def __init__(self, path_registry: Optional[PathRegistry]): + super().__init__() + self.path_registry = path_registry + + def _run(self, query: str) -> str: + """Use the tool.""" + try: + if self.path_registry is None: # this should not happen + return "Path registry not initialized" + filename, pdbfile_id = get_pdb(query, self.path_registry) + if pdbfile_id is None: + return "Name2PDB tool failed to find and download PDB file." + else: + self.path_registry.map_path( + pdbfile_id, + f"files/pdb/{filename}", + f"PDB file downloaded from RSCB, PDBFile ID: {pdbfile_id}", + ) + return f"Name2PDB tool successful. downloaded the PDB file:{pdbfile_id}" + except Exception as e: + return f"Something went wrong. {e}" + + async def _arun(self, query) -> str: + """Use the tool asynchronously.""" + raise NotImplementedError("this tool does not support async") + + +class MolPDB: + def __init__(self, path_registry): + self.path_registry = path_registry + + def is_smiles(self, text: str) -> bool: + try: + m = Chem.MolFromSmiles(text, sanitize=False) + if m is None: + return False + return True + except Exception: + return False + + def largest_mol( + self, smiles: str + ) -> ( + str + ): # from https://github.com/ur-whitelab/chemcrow-public/blob/main/chemcrow/utils.py + ss = smiles.split(".") + ss.sort(key=lambda a: len(a)) + while not self.is_smiles(ss[-1]): + rm = ss[-1] + ss.remove(rm) + return ss[-1] + + def molname2smiles( + self, query: str + ) -> ( + str + ): # from https://github.com/ur-whitelab/chemcrow-public/blob/main/chemcrow/tools/databases.py + url = " https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{}/{}" + r = requests.get(url.format(query, "property/IsomericSMILES/JSON")) + # convert the response to a json object + data = r.json() + # return the SMILES string + try: + smi = data["PropertyTable"]["Properties"][0]["IsomericSMILES"] + except KeyError: + return ( + "Could not find a molecule matching the text." + "One possible cause is that the input is incorrect, " + "input one molecule at a time." + ) + # remove salts + return Chem.CanonSmiles(self.largest_mol(smi)) + + def smiles2name(self, smi: str) -> str: + try: + smi = Chem.MolToSmiles(Chem.MolFromSmiles(smi), canonical=True) + except Exception: + return "Invalid SMILES string" + # query the PubChem database + r = requests.get( + "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/" + + smi + + "/synonyms/JSON" + ) + data = r.json() + try: + name = data["InformationList"]["Information"][0]["Synonym"][0] + except KeyError: + return "Unknown Molecule" + return name + + def small_molecule_pdb(self, mol_str: str) -> str: + # takes in molecule name or smiles (converts to smiles if name) + # writes pdb file name.pdb (gets name from smiles if possible) + # output is done message + ps = Chem.SmilesParserParams() + ps.removeHs = False + try: + if self.is_smiles(mol_str): + m = Chem.MolFromSmiles(mol_str) + mol_name = self.smiles2name(mol_str) + else: # if input is not smiles, try getting smiles + smi = self.molname2smiles(mol_str) + m = Chem.MolFromSmiles(smi) + mol_name = mol_str + try: # only if needed + m = Chem.AddHs(m) + except Exception: + pass + Chem.AllChem.EmbedMolecule(m) + file_name = f"files/pdb/{mol_name}.pdb" + Chem.MolToPDBFile(m, file_name) + self.path_registry.map_path( + mol_name, file_name, f"pdb file for the small molecule {mol_name}" + ) + return ( + f"PDB file for {mol_str} successfully created and saved to {file_name}." + ) + except Exception: + print( + "There was an error getting pdb. Please input a single molecule name." + f"{mol_str},{mol_name}, {smi}" + ) + return ( + "There was an error getting pdb. Please input a single molecule name." + ) + + +class SmallMolPDB(BaseTool): + name = "SmallMoleculePDB" + description = ( + "Creates a PDB file for a small molecule" + "Use this tool when you need to use a small molecule in a simulation." + "Input can be a molecule name or a SMILES string." + ) + path_registry: Optional[PathRegistry] + + def __init__(self, path_registry: Optional[PathRegistry]): + super().__init__() + self.path_registry = path_registry + + def _run(self, mol_str: str) -> str: + """use the tool.""" + mol_pdb = MolPDB(self.path_registry) + output = mol_pdb.small_molecule_pdb(mol_str) + return output diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py deleted file mode 100644 index 7c9a5b2c..00000000 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py +++ /dev/null @@ -1,1486 +0,0 @@ -import os -import re -import subprocess -import sys -import typing -from typing import Any, Dict, List, Optional, Type, Union - -import requests -import streamlit as st -from langchain.tools import BaseTool -from pdbfixer import PDBFixer -from pydantic import BaseModel, Field, ValidationError, root_validator -from rdkit import Chem - -from mdagent.utils import FileType, PathRegistry - -from .elements import list_of_elements - - -def get_pdb(query_string, path_registry=None): - """ - Search RSCB's protein data bank using the given query string - and return the path to pdb file in either CIF or PDB format - """ - if path_registry is None: - path_registry = PathRegistry.get_instance() - url = "https://search.rcsb.org/rcsbsearch/v2/query?json={search-request}" - query = { - "query": { - "type": "terminal", - "service": "full_text", - "parameters": {"value": query_string}, - }, - "return_type": "entry", - } - r = requests.post(url, json=query) - if r.status_code == 204: - return None - if "cif" in query_string or "CIF" in query_string: - filetype = "cif" - else: - filetype = "pdb" - if "result_set" in r.json() and len(r.json()["result_set"]) > 0: - pdbid = r.json()["result_set"][0]["identifier"] - print(f"PDB file found with this ID: {pdbid}") - st.markdown(f"PDB file found with this ID: {pdbid}", unsafe_allow_html=True) - url = f"https://files.rcsb.org/download/{pdbid}.{filetype}" - pdb = requests.get(url) - filename = path_registry.write_file_name( - FileType.PROTEIN, - protein_name=pdbid, - description="raw", - file_format=filetype, - ) - file_id = path_registry.get_fileid(filename, FileType.PROTEIN) - directory = "files/pdb" - # Create the directory if it does not exist - if not os.path.exists(directory): - os.makedirs(directory) - - with open(f"{directory}/{filename}", "w") as file: - file.write(pdb.text) - - return filename, file_id - return None - - -class ProteinName2PDBTool(BaseTool): - name = "PDBFileDownloader" - description = ( - "This tool downloads PDB (Protein Data Bank) or" - "CIF (Crystallographic Information File) files using" - "a protein's common name (NOT a small molecule)." - "When a specific file type, either PDB or CIF," - "is requested, add file type to the query string with space." - "Input: Commercial name of the protein or file without" - "file extension" - "Output: Corresponding PDB or CIF file" - ) - path_registry: Optional[PathRegistry] - - def __init__(self, path_registry: Optional[PathRegistry]): - super().__init__() - self.path_registry = path_registry - - def _run(self, query: str) -> str: - """Use the tool.""" - try: - if self.path_registry is None: # this should not happen - return "Path registry not initialized" - filename, pdbfile_id = get_pdb(query, self.path_registry) - if pdbfile_id is None: - return "Name2PDB tool failed to find and download PDB file." - else: - self.path_registry.map_path( - pdbfile_id, - f"files/pdb/{filename}", - f"PDB file downloaded from RSCB, PDBFile ID: {pdbfile_id}", - ) - return f"Name2PDB tool successful. downloaded the PDB file:{pdbfile_id}" - except Exception as e: - return f"Something went wrong. {e}" - - async def _arun(self, query) -> str: - """Use the tool asynchronously.""" - raise NotImplementedError("this tool does not support async") - - -"""validate_pdb_format: validates a pdb file against the pdb format specification - packmol_wrapper: takes in a list of pdb files, a - list of number of molecules, a list of instructions, and a list of small molecules - and returns a packed pdb file - Molecule: class that represents a molecule (helpful for packmol - PackmolBox: class that represents a box of molecules (helpful for packmol) - summarize_errors: function that summarizes the errors found by validate_pdb_format - _extract_path: function that extracts a file path from a string - _standard_cleaning: function that cleans a pdb file using pdbfixer)""" - -########PDB Validation######### - - -def validate_pdb_format(fhandle): - """ - Compare each ATOM/HETATM line with the format defined on the - official PDB website. - - Parameters - ---------- - fhandle : a line-by-line iterator of the original PDB file. - - Returns - ------- - (int, list) - - 1 if error was found, 0 if no errors were found. - - List of error messages encountered. - """ - # check if filename is in directory - if not os.path.exists(fhandle): - return (1, ["File not found. Packmol failed to write the file."]) - errors = [] - _fmt_check = ( - ("Atm. Num.", (slice(6, 11), re.compile(r"[\d\s]+"))), - ("Alt. Loc.", (slice(11, 12), re.compile(r"\s"))), - ("Atm. Nam.", (slice(12, 16), re.compile(r"\s*[A-Z0-9]+\s*"))), - ("Spacer #1", (slice(16, 17), re.compile(r"[A-Z0-9 ]{1}"))), - ("Res. Nam.", (slice(17, 20), re.compile(r"\s*[A-Z0-9]+\s*"))), - ("Spacer #2", (slice(20, 21), re.compile(r"\s"))), - ("Chain Id.", (slice(21, 22), re.compile(r"[A-Za-z0-9 ]{1}"))), - ("Res. Num.", (slice(22, 26), re.compile(r"\s*[\d\-]+\s*"))), - ("Ins. Code", (slice(26, 27), re.compile(r"[A-Z0-9 ]{1}"))), - ("Spacer #3", (slice(27, 30), re.compile(r"\s+"))), - ("Coordn. X", (slice(30, 38), re.compile(r"\s*[\d\.\-]+\s*"))), - ("Coordn. Y", (slice(38, 46), re.compile(r"\s*[\d\.\-]+\s*"))), - ("Coordn. Z", (slice(46, 54), re.compile(r"\s*[\d\.\-]+\s*"))), - ("Occupancy", (slice(54, 60), re.compile(r"\s*[\d\.\-]+\s*"))), - ("Tmp. Fac.", (slice(60, 66), re.compile(r"\s*[\d\.\-]+\s*"))), - ("Spacer #4", (slice(66, 72), re.compile(r"\s+"))), - ("Segm. Id.", (slice(72, 76), re.compile(r"[\sA-Z0-9\-\+]+"))), - ("At. Elemt", (slice(76, 78), re.compile(r"[\sA-Z0-9\-\+]+"))), - ("At. Charg", (slice(78, 80), re.compile(r"[\sA-Z0-9\-\+]+"))), - ) - - def _make_pointer(column): - col_bg, col_en = column.start, column.stop - pt = ["^" if c in range(col_bg, col_en) else " " for c in range(80)] - return "".join(pt) - - for iline, line in enumerate(fhandle, start=1): - line = line.rstrip("\n").rstrip("\r") # CR/LF - if not line: - continue - - if line[0:6] in ["ATOM ", "HETATM"]: - # ... [rest of the code unchanged here] - linelen = len(line) - if linelen < 80: - emsg = "[!] Line {0} is short: {1} < 80\n" - sys.stdout.write(emsg.format(iline, linelen)) - - elif linelen > 80: - emsg = "[!] Line {0} is long: {1} > 80\n" - sys.stdout.write(emsg.format(iline, linelen)) - - for fname, (fcol, fcheck) in _fmt_check: - field = line[fcol] - if not fcheck.match(field): - pointer = _make_pointer(fcol) - emsg = "[!] Offending field ({0}) at line {1}\n".format( - fname, iline - ) - emsg += repr(line) + "\n" - emsg += pointer + "\n" - errors.append(emsg) - - else: - # ... [rest of the code unchanged here] - linelen = len(line) - # ... [rest of the code unchanged here] - linelen = len(line) - skip_keywords = ( - "END", - "ENDMDL", - "HEADER", - "TITLE", - "REMARK", - "CRYST1", - "MODEL", - ) - - if any(keyword in line for keyword in skip_keywords): - continue - - if linelen < 80: - emsg = "[!] Line {0} is short: {1} < 80\n" - sys.stdout.write(emsg.format(iline, linelen)) - elif linelen > 80: - emsg = "[!] Line {0} is long: {1} > 80\n" - sys.stdout.write(emsg.format(iline, linelen)) - - """ - map paths to files in path_registry before you return the string - same for all other functions you want to save files for next tools - Don't forget to import PathRegistry and add path_registry - or PathRegistry as an argument - """ - if errors: - msg = "\nTo understand your errors, read the format specification:\n" - msg += "http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM\n" - errors.append(msg) - return (1, errors) - else: - return (0, ["It *seems* everything is OK."]) - - -##########################PACKMOL############################### - - -def summarize_errors(errors): - error_summary = {} - - # Regular expression pattern to capture the error type and line number - pattern = r"\[!\] Offending field \((.+?)\) at line (\d+)" - - for error in errors: - match = re.search(pattern, error) - if match: - error_type, line_number = match.groups() - # If this error type hasn't been seen before, - # initialize it in the dictionary - if error_type not in error_summary: - error_summary[error_type] = {"lines": []} - error_summary[error_type]["lines"].append(line_number) - - # Format the summarized errors for display - summarized_strings = [] - for error_type, data in error_summary.items(): - line_count = len(data["lines"]) - if line_count > 3: - summarized_strings.append(f"{error_type}: total {line_count} lines") - else: - summarized_strings.append(f"{error_type}: lines: {','.join(data['lines'])}") - - return summarized_strings - - -class Molecule: - def __init__(self, filename, file_id, number_of_molecules=1, instructions=None): - self.filename = filename - self.id = file_id - self.number_of_molecules = number_of_molecules - self.instructions = instructions if instructions else [] - self.load() - - def load(self): - # load the molecule data (optional) - pass - - def get_number_of_atoms(self): - # return the number of atoms in this molecule - pass - - -class PackmolBox: - def __init__( - self, file_number=1, file_description="PDB file for simulation with: \n" - ): - self.molecules = [] - self.file_number = 1 - self.file_description = file_description - self.final_name = None - - def add_molecule(self, molecule): - self.molecules.append(molecule) - self.file_description += f"""{molecule.number_of_molecules} of - {molecule.filename} as {molecule.instructions} \n""" - - def generate_input_header(self): - # Generate the header of the input file in .inp format - orig_pdbs_ids = [ - f"{molecule.number_of_molecules}_{molecule.id}" - for molecule in self.molecules - ] - - _final_name = f'{"_and_".join(orig_pdbs_ids)}' - - self.file_description = ( - "Packed Structures of the following molecules:\n" - + "\n".join( - [ - f"Molecule ID: {molecule.id}, " - f"Number of Molecules: {molecule.number_of_molecules}" - for molecule in self.molecules - ] - ) - ) - while os.path.exists(f"files/pdb/{_final_name}_v{self.file_number}.pdb"): - self.file_number += 1 - - self.final_name = f"{_final_name}_v{self.file_number}.pdb" - with open("packmol.inp", "w") as out: - out.write("##Automatically generated by LangChain\n") - out.write("tolerance 2.0\n") - out.write("filetype pdb\n") - out.write( - f"output {self.final_name}\n" - ) # this is the name of the final file - out.close() - - def generate_input(self): - input_data = [] - for molecule in self.molecules: - input_data.append(f"structure {molecule.filename}") - input_data.append(f" number {molecule.number_of_molecules}") - for idx, instruction in enumerate(molecule.instructions): - input_data.append(f" {molecule.instructions[idx]}") - input_data.append("end structure") - - # Convert list of input data to a single string - return "\n".join(input_data) - - def run_packmol(self, PathRegistry): - # Use the generated input to execute Packmol - input_string = self.generate_input() - # Write the input to a file - with open("packmol.inp", "a") as f: - f.write(input_string) - # Here, run Packmol using the subprocess module or similar - cmd = "packmol < packmol.inp" - result = subprocess.run(cmd, shell=True, text=True, capture_output=True) - if result.returncode != 0: - print("Packmol failed to run with 'packmol < packmol.inp' command") - result = subprocess.run( - "./" + cmd, shell=True, text=True, capture_output=True - ) - if result.returncode != 0: - print("Packmol failed to run with './packmol < packmol.inp' command") - return ( - "Packmol failed to run. Please check the input file and try again." - ) - - # validate final pdb - pdb_validation = validate_pdb_format(f"{self.final_name}") - if pdb_validation[0] == 0: - # delete .inp files - # os.remove("packmol.inp") - for molecule in self.molecules: - os.remove(molecule.filename) - # name of packed pdb file - time_stamp = PathRegistry.get_timestamp()[-6:] - os.rename(self.final_name, f"files/pdb/{self.final_name}") - PathRegistry.map_path( - f"PACKED_{time_stamp}", - f"files/pdb/{self.final_name}", - self.file_description, - ) - # move file to files/pdb - print("successfull!") - return f"PDB file validated successfully. FileID: PACKED_{time_stamp}" - elif pdb_validation[0] == 1: - # format pdb_validation[1] list of errors - errors = summarize_errors(pdb_validation[1]) - # delete .inp files - - # os.remove("packmol.inp") - print("errors:", f"{errors}") - return "PDB file not validated, errors found {}".format(("\n").join(errors)) - - -# define function that takes in a list of -# molecules and a list of instructions and returns a pdb file - - -def packmol_wrapper( - PathRegistry, - pdbfiles: List, - files_id: List, - number_of_molecules: List, - instructions: List[List], -): - """Useful when you need to create a box - of different types of molecules molecules""" - - # create a box - box = PackmolBox() - # add molecules to the box - for ( - pdbfile, - file_id, - number_of_molecules, - instructions, - ) in zip(pdbfiles, files_id, number_of_molecules, instructions): - molecule = Molecule(pdbfile, file_id, number_of_molecules, instructions) - box.add_molecule(molecule) - # generate input header - box.generate_input_header() - # generate input - # run packmol - print("Packing:", box.file_description, "\nThe file name is:", box.final_name) - return box.run_packmol(PathRegistry) - - -"""Args schema for packmol_wrapper tool. Useful for OpenAI functions""" -##TODO - - -class PackmolInput(BaseModel): - pdbfiles_id: typing.Optional[typing.List[str]] = Field( - ..., description="List of PDB files id (path_registry) to pack into a box" - ) - small_molecules: typing.Optional[typing.List[str]] = Field( - [], - description=( - "List of small molecules to be packed in the system. " - "Examples: water, benzene, toluene, etc." - ), - ) - - number_of_molecules: typing.Optional[typing.List[int]] = Field( - ..., - description=( - "List of number of instances of each species to pack into the box. " - "One number per species (either protein or small molecule) " - ), - ) - instructions: typing.Optional[typing.List[List[str]]] = Field( - ..., - description=( - "List of instructions for each species. " - "One List per Molecule. " - "Every instruction should be one string like:\n" - "'inside box 0. 0. 0. 90. 90. 90.'" - ), - ) - - -class PackMolTool(BaseTool): - name: str = "packmol_tool" - description: str = ( - "Useful when you need to create a box " - "of different types of chemical species.\n" - "Three different examples:\n" - "pdbfiles_id: ['1a2b_123456']\n" - "small_molecules: ['water'] \n" - "number_of_molecules: [1, 1000]\n" - "instructions: [['fixed 0. 0. 0. 0. 0. 0. \n centerofmass'], " - "['inside box 0. 0. 0. 90. 90. 90.']]\n" - "will pack 1 molecule of 1a2b_123456 at the origin " - "and 1000 molecules of water. \n" - "pdbfiles_id: ['1a2b_123456']\n" - "number_of_molecules: [1]\n" - "instructions: [['fixed 0. 0. 0. 0. 0. 0.' \n center]]\n" - "This will fix the barocenter of protein 1a2b_123456 at " - "the center of the box with no rotation.\n" - "pdbfiles_id: ['1a2b_123456']\n" - "number_of_molecules: [1]\n" - "instructions: [['outside sphere 2.30 3.40 4.50 8.0]]\n" - "This will place the protein 1a2b_123456 outside a sphere " - "centered at 2.30 3.40 4.50 with radius 8.0\n" - ) - - args_schema: Type[BaseModel] = PackmolInput - - path_registry: typing.Optional[PathRegistry] - - def __init__(self, path_registry: typing.Optional[PathRegistry]): - super().__init__() - self.path_registry = path_registry - - def _get_sm_pdbs(self, small_molecules): - all_files = self.path_registry.list_path_names() - for molecule in small_molecules: - # check path registry for molecule.pdb - if molecule not in all_files: - # download molecule using small_molecule_pdb from MolPDB - molpdb = MolPDB() - molpdb.small_molecule_pdb(molecule, self.path_registry) - print("Small molecules PDBs created successfully") - - def _run(self, **values) -> str: - """use the tool.""" - - if self.path_registry is None: # this should not happen - raise ValidationError("Path registry not initialized") - try: - values = self.validate_input(values) - except ValidationError as e: - return str(e) - error_msg = values.get("error", None) - if error_msg: - print("Error in Packmol inputs:", error_msg) - return f"Error in inputs: {error_msg}" - print("Starting Packmol Tool!") - pdbfile_ids = values.get("pdbfiles_id", []) - pdbfiles = [ - self.path_registry.get_mapped_path(pdbfile) for pdbfile in pdbfile_ids - ] - pdbfile_names = [pdbfile.split("/")[-1] for pdbfile in pdbfiles] - # copy them to the current directory with temp_ names - - pdbfile_names = [f"temp_{pdbfile_name}" for pdbfile_name in pdbfile_names] - number_of_molecules = values.get("number_of_molecules", []) - instructions = values.get("instructions", []) - small_molecules = values.get("small_molecules", []) - # make sure small molecules are all downloaded - self._get_sm_pdbs(small_molecules) - small_molecules_files = [ - self.path_registry.get_mapped_path(sm) for sm in small_molecules - ] - small_molecules_file_names = [ - small_molecule.split("/")[-1] for small_molecule in small_molecules_files - ] - small_molecules_file_names = [ - f"temp_{small_molecule_file_name}" - for small_molecule_file_name in small_molecules_file_names - ] - # append small molecules to pdbfiles - pdbfiles.extend(small_molecules_files) - pdbfile_names.extend(small_molecules_file_names) - pdbfile_ids.extend(small_molecules) - - for pdbfile, pdbfile_name in zip(pdbfiles, pdbfile_names): - os.system(f"cp {pdbfile} {pdbfile_name}") - # check if packmol is installed - cmd = "command -v packmol" - result = subprocess.run(cmd, shell=True, text=True, capture_output=True) - if result.returncode != 0: - result = subprocess.run( - "./" + cmd, shell=True, text=True, capture_output=True - ) - if result.returncode != 0: - return ( - "Packmol is not installed. Please install" - "packmol at " - "'https://m3g.github.io/packmol/download.shtml'" - "and try again." - ) - - return packmol_wrapper( - self.path_registry, - pdbfiles=pdbfile_names, - files_id=pdbfile_ids, - number_of_molecules=number_of_molecules, - instructions=instructions, - ) - - def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: - # check if is only a string - if isinstance(values, str): - print("values is a string", values) - raise ValidationError("Input must be a dictionary") - pdbfiles = values.get("pdbfiles_id", []) - small_molecules = values.get("small_molecules", []) - number_of_molecules = values.get("number_of_molecules", []) - instructions = values.get("instructions", []) - number_of_species = len(pdbfiles) + len(small_molecules) - - if not number_of_species == len(number_of_molecules): - if not number_of_species == len(instructions): - return { - "error": ( - "The length of number_of_molecules AND instructions " - "must be equal to the number of species in the system. " - f"You have {number_of_species} " - f"from {len(pdbfiles)} pdbfiles and {len(small_molecules)} " - "small molecules" - ) - } - return { - "error": ( - "The length of number_of_molecules must be equal to the " - f"number of species in the system. You have {number_of_species} " - f"from {len(pdbfiles)} pdbfiles and {len(small_molecules)} " - "small molecules" - ) - } - elif not number_of_species == len(instructions): - return { - "error": ( - "The length of instructions must be equal to the " - f"number of species in the system. You have {number_of_species} " - f"from {len(pdbfiles)} pdbfiles and {len(small_molecules)} " - "small molecules" - ) - } - - molPDB = MolPDB() - for instruction in instructions: - if len(instruction) != 1: - return { - "error": ( - "Each instruction must be a single string. " - "If necessary, use newlines in a instruction string." - ) - } - # TODO enhance this validation with more packmol instructions - first_word = instruction[0].split(" ")[0] - if first_word == "center": - if len(instruction[0].split(" ")) == 1: - return { - "error": ( - "The instruction 'center' must be accompanied by more " - "instructions. Example 'fixed 0. 0. 0. 0. 0. 0.' " - "The complete instruction would be: 'center \n fixed 0. 0. " - "0. 0. 0. 0.' with a newline separating the two " - "instructions." - ) - } - elif first_word not in [ - "inside", - "outside", - "fixed", - ]: - return { - "error": ( - "The first word of each instruction must be one of " - "'inside' or 'outside' or 'fixed' \n" - "examples: center \n fixed 0. 0. 0. 0. 0. 0.,\n" - "inside box -10. 0. 0. 10. 10. 10. \n" - ) - } - - # Further validation, e.g., checking if files exist - registry = PathRegistry() - file_ids = registry.list_path_names() - - for pdbfile_id in pdbfiles: - if "_" not in pdbfile_id: - return { - "error": ( - f"{pdbfile_id} is not a valid pdbfile_id in the path_registry" - ) - } - if pdbfile_id not in file_ids: - # look for files in the current directory - # that match some part of the pdbfile - ids_w_description = registry.list_path_names_and_descriptions() - - return { - "error": ( - f"PDB file ID {pdbfile_id} does not exist " - "in the path registry.\n" - f"This are the files IDs: {ids_w_description} " - ) - } - for small_molecule in small_molecules: - if small_molecule not in file_ids: - result = molPDB.small_molecule_pdb(small_molecule, registry) - if "successfully" not in result: - return { - "error": ( - f"{small_molecule} could not be converted to a pdb " - "file. Try with a different name, or with the SMILES " - "of the small molecule" - ) - } - return values - - async def _arun(self, values: str) -> str: - """Use the tool asynchronously.""" - raise NotImplementedError("custom_search does not support async") - - -########VALIDATION AND FIXING PDB FILES######################## - - -class PDBsummarizerfxns: - def __init__(self): - self.list_of_elements = list_of_elements - - def _record_inf(self, pdbfile): - with open(pdbfile, "r") as f: - lines = f.readlines() - remarks = [ - record_lines - for record_lines in lines - if record_lines.startswith("REMARK") - ] - atoms = [ - record_lines - for record_lines in lines - if record_lines.startswith("ATOM") - ] - box = [ - record_lines - for record_lines in lines - if record_lines.startswith("CRYST1") - ] - HETATM = [ - record_lines - for record_lines in lines - if record_lines.startswith("HETATM") - ] - - return remarks, atoms, box, HETATM - - def _num_of_dif_residues(self, pdbfile): - remarks, atoms, box, HETATM = self._record_inf(pdbfile) - residues = [atom[17:20] for atom in atoms] - residues = list(set(residues)) - return len(residues) - - # diagnosis - """Checks for the elements names in the pdb file. - Positions 76-78 of the ATOM and HETATM records""" - - def _atoms_have_elements(self, pdbfile): - _, atoms, _, _ = self._record_inf(pdbfile) - print(atoms) - elements = [atom[76:78] for atom in atoms if atom not in [" ", "", " ", " "]] - print(elements) - if len(elements) != len(atoms): - print( - ( - "No elements in the ATOM records there are" - "{len(elements)} elements and {len(atoms)}" - "atoms records" - ) - ) - return False - elements = list(set(elements)) - for element in elements: - if element not in self.list_of_elements: - print("Element not in the list of elements") - return False - return True - - def _atoms_have_tempFactor(self, pdbfile): - _, atoms, _, _ = self._record_inf(pdbfile) - tempFactor = [ - atom[60:66] - for atom in atoms - if atom[60:66] not in [" ", "", " ", " ", " ", " "] - ] - if len(tempFactor) != len(atoms): - return False - return True - - def _atoms_have_occupancy(self, pdbfile): - _, atoms, _, _ = self._record_inf(pdbfile) - occupancy = [ - atom[54:60] - for atom in atoms - if atom[54:60] not in [" ", "", " ", " ", " ", " "] - ] - if len(occupancy) != len(atoms): - return False - return True - - def _hetatom_have_occupancy(self, pdbfile): - _, _, _, HETATM = self._record_inf(pdbfile) - occupancy = [ - atom[54:60] - for atom in HETATM - if atom[54:60] not in [" ", "", " ", " ", " ", " "] - ] - if len(occupancy) != len(HETATM): - return False - return True - - def _hetatm_have_elements(self, pdbfile): - _, _, _, HETATM = self._record_inf(pdbfile) - elements = [ - atom[76:78] for atom in HETATM if atom[76:78] not in [" ", "", " ", " "] - ] - if len(elements) != len(HETATM): - print("No elements in the HETATM records") - return False - return True - - def _hetatm_have_tempFactor(self, pdbfile): - _, _, _, HETATM = self._record_inf(pdbfile) - tempFactor = [ - atom[60:66] for atom in HETATM if atom not in [" ", "", " ", " "] - ] - if len(tempFactor) != len(HETATM): - return False - return True - - """Checks for the residue names in the pdb file. - Positions 17-20 of the ATOM and HETATM records""" - - def _atoms_hetatm_have_residue_names(self, pdbfile): - _, atoms, _, HETATM = self._record_inf(pdbfile) - residues = [atom[17:20] for atom in atoms] - residues = list(set(residues)) - if len(residues) != len(atoms): - return False - residues = [atom[17:20] for atom in HETATM] - residues = list(set(residues)) - if len(residues) != len(HETATM): - return False - return True - - def _atoms_hetatm_have_occupancy(self, pdbfile): - _, atoms, _, HETATM = self._record_inf(pdbfile) - occupancy = [ - atom[54:60] - for atom in atoms - if atom not in [" ", "", " ", " ", " ", " "] - ] - if len(occupancy) != len(atoms): - return False - occupancy = [ - HET[54:60] - for HET in HETATM - if HET not in [" ", "", " ", " ", " ", " "] - ] - if len(occupancy) != len(HETATM): - return False - return True - - def _non_standard_residues(self, pdbfile): - fixer = PDBFixer(file_name=pdbfile) - fixer.findNonstandardResidues() - len(fixer.nonstandardResidues) - - -def pdb_summarizer(pdb_file): - pdb = PDBsummarizerfxns() - pdb.remarks, pdb.atoms, pdb.box, pdb.HETATM = pdb._record_inf(pdb_file) - pdb.atoms_elems = pdb._atoms_have_elements(pdb_file) - pdb.HETATM_elems = pdb._hetatm_have_elements(pdb_file) - pdb.residues = pdb._atoms_hetatm_have_residue_names(pdb_file) - pdb.atoms_tempFact = pdb._atoms_have_tempFactor(pdb_file) - pdb.num_of_residues = pdb._num_of_dif_residues(pdb_file) - pdb.HETATM_tempFact = pdb._hetatm_have_tempFactor(pdb_file) - - output = ( - f"PDB file: {pdb_file} has the following properties:" - "Number of residues: {pdb.num_of_residues}" - "Are elements identifiers present: {pdb.atoms}" - "Are HETATM elements identifiers present: {pdb.HETATM}" - "Are residue names present: {pdb.residues}" - "Are box dimensions present: {pdb.box}" - "Non-standard residues: {pdb.HETATM}" - ) - return output - - -def _fix_element_column(pdb_file, custom_element_dict=None): - records = ("ATOM", "HETATM", "ANISOU") - corrected_lines = [] - for line in pdb_file: - if line.startswith(records): - atom_name = line[12:16] - - if atom_name[0].isalpha() and not atom_name[2:].isdigit(): - element = atom_name.strip() - else: - atom_name = atom_name.strip() - if atom_name[0].isdigit(): - element = atom_name[1] - else: - element = atom_name[0] - - if element not in set(list_of_elements): - element = " " # empty element in case we cannot assign - - line = line[:76] + element.rjust(2) + line[78:] - corrected_lines.append(line) - - return corrected_lines - - -def fix_element_column(pdb_file, custom_element_dict=None): - """Fixes the Element columns of a pdb file""" - - # extract Title, Header, Remarks, and Cryst1 records - file_name = pdb_file.split(".")[0] - # check if theres a file-name-fixed.pdb file - if os.path.isfile(file_name + "-fixed.pdb"): - pdb_file = file_name + "-fixed.pdb" - assert isinstance(pdb_file, str), "pdb_file must be a string" - with open(pdb_file, "r") as f: - print("I read the initial file") - pdb_file_lines = f.readlines() - # validate if pdbfile has element records - pdb = PDBsummarizerfxns() - atoms_have_elems, HETATM_have_elems = pdb._atoms_have_elements( - pdb_file - ), pdb._hetatm_have_elements(pdb_file) - if atoms_have_elems and HETATM_have_elems: - f.close() - return ( - "Element's column already filled with" - "elements, no fix needed for elements" - ) - print("I closed the initial file") - f.close() - - # fix element column - records = ("TITLE", "HEADER", "REMARK", "CRYST1", "HET", "LINK", "SEQRES") - final_records = ("CONECT", "MASTER", "END") - _unchanged_records = [] - _unchanged_final_records = [] - print("pdb_file", pdb_file) - for line in pdb_file_lines: - if line.startswith(records): - _unchanged_records.append(line) - elif line.startswith(final_records): - _unchanged_final_records.append(line) - print("_unchanged_records", _unchanged_records) - new_pdb = _fix_element_column(pdb_file_lines, custom_element_dict) - # join the linees - new_pdb = "".join(new_pdb) - # write new pdb file as pdb_file-fixed.pdb - new_pdb_file = file_name.split(".")[0] + "-fixed.pdb" - print("name of fixed pdb file", new_pdb_file) - # write the unchanged records first and then the new pdb file - assert isinstance(new_pdb_file, str), "new_pdb_file must be a string" - with open(new_pdb_file, "w") as f: - print("I wrote the new file") - f.writelines(_unchanged_records) - f.write(new_pdb) - f.writelines(_unchanged_final_records) - f.close() - try: - # read the new pdb file and check if it has element records - with open(new_pdb_file, "r") as f: - pdb_file_lines = f.readlines() - pdb = PDBsummarizerfxns() - atoms_have_elems, HETATM_have_elems = pdb._atoms_have_elements( - new_pdb_file - ), pdb._hetatm_have_elements(new_pdb_file) - if atoms_have_elems and HETATM_have_elems: - f.close() - return "Element's column fixed successfully" - else: - f.close() - return "Element's column not fixed, and i dont know why" - except Exception as e: - return f"Element's column not fixed error: {e}" - - -class FixElementColumnArgs(BaseTool): - # arguments of fix_element_column - pdb_file: str = Field(..., description="PDB file to be fixed") - custom_element_dict: dict = Field( - None, - description=( - "Custom element dictionary. If None," "the default dictionary is used" - ), - ) - - -def pad_line(line): - """Pad line to 80 characters in case it is shorter.""" - size_of_line = len(line) - if size_of_line < 80: - padding = 80 - size_of_line + 1 - line = line.strip("\n") + " " * padding + "\n" - return line[:81] # 80 + newline character - - -def _fix_temp_factor_column(pdbfile, bfactor, only_fill): - """Set the temperature column in all ATOM/HETATM records to a given value. - - This function is a generator. - - Parameters - ---------- - fhandle : a line-by-line iterator of the original PDB file. - - bfactor : float - The desired bfactor. - - Yields - ------ - str (line-by-line) - The modified (or not) PDB line.""" - _pad_line = pad_line - records = ("ATOM", "HETATM") - corrected_lines = [] - bfactor = "{0:>6.2f}".format(bfactor) - - for line in pdbfile: - if line.startswith(records): - line = _pad_line(line) - if only_fill: - if line[60:66].strip() == "": - corrected_lines.append(line[:60] + bfactor + line[66:]) - else: - corrected_lines.append(line[:60] + bfactor + line[66:]) - else: - corrected_lines.append(line) - - return corrected_lines - - -def fix_temp_factor_column(pdb_file, bfactor=1.00, only_fill=True): - """Fixes the tempFactor columns of a pdb file""" - - # extract Title, Header, Remarks, and Cryst1 records - # get name from pdb_file - if isinstance(pdb_file, str): - file_name = pdb_file.split(".")[0] - else: - return "pdb_file must be a string" - file_name = pdb_file.split(".")[0] - - if os.path.isfile(file_name + "-fixed.pdb"): - file_name = file_name + "-fixed.pdb" - - assert isinstance(file_name, str), "pdb_file must be a string" - with open(file_name, "r") as f: - print("im reading the files temp factor") - pdb_file_lines = f.readlines() - # validate if pdbfile has temp factors - pdb = PDBsummarizerfxns() - atoms_have_bfactor, HETATM_have_bfactor = pdb._atoms_have_tempFactor( - pdb_file - ), pdb._hetatm_have_tempFactor(pdb_file) - if atoms_have_bfactor and HETATM_have_bfactor and only_fill: - # print("Im closing the file temp factor") - f.close() - return ( - "TempFact column filled with bfactor already," - "no fix needed for temp factor" - ) - f.close() - # fix element column - records = ("TITLE", "HEADER", "REMARK", "CRYST1", "HET", "LINK", "SEQRES") - final_records = ("CONECT", "MASTER", "END") - _unchanged_final_records = [] - _unchanged_records = [] - for line in pdb_file_lines: - if line.startswith(records): - _unchanged_records.append(line) - elif line.startswith(final_records): - _unchanged_final_records.append(line) - - new_pdb = _fix_temp_factor_column(pdb_file_lines, bfactor, only_fill) - # join the linees - new_pdb = "".join(new_pdb) - # write new pdb file as pdb_file-fixed.pdb - new_pdb_file = file_name + "-fixed.pdb" - # organize columns HEADER, TITLE, REMARKS, CRYST1, ATOM, HETATM, CONECT, MASTER, END - - assert isinstance(new_pdb_file, str), "new_pdb_file must be a string" - # write new pdb file as pdb_file-fixed.pdb - with open(new_pdb_file, "w") as f: - f.writelines(_unchanged_records) - f.write(new_pdb) - f.writelines(_unchanged_final_records) - f.close() - try: - # read the new pdb file and check if it has element records - with open(new_pdb_file, "r") as f: - pdb_file = f.readlines() - pdb = PDBsummarizerfxns() - atoms_have_bfactor, HETATM_have_bfactor = pdb._atoms_have_tempFactor( - new_pdb_file - ), pdb._hetatm_have_tempFactor(new_pdb_file) - if atoms_have_bfactor and HETATM_have_bfactor: - f.close() - return "TempFact fixed successfully" - else: - f.close() - return "TempFact column not fixed" - except Exception as e: - return f"Couldnt read written file TempFact column not fixed error: {e}" - - -class FixTempFactorColumnArgs(BaseTool): - # arguments of fix_element_column - pdb_file: str = Field(..., description="PDB file to be fixed") - bfactor: float = Field(1.0, description="Bfactor value to use") - only_fill: bool = Field( - True, - description=( - "Only fill empty bfactor columns." - "Avoids replacing existing values." - "False if you want to replace all values" - "with the bfactor value" - ), - ) - - -def _fix_occupancy_column(pdbfile, occupancy, only_fill): - """ - Set the occupancy column in all ATOM/HETATM records to a given value. - - Non-ATOM/HETATM lines are give as are. This function is a generator. - - Parameters - ---------- - fhandle : a line-by-line iterator of the original PDB file. - - occupancy : float - The desired occupancy value - - Yields - ------ - str (line-by-line) - The modified (or not) PDB line. - """ - - records = ("ATOM", "HETATM") - corrected_lines = [] - occupancy = "{0:>6.2f}".format(occupancy) - for line in pdbfile: - if line.startswith(records): - line = pad_line(line) - if only_fill: - if line[54:60].strip() == "": - corrected_lines.append(line[:54] + occupancy + line[60:]) - else: - corrected_lines.append(line[:54] + occupancy + line[60:]) - else: - corrected_lines.append(line) - - return corrected_lines - - -def fix_occupancy_columns(pdb_file, occupancy=1.0, only_fill=True): - """Fixes the occupancy columns of a pdb file""" - # extract Title, Header, Remarks, and Cryst1 records - # get name from pdb_file - file_name = pdb_file.split(".")[0] - if os.path.isfile(file_name + "-fixed.pdb"): - file_name = file_name + "-fixed.pdb" - - assert isinstance(pdb_file, str), "pdb_file must be a string" - with open(file_name, "r") as f: - pdb_file_lines = f.readlines() - # validate if pdbfile has occupancy - pdb = PDBsummarizerfxns() - atoms_have_bfactor, HETATM_have_bfactor = pdb._atoms_have_occupancy( - file_name - ), pdb._hetatom_have_occupancy(file_name) - if atoms_have_bfactor and HETATM_have_bfactor and only_fill: - f.close() - return ( - "Occupancy column filled with occupancy" - "already, no fix needed for occupancy" - ) - f.close() - # fix element column - records = ("TITLE", "HEADER", "REMARK", "CRYST1", "HET", "LINK", "SEQRES") - final_records = ("CONECT", "MASTER", "END") - _unchanged_records = [] - _unchanged_final_records = [] - for line in pdb_file_lines: - if line.startswith(records): - _unchanged_records.append(line) - elif line.startswith(final_records): - _unchanged_final_records.append(line) - - new_pdb = _fix_occupancy_column(pdb_file_lines, occupancy, only_fill) - # join the linees - new_pdb = "".join(new_pdb) - # write new pdb file as pdb_file-fixed.pdb - new_pdb_file = file_name + "-fixed.pdb" - - # write new pdb file as pdb_file-fixed.pdb - assert isinstance(new_pdb_file, str), "new_pdb_file must be a string" - with open(new_pdb_file, "w") as f: - f.writelines(_unchanged_records) - f.write(new_pdb) - f.writelines(_unchanged_final_records) - f.close() - try: - # read the new pdb file and check if it has element records - with open(new_pdb_file, "r") as f: - pdb_file = f.readlines() - pdb = PDBsummarizerfxns() - atoms_have_bfactor, HETATM_have_bfactor = pdb._atoms_have_tempFactor( - new_pdb_file - ), pdb._hetatm_have_tempFactor(new_pdb_file) - if atoms_have_bfactor and HETATM_have_bfactor: - f.close() - return "Occupancy fixed successfully" - else: - f.close() - return "Occupancy column not fixed" - except Exception: - return "Couldnt read file Occupancy's column not fixed" - - -class FixOccupancyColumnArgs(BaseTool): - # arguments of fix_element_column - pdb_file: str = Field(..., description="PDB file to be fixed") - occupancy: float = Field(1.0, description="Occupancy value to be set") - only_fill: bool = Field( - True, - description=( - "Only fill empty occupancy columns." - "Avoids replacing existing values." - "False if you want to replace all" - "values with the occupancy value" - ), - ) - - -# Define a mapping between query keys and functions. -# If a function requires additional arguments from the query, define it as a lambda. -FUNCTION_MAP = { - "ElemColum": lambda pdbfile, params: fix_element_column(pdbfile), - "tempFactor": lambda pdbfile, params: fix_temp_factor_column(pdbfile, *params), - "Occupancy": lambda pdbfile, params: fix_occupancy_columns(pdbfile, *params), -} - - -def apply_fixes(pdbfile, query): - # Iterate through the keys and functions in FUNCTION_MAP. - for key, func in FUNCTION_MAP.items(): - # Check if the current key is in the query and is not None. - params = query.get(key) - if params is not None: - # If it is, call the function with - # pdbfile and the parameters from the query. - func(pdbfile, params) - - return "PDB file fixed" - - -class PDBFilesFixInp(BaseModel): - pdbfile: str = Field(..., description="PDB file to be fixed") - ElemColum: typing.Optional[bool] = Field( - False, - description=( - "List of fixes to be applied. If None, a" - "validation of what fixes are needed is performed." - ), - ) - tempFactor: typing.Optional[typing.Tuple[float, bool]] = Field( - (...), - description=( - "Tuple of ( float, bool)" - "first arg is the" - "value to be set as the tempFill, and third arg indicates" - "if only empty TempFactor columns have to be filled" - ), - ) - Occupancy: typing.Optional[typing.Tuple[float, bool]] = Field( - (...), - description=( - "Tuple of (bool, float, bool)" - "where first arg indicates if Occupancy" - "fix has to be applied, second arg is the" - "value to be set, and third arg indicates" - "if only empty Occupancy columns have to be filled" - ), - ) - - @root_validator - def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: - if isinstance(values, str): - print("values is a string", values) - raise ValidationError("Input must be a dictionary") - - pdbfile = values.get("pdbfiles", "") - occupancy = values.get("occupancy") - tempFactor = values.get("tempFactor") - ElemColum = values.get("ElemColum") - - if occupancy is None and tempFactor is None and ElemColum is None: - if pdbfile == "": - return {"error": "No inputs given, failed use of tool."} - else: - return values - else: - if occupancy: - if len(occupancy) != 2: - return { - "error": ( - "if you want to fix the occupancy" - "column argument must be a tuple of (bool, float)" - ) - } - if not isinstance(occupancy[0], float): - return {"error": "occupancy first arg must be a float"} - if not isinstance(occupancy[1], bool): - return {"error": "occupancy second arg must be a bool"} - if tempFactor: - if len(tempFactor != 2): - return { - "error": ( - "if you want to fix the tempFactor" - "column argument must be a tuple of (float, bool)" - ) - } - if not isinstance(tempFactor[0], bool): - return {"error": "occupancy first arg must be a float"} - if not isinstance(tempFactor[1], float): - return {"error": "tempFactor second arg must be a float"} - if ElemColum is not None: - if not isinstance(ElemColum[1], bool): - return {"error": "ElemColum must be a bool"} - return values - - -class FixPDBFile(BaseTool): - name: str = "PDBFileFixer" - description: str = "Fixes PDB files columns if needed" - args_schema: Type[BaseModel] = PDBFilesFixInp - - path_registry: Optional[PathRegistry] - - def __init__(self, path_registry: Optional[PathRegistry]): - super().__init__() - self.path_registry = path_registry - - def _run(self, query: Dict): - """use the tool.""" - if self.path_registry is None: - raise ValidationError("Path registry not initialized") - error_msg = query.get("error") - if error_msg: - return error_msg - pdbfile = query.pop("pdbfile") - if len(query.keys()) == 0: - validation = validate_pdb_format(pdbfile) - if validation[0] == 0: - return "PDB file is valid, no need to fix it" - - if validation[0] == 1: - # Convert summarized_errors into a set for O(1) lookups - error_set = set(validation[1]) - - # Apply Fixes - if "At. Elem." in error_set: - fix_element_column(pdbfile) - if "Tmp. Fac." in error_set: - fix_temp_factor_column(pdbfile) - if "Occupancy" in error_set: - fix_occupancy_columns(pdbfile) - - validate = validate_pdb_format(pdbfile + "-fixed.pdb") - if validate[0] == 0: - name = pdbfile + "-fixed.pdb" - description = "PDB file fixed" - self.path_registry.map_path(name, name, description) - return "PDB file fixed" - else: - return "PDB not fully fixed" - else: - apply_fixes(pdbfile, query) - validate = validate_pdb_format(pdbfile + "-fixed.pdb") - if validate[0] == 0: - name = pdbfile + "-fixed.pdb" - description = "PDB file fixed" - self.path_registry.map_path(name, name, description) - return "PDB file fixed" - else: - return "PDB not fully fixed" - - -class MolPDB: - def is_smiles(self, text: str) -> bool: - try: - m = Chem.MolFromSmiles(text, sanitize=False) - if m is None: - return False - return True - except Exception: - return False - - def largest_mol( - self, smiles: str - ) -> ( - str - ): # from https://github.com/ur-whitelab/chemcrow-public/blob/main/chemcrow/utils.py - ss = smiles.split(".") - ss.sort(key=lambda a: len(a)) - while not self.is_smiles(ss[-1]): - rm = ss[-1] - ss.remove(rm) - return ss[-1] - - def molname2smiles( - self, query: str - ) -> ( - str - ): # from https://github.com/ur-whitelab/chemcrow-public/blob/main/chemcrow/tools/databases.py - url = " https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{}/{}" - r = requests.get(url.format(query, "property/IsomericSMILES/JSON")) - # convert the response to a json object - data = r.json() - # return the SMILES string - try: - smi = data["PropertyTable"]["Properties"][0]["IsomericSMILES"] - except KeyError: - return ( - "Could not find a molecule matching the text." - "One possible cause is that the input is incorrect, " - "input one molecule at a time." - ) - # remove salts - return Chem.CanonSmiles(self.largest_mol(smi)) - - def smiles2name(self, smi: str) -> str: - try: - smi = Chem.MolToSmiles(Chem.MolFromSmiles(smi), canonical=True) - except Exception: - return "Invalid SMILES string" - # query the PubChem database - r = requests.get( - "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/" - + smi - + "/synonyms/JSON" - ) - data = r.json() - try: - name = data["InformationList"]["Information"][0]["Synonym"][0] - except KeyError: - return "Unknown Molecule" - return name - - def small_molecule_pdb(self, mol_str: str, path_registry) -> str: - # takes in molecule name or smiles (converts to smiles if name) - # writes pdb file name.pdb (gets name from smiles if possible) - # output is done message - ps = Chem.SmilesParserParams() - ps.removeHs = False - try: - if self.is_smiles(mol_str): - m = Chem.MolFromSmiles(mol_str) - mol_name = self.smiles2name(mol_str) - else: # if input is not smiles, try getting smiles - smi = self.molname2smiles(mol_str) - m = Chem.MolFromSmiles(smi) - mol_name = mol_str - try: # only if needed - m = Chem.AddHs(m) - except Exception: # TODO: we should be more specific here - pass - Chem.AllChem.EmbedMolecule(m) - file_name = f"files/pdb/{mol_name}.pdb" - Chem.MolToPDBFile(m, file_name) - # add to path registry - if path_registry: - _ = path_registry.map_path( - mol_name, file_name, f"pdb file for the small molecule {mol_name}" - ) - return ( - f"PDB file for {mol_str} successfully created and saved to {file_name}." - ) - except Exception: # TODO: we should be more specific here - print( - "There was an error getting pdb. Please input a single molecule name." - f"{mol_str},{mol_name}, {smi}" - ) - return ( - "There was an error getting pdb. Please input a single molecule name." - ) - - -class SmallMolPDB(BaseTool): - name = "SmallMoleculePDB" - description = ( - "Creates a PDB file for a small molecule" - "Use this tool when you need to use a small molecule in a simulation." - "Input can be a molecule name or a SMILES string." - ) - path_registry: Optional[PathRegistry] - - def __init__(self, path_registry: Optional[PathRegistry]): - super().__init__() - self.path_registry = path_registry - - def _run(self, mol_str: str) -> str: - """use the tool.""" - mol_pdb = MolPDB() - output = mol_pdb.small_molecule_pdb(mol_str, self.path_registry) - return output diff --git a/mdagent/tools/base_tools/simulation_tools/__init__.py b/mdagent/tools/base_tools/simulation_tools/__init__.py index 913f1a4a..af0d099d 100644 --- a/mdagent/tools/base_tools/simulation_tools/__init__.py +++ b/mdagent/tools/base_tools/simulation_tools/__init__.py @@ -1,14 +1,8 @@ from .create_simulation import ModifyBaseSimulationScriptTool -from .setup_and_run import ( - InstructionSummary, - SetUpandRunFunction, - SetUpAndRunTool, - SimulationFunctions, -) +from .setup_and_run import SetUpandRunFunction, SetUpAndRunTool, SimulationFunctions __all__ = [ "SetUpAndRunTool", - "InstructionSummary", "SimulationFunctions", "ModifyBaseSimulationScriptTool", "SetUpandRunFunction", diff --git a/mdagent/tools/base_tools/simulation_tools/create_simulation.py b/mdagent/tools/base_tools/simulation_tools/create_simulation.py index 42a32267..639b89c0 100644 --- a/mdagent/tools/base_tools/simulation_tools/create_simulation.py +++ b/mdagent/tools/base_tools/simulation_tools/create_simulation.py @@ -17,79 +17,8 @@ class ModifyScriptUtils: def __init__(self, llm): self.llm = llm - Examples = [ - """ -from openmm.app import * -from openmm import * -from openmm.unit import * -from sys import stdout - -pdb = PDBFile("1AKI.pdb") - -#We need to define the forcefield we want to use. -#We will use the Amber14 forcefield and the TIP3P-FB water model. - -# Specify the forcefield -forcefield = ForceField('amber14-all.xml', 'amber14/tip3pfb.xml') - -#This PDB file contains some crystal water molecules which we want to strip out. -#This can be done using the Modeller class. We also add in any missing H atoms. -modeller = Modeller(pdb.topology, pdb.positions) -modeller.deleteWater() -residues=modeller.addHydrogens(forcefield) - -#We can use the addSolvent method to add water molecules -modeller.addSolvent(forcefield, padding=1.0*nanometer) - -#We now need to combine our molecular topology and the forcefield -#to create a complete description of the system. This is done using -# the ForceField object’s createSystem() function. We then create the integrator, -# and combine the integrator and system to create the Simulation object. -# Finally we set the initial atomic positions. - -system = forcefield.createSystem(modeller.topology, nonbondedMethod=PME, -nonbondedCutoff=1.0*nanometer, constraints=HBonds) -integrator = LangevinMiddleIntegrator(300*kelvin, 1/picosecond, 0.004*picoseconds) -simulation = Simulation(modeller.topology, system, integrator) -simulation.context.setPositions(modeller.positions) - -#It is a good idea to run local energy minimization at the start of a simulation, -# since the coordinates in the PDB file might produce very large forces - -print("Minimizing energy") -simulation.minimizeEnergy() - -#To get output from our simulation we can add reporters. -# We use PDBReporter to write the coorinates every 1000 timesteps -# to “output.pdb” and we use StateDataReporter to print the timestep, -# potential energy, temperature, and volume to the screen and to -# a file called “md_log.txt”. - -simulation.reporters.append(PDBReporter('output.pdb', 1000)) -simulation.reporters.append(StateDataReporter(stdout, 1000, step=True, - potentialEnergy=True, temperature=True, volume=True)) -simulation.reporters.append(StateDataReporter("md_log.txt", 100, step=True, - potentialEnergy=True, temperature=True, volume=True)) - -#We are using a Langevin integrator which means we are simulating in the NVT ensemble. -# To equilibrate the temperature we just need to run the -# simulation for a number of timesteps. -print("Running NVT") -simulation.step(10000) - -#To run our simulation in the NPT ensemble we -# need to add in a barostat to control the pressure. We can use MonteCarloBarostat -system.addForce(MonteCarloBarostat(1*bar, 300*kelvin)) -simulation.context.reinitialize(preserveState=True) - - -print("Running NPT") -simulation.step(10000) - """ - ] - - def _prompt_summary(self, query: str, llm: BaseLanguageModel = None): - if not llm: + def _prompt_summary(self, query: str): + if not self.llm: raise ValueError("No language model provided at ModifyScriptTool") prompt_template = ( @@ -120,7 +49,7 @@ def _prompt_summary(self, query: str, llm: BaseLanguageModel = None): prompt = PromptTemplate( template=prompt_template, input_variables=["base_script", "query"] ) - llm_chain = LLMChain(prompt=prompt, llm=llm) + llm_chain = LLMChain(prompt=prompt, llm=self.llm) return llm_chain.invoke(query) @@ -161,15 +90,13 @@ def __init__(self, path_registry: Optional[PathRegistry], llm): self.llm = llm def _run(self, *args, **input): - if self.llm is None: # this should not happen - print("No language model provided at ModifyScriptTool") - return "llm not initialized" if len(args) > 0: return ( "This tool expects you to provide the input as a " "dictionary: {'query': 'your query', 'script': 'script id'}" ) - + if not self.path_registry: + return "No path registry provided" # this should not happen base_script_id = input.get("script") if not base_script_id: return "No id provided. The keys for the input are: " "query' and 'script'" @@ -187,7 +114,7 @@ def _run(self, *args, **input): description = input.get("query") answer = utils._prompt_summary( - query={"base_script": base_script, "query": description}, llm=self.llm + query={"base_script": base_script, "query": description} ) script = answer["text"] thoughts, new_script = script.split("SCRIPT:") diff --git a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py b/mdagent/tools/base_tools/simulation_tools/setup_and_run.py index 808d9ca1..48a54012 100644 --- a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py +++ b/mdagent/tools/base_tools/simulation_tools/setup_and_run.py @@ -10,7 +10,6 @@ import langchain import streamlit as st -from langchain.base_language import BaseLanguageModel from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from langchain.tools import BaseTool @@ -110,9 +109,27 @@ class SimulationFunctions: - llm = langchain.chat_models.ChatOpenAI( - temperature=0.05, model_name="gpt-4", request_timeout=1000, max_tokens=2000 - ) + def __init__( + self, + path_registry, + temperature: float = 0.05, + model_name: str = "gpt-4", + request_timeout: int = 1000, + max_tokens: int = 2000, + ): + self.path_registry = path_registry + self.temperature = temperature + self.model_name = model_name + self.request_timeout = request_timeout + self.max_tokens = max_tokens + + self.llm = langchain.chat_models.ChatOpenAI( + temperature=self.temperature, + model_name=self.model_name, + request_timeout=self.request_timeout, + max_tokens=self.request_timeout, + ) + #######==================System Congifuration==================######## # System Configuration initialization. @@ -187,7 +204,7 @@ def _define_integrator( return integrator - def _prompt_summary(self, query: str, llm: BaseLanguageModel = llm): + def _prompt_summary(self, query: str): prompt_template = """Your input is the original query. Your task is to parse through the user query. and provide a summary of the file path input, @@ -252,11 +269,11 @@ def _prompt_summary(self, query: str, llm: BaseLanguageModel = llm): you may fill in with the default, but explicitly state so. Here is the information:{query}""" prompt = PromptTemplate(template=prompt_template, input_variables=["query"]) - llm_chain = LLMChain(prompt=prompt, llm=llm) + llm_chain = LLMChain(prompt=prompt, llm=self.llm) return llm_chain.run(" ".join(query)) - def _save_to_file(self, summary: str, filename: str, PathRegistry): + def _save_to_file(self, summary: str, filename: str): """Parse the summary string and save it to a file in JSON format.""" # Split the summary into lines @@ -274,11 +291,11 @@ def _save_to_file(self, summary: str, filename: str, PathRegistry): # add filename to registry file_description = "Simulation Parameters" - PathRegistry.map_path(filename, filename, file_description) + self.path_registry.map_path(filename, filename, file_description) - def _instruction_summary(self, query: str, PathRegistry): + def _instruction_summary(self, query: str): summary = self._prompt_summary(query) - self._save_to_file(summary, "simulation_parameters.json", PathRegistry) + self._save_to_file(summary, "simulation_parameters.json") return summary def _setup_simulation_from_json(self, file_name): @@ -287,7 +304,7 @@ def _setup_simulation_from_json(self, file_name): params = json.load(f) return params - def _setup_and_run_simulation(self, query, PathRegistry): + def _setup_and_run_simulation(self, query): # Load the force field # ask for inputs from the user params = self._setup_simulation_from_json(query) @@ -325,8 +342,8 @@ def _setup_and_run_simulation(self, query, PathRegistry): # adding forcefield to registry # Load the PDB file - cleantools = CleaningTools() - pdbfile = cleantools._extract_path(params["File Path"]) + CleaningTools(self.path_registry) + pdbfile = self.path_registry.get_mapped_path(params["File Path"]) name = pdbfile.split(".")[0] end = pdbfile.split(".")[1] if end == "pdb": @@ -430,12 +447,12 @@ def _setup_and_run_simulation(self, query, PathRegistry): # add filenames to registry file_name1 = "simulation_trajectory.pdb" file_description1 = "Simulation PDB, containing the simulation trajectory" - PathRegistry.map_path(file_name1, f"{name}.pdb", file_description1) + self.path_registry.map_path(file_name1, f"{name}.pdb", file_description1) file_name2 = "simulation_data.csv" file_description2 = ( "Simulation Data, containing step, potential energy, and temperature" ) - PathRegistry.map_path(file_name2, f"{name}.csv", file_description2) + self.path_registry.map_path(file_name2, f"{name}.csv", file_description2) return simulation @@ -456,8 +473,7 @@ def _extract_parameters_path(self): class SetUpAndRunTool(BaseTool): name = "SetUpAndRunTool" - description = """This tool can only run after InstructionSummary - This tool will set up the simulation objects + description = """This tool will set up the simulation objects and run the simulation. It will ask for the parameters path. input: json file @@ -477,7 +493,7 @@ def _run(self, query: str) -> str: try: if self.path_registry is None: # this should not happen return "Registry not initialized" - sim_fxns = SimulationFunctions() + sim_fxns = SimulationFunctions(path_registry=self.path_registry) parameters = sim_fxns._extract_parameters_path() except ValueError as e: @@ -497,7 +513,7 @@ def _run(self, query: str) -> str: self.log("Are you sure you want to run the simulation? (y/n)") response = input("yes or no: ") if response.lower() in ["yes", "y"]: - sim_fxns._setup_and_run_simulation(parameters, self.path_registry) + sim_fxns._setup_and_run_simulation(parameters) else: return "Simulation interrupted due to human input" return "Simulation Completed, simulation trajectory and data files saved." @@ -513,51 +529,6 @@ async def _arun(self, query: str) -> str: raise NotImplementedError("custom_search does not support async") -class InstructionSummary(BaseTool): - name = "Instruction Summary" - description = """This tool will summarize the instructions - given by the human. This is the first tool you will - use, unless you dont have a .cif or .pdb file in - which case you have to download one first. - Input: Instructions or original query. - Output: Summary of instructions""" - path_registry: Optional[PathRegistry] - - def __init__( - self, - path_registry: Optional[PathRegistry], - ): - super().__init__() - self.path_registry = path_registry - - def _run(self, query: str) -> str: - # first check if there is any .cif or .pdb files in the directory - # if there is, then ask for instructions - if self.path_registry is None: # this should not happen - return "Registry not initialized" - files = os.listdir(".") - pdb_cif_files = [f for f in files if f.endswith(".pdb") or f.endswith(".cif")] - pdb_cif_files_tidy = [ - f - for f in files - if (f.endswith(".pdb") or f.endswith(".cif")) and "tidy" in f - ] - if len(pdb_cif_files_tidy) != 0: - path = pdb_cif_files_tidy[0] - else: - path = pdb_cif_files[0] - sim_fxns = SimulationFunctions() - summary = sim_fxns._prompt_summary(query + "the pdbfile is" + path) - sim_fxns._save_to_file( - summary, "simulation_parameters_summary.json", self.path_registry - ) - return summary - - async def _arun(self, query: str) -> str: - """Use the tool asynchronously.""" - raise NotImplementedError("custom_search does not support async") - - #######==================System Configuration==================######## # System Configuration class SetUpandRunFunctionInput(BaseModel): @@ -1747,39 +1718,6 @@ async def _arun(self, query: str) -> str: raise NotImplementedError("custom_search does not support async") -########==================Integrator==================######## -# TODO integrate this functions into the OPENMMsimulation class -# Integrator -def _define_integrator( - integrator_type="LangevinMiddle", - temperature=300 * kelvin, - friction=1.0 / picoseconds, - timestep=0.004 * picoseconds, - **kwargs, -): - # Create a dictionary to hold integrator parameters - integrator_params = { - "temperature": temperature, - "friction": friction, - "timestep": timestep, - } - - # Update integrator_params with any additional parameters provided - integrator_params.update(kwargs) - - # Create the integrator - if integrator_type == "LangevinMiddle": - integrator = LangevinMiddleIntegrator(**integrator_params) - elif integrator_type == "Verlet": - integrator = VerletIntegrator(**integrator_params) - elif integrator_type == "Brownian": - integrator = BrownianIntegrator(**integrator_params) - else: - raise Exception("Integrator type not recognized") - - return integrator - - def create_simulation_input(pdb_path, forcefield_files): """ This function takes a PDB file path and a list of forcefield files. diff --git a/mdagent/tools/base_tools/util_tools/git_issues_tool.py b/mdagent/tools/base_tools/util_tools/git_issues_tool.py index fd7aeefa..10b8deee 100644 --- a/mdagent/tools/base_tools/util_tools/git_issues_tool.py +++ b/mdagent/tools/base_tools/util_tools/git_issues_tool.py @@ -2,7 +2,6 @@ import requests import tiktoken -from langchain.base_language import BaseLanguageModel from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from langchain.tools import BaseTool @@ -14,10 +13,18 @@ class GitToolFunctions: """Class to store the functions of the tool.""" - """chain that can be used the tools for summarization or classification""" - llm_ = _make_llm(model="gpt-3.5-turbo-16k", temp=0.05, verbose=False) - - def _prompt_summary(self, query: str, output: str, llm: BaseLanguageModel = llm_): + def __init__( + self, + model: str = "gpt-3.5-turbo-16k", + temp: float = 0.05, + verbose: bool = False, + ): + self.model = model + self.temp = temp + self.verbose = verbose + self.llm = _make_llm(model=self.model, temp=self.temp, verbose=self.verbose) + + def _prompt_summary(self, query: str, output: str): prompt_template = """You're receiving the following github issues and comments. They come after looking for issues in the openmm repo for the query: {query}. @@ -47,7 +54,7 @@ def _prompt_summary(self, query: str, output: str, llm: BaseLanguageModel = llm_ prompt = PromptTemplate( template=prompt_template, input_variables=["query", "output"] ) - llm_chain = LLMChain(prompt=prompt, llm=llm) + llm_chain = LLMChain(prompt=prompt, llm=self.llm) return llm_chain.run({"query": query, "output": output}) diff --git a/mdagent/tools/maketools.py b/mdagent/tools/maketools.py index 15933aed..14511a44 100644 --- a/mdagent/tools/maketools.py +++ b/mdagent/tools/maketools.py @@ -9,14 +9,12 @@ from langchain.embeddings.openai import OpenAIEmbeddings from langchain.tools import BaseTool, StructuredTool from langchain.vectorstores import Chroma -from langchain_experimental.tools import PythonREPLTool from pydantic import BaseModel, Field from mdagent.subagents import Iterator, SubAgentInitializer, SubAgentSettings from mdagent.utils import PathRegistry, _make_llm from .base_tools import ( - CheckDirectoryFiles, CleaningToolFunction, ListRegistryPaths, ModifyBaseSimulationScriptTool, @@ -67,7 +65,7 @@ def make_all_tools( path_instance = PathRegistry.get_instance() # get instance first if llm: all_tools += agents.load_tools(["llm-math"], llm) - all_tools += [PythonREPLTool()] # or PythonREPLTool(llm=llm)? + # all_tools += [PythonREPLTool()] all_tools += [ ModifyBaseSimulationScriptTool(path_registry=path_instance, llm=llm) ] @@ -79,18 +77,15 @@ def make_all_tools( # add base tools base_tools = [ CleaningToolFunction(path_registry=path_instance), - CheckDirectoryFiles(), ListRegistryPaths(path_registry=path_instance), - # MapPath2Name(path_registry=path_instance), ProteinName2PDBTool(path_registry=path_instance), PackMolTool(path_registry=path_instance), SmallMolPDB(path_registry=path_instance), VisualizeProtein(path_registry=path_instance), - PPIDistance(), - RMSDCalculator(), + PPIDistance(path_registry=path_instance), + RMSDCalculator(path_registry=path_instance), SetUpandRunFunction(path_registry=path_instance), - ModifyBaseSimulationScriptTool(path_registry=path_instance, llm=llm), - SimulationOutputFigures(), + SimulationOutputFigures(path_registry=path_instance), ] if subagent_settings is None: subagent_settings = SubAgentSettings(path_registry=path_instance) diff --git a/tests/test_fxns.py b/tests/test_fxns.py index 19b528e9..852de22f 100644 --- a/tests/test_fxns.py +++ b/tests/test_fxns.py @@ -12,8 +12,9 @@ VisFunctions, get_pdb, ) -from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv -from mdagent.tools.base_tools.preprocess_tools.pdb_tools import MolPDB, PackMolTool +from mdagent.tools.base_tools.analysis_tools.plot_tools import PlottingTools +from mdagent.tools.base_tools.preprocess_tools.packing import PackMolTool +from mdagent.tools.base_tools.preprocess_tools.pdb_get import MolPDB from mdagent.utils import FileType, PathRegistry warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") @@ -37,44 +38,46 @@ def path_to_cif(): @pytest.fixture -def cleaning_fxns(): - return CleaningTools() +def fibronectin(): + return "fibronectin pdb" @pytest.fixture -def molpdb(): - return MolPDB() +def get_registry(): + return PathRegistry() -# Test simulation tools @pytest.fixture -def sim_fxns(): - return SimulationFunctions() +def sim_fxns(get_registry): + return SimulationFunctions(get_registry) -# Test visualization tools @pytest.fixture -def vis_fxns(): - return VisFunctions() +def plotting_tools(get_registry): + return PlottingTools(get_registry) -# Test MD utility tools @pytest.fixture -def fibronectin(): - return "fibronectin pdb" +def vis_fxns(get_registry): + return VisFunctions(get_registry) @pytest.fixture -def get_registry(): - return PathRegistry() +def packmol(get_registry): + return PackMolTool(get_registry) @pytest.fixture -def packmol(get_registry): - return PackMolTool(get_registry) +def molpdb(get_registry): + return MolPDB(get_registry) -def test_process_csv(): +@pytest.fixture +def cleaning_fxns(get_registry): + return CleaningTools(get_registry) + + +def test_process_csv(plotting_tools): mock_csv_content = "Time,Value1,Value2\n1,10,20\n2,15,25" mock_reader = MagicMock() mock_reader.fieldnames = ["Time", "Value1", "Value2"] @@ -84,19 +87,23 @@ def test_process_csv(): {"Time": "2", "Value1": "15", "Value2": "25"}, ] ) - + plotting_tools.file_path = "mock_file.csv" + plotting_tools.file_name = "mock_file.csv" with patch("builtins.open", mock_open(read_data=mock_csv_content)): with patch("csv.DictReader", return_value=mock_reader): - data, headers, matched_headers = process_csv("mock_file.csv") - - assert headers == ["Time", "Value1", "Value2"] - assert len(matched_headers) == 1 - assert matched_headers[0][1] == "Time" - assert len(data) == 2 - assert data[0]["Time"] == "1" and data[0]["Value1"] == "10" + plotting_tools.process_csv() + + assert plotting_tools.headers == ["Time", "Value1", "Value2"] + assert len(plotting_tools.matched_headers) == 1 + assert plotting_tools.matched_headers[0][1] == "Time" + assert len(plotting_tools.data) == 2 + assert ( + plotting_tools.data[0]["Time"] == "1" + and plotting_tools.data[0]["Value1"] == "10" + ) -def test_plot_data(): +def test_plot_data(plotting_tools): # Test successful plot generation data_success = [ {"Time": "1", "Value1": "10", "Value2": "20"}, @@ -112,7 +119,10 @@ def test_plot_data(): ), patch( "matplotlib.pyplot.close" ): - created_plots = plot_data(data_success, headers, matched_headers) + plotting_tools.data = data_success + plotting_tools.headers = headers + plotting_tools.matched_headers = matched_headers + created_plots = plotting_tools.plot_data() assert "time_vs_value1.png" in created_plots assert "time_vs_value2.png" in created_plots @@ -122,8 +132,12 @@ def test_plot_data(): {"Time": "2", "Value1": "C", "Value2": "D"}, ] + plotting_tools.data = data_failure + plotting_tools.headers = headers + plotting_tools.matched_headers = matched_headers + with pytest.raises(Exception) as excinfo: - plot_data(data_failure, headers, matched_headers) + plotting_tools.plot_data() assert "All plots failed due to non-numeric data." in str(excinfo.value) @@ -133,14 +147,29 @@ def test_run_molrender(path_to_cif, vis_fxns): assert result == "Visualization created" -def test_create_notebook(path_to_cif, vis_fxns, get_registry): - result = vis_fxns.create_notebook(path_to_cif, get_registry) +def test_find_png(vis_fxns): + vis_fxns.starting_files = os.listdir(".") + test_file = "test_image.png" + with open(test_file, "w") as f: + f.write("") + png_files = vis_fxns._find_png() + assert test_file in png_files + + os.remove(test_file) + + +def test_create_notebook(path_to_cif, vis_fxns): + result = vis_fxns.create_notebook(path_to_cif) + path_to_notebook = path_to_cif.split(".")[0] + "_vis.ipynb" + os.remove(path_to_notebook) assert result == "Visualization Complete" -def test_add_hydrogens_and_remove_water(path_to_cif, cleaning_fxns, get_registry): - result = cleaning_fxns._add_hydrogens_and_remove_water(path_to_cif, get_registry) - assert "Cleaned File" in result # just want to make sur the function ran +def test_add_hydrogens_and_remove_water(path_to_cif, cleaning_fxns): + result = cleaning_fxns._add_hydrogens_and_remove_water(path_to_cif) + path_to_cleaned_file = "tidy_" + path_to_cif + os.remove(path_to_cleaned_file) + assert "Cleaned File" in result @patch("os.path.exists") @@ -302,14 +331,14 @@ def test_map_path(): assert result == "Path successfully mapped to name: new_name" -def test_small_molecule_pdb(molpdb, get_registry): +def test_small_molecule_pdb(molpdb): # Test with a valid SMILES string valid_smiles = "C1=CC=CC=C1" # Benzene expected_output = ( "PDB file for C1=CC=CC=C1 successfully created and saved to " "files/pdb/benzene.pdb." ) - assert molpdb.small_molecule_pdb(valid_smiles, get_registry) == expected_output + assert molpdb.small_molecule_pdb(valid_smiles) == expected_output assert os.path.exists("files/pdb/benzene.pdb") os.remove("files/pdb/benzene.pdb") # Clean up @@ -319,26 +348,23 @@ def test_small_molecule_pdb(molpdb, get_registry): expected_output = ( "There was an error getting pdb. Please input a single molecule name." ) - assert molpdb.small_molecule_pdb(invalid_smiles, get_registry) == expected_output - assert molpdb.small_molecule_pdb(invalid_name, get_registry) == expected_output + assert molpdb.small_molecule_pdb(invalid_smiles) == expected_output + assert molpdb.small_molecule_pdb(invalid_name) == expected_output # test with valid molecule name valid_name = "water" expected_output = ( "PDB file for water successfully created and " "saved to files/pdb/water.pdb." ) - assert molpdb.small_molecule_pdb(valid_name, get_registry) == expected_output + assert molpdb.small_molecule_pdb(valid_name) == expected_output assert os.path.exists("files/pdb/water.pdb") os.remove("files/pdb/water.pdb") # Clean up def test_packmol_sm_download_called(packmol): - path_registry = PathRegistry() - path_registry._remove_path_from_json("water") - path_registry._remove_path_from_json("benzene") - path_registry.map_path("1A3N_144150", "files/pdb/1A3N_144150.pdb", "pdb") + packmol.path_registry.map_path("1A3N_144150", "files/pdb/1A3N_144150.pdb", "pdb") with patch( - "mdagent.tools.base_tools.preprocess_tools.pdb_tools.PackMolTool._get_sm_pdbs", + "mdagent.tools.base_tools.preprocess_tools.packing.PackMolTool._get_sm_pdbs", new=MagicMock(), ) as mock_get_sm_pdbs: test_values = { @@ -358,9 +384,8 @@ def test_packmol_sm_download_called(packmol): def test_packmol_download_only(packmol): - path_registry = PathRegistry() - path_registry._remove_path_from_json("water") - path_registry._remove_path_from_json("benzene") + packmol.path_registry._remove_path_from_json("water") + packmol.path_registry._remove_path_from_json("benzene") small_molecules = ["water", "benzene"] packmol._get_sm_pdbs(small_molecules) assert os.path.exists("files/pdb/water.pdb") @@ -370,8 +395,7 @@ def test_packmol_download_only(packmol): def test_packmol_download_only_once(packmol): - path_registry = PathRegistry() - path_registry._remove_path_from_json("water") + packmol.path_registry._remove_path_from_json("water") small_molecules = ["water"] packmol._get_sm_pdbs(small_molecules) assert os.path.exists("files/pdb/water.pdb")