Skip to content

Commit

Permalink
Improving path registry (#66)
Browse files Browse the repository at this point in the history
* adding path registry utils and change in get pdb: write_file_name, map_path, getid

* Change Clean/scripting/download file tools to use path_registry

* added tests to writefilename and map path
  • Loading branch information
Jgmedina95 authored Jan 10, 2024
1 parent 16b69dc commit 501bd82
Show file tree
Hide file tree
Showing 11 changed files with 2,244 additions and 66 deletions.
7 changes: 6 additions & 1 deletion mdagent/tools/base_tools/analysis_tools/plot_tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import csv
import re
from typing import Optional

import matplotlib.pyplot as plt
from langchain.tools import BaseTool

from mdagent.utils import PathRegistry


def process_csv(file_name):
with open(file_name, "r") as f:
Expand Down Expand Up @@ -64,13 +67,15 @@ def plot_data(data, headers, matched_headers):
class SimulationOutputFigures(BaseTool):
name = "PostSimulationFigures"
description = """This tool will take
a csv file output from an openmm
a csv file id output from an openmm
simulation and create figures for
all physical parameters
versus timestep of the simulation.
Give this tool the path to the
csv file output from the simulation."""

path_registry: Optional[PathRegistry]

def _run(self, file_path: str) -> str:
"""use the tool."""
try:
Expand Down
72 changes: 47 additions & 25 deletions mdagent/tools/base_tools/preprocess_tools/clean_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pdbfixer import PDBFixer
from pydantic import BaseModel, Field, root_validator

from mdagent.utils import PathRegistry
from mdagent.utils import FileType, PathRegistry


class CleaningTools:
Expand Down Expand Up @@ -226,7 +226,7 @@ async def _arun(self, query: str) -> str:
class CleaningToolFunctionInput(BaseModel):
"""Input model for CleaningToolFunction"""

pdb_path: str = Field(..., description="Path to PDB or CIF file")
pdb_id: str = Field(..., description="ID of the pdb/cif file in the path registry")
output_path: Optional[str] = Field(..., description="Path to the output file")
replace_nonstandard_residues: bool = Field(
True, description="Whether to replace nonstandard residues with standard ones. "
Expand Down Expand Up @@ -277,10 +277,10 @@ def _run(self, **input_args) -> str:
input_args = input_args["input_args"]
else:
input_args = input_args
pdbfile_path = input_args.get("pdb_path", None)
if pdbfile_path is None:
return """No file path provided.
The input has to be a dictionary with the key 'pdb_path'"""
pdbfile_id = input_args.get("pdb_id", None)
if pdbfile_id is None:
return """No file was provided.
The input has to be a dictionary with the key 'pdb_id'"""
remove_heterogens = input_args.get("remove_heterogens", True)
remove_water = input_args.get("remove_water", True)
add_hydrogens = input_args.get("add_hydrogens", True)
Expand All @@ -289,17 +289,23 @@ def _run(self, **input_args) -> str:
"replace_nonstandard_residues", True
)
add_missing_atoms = input_args.get("add_missing_atoms", True)
output_path = input_args.get("output_path", None)
input_args.get("output_path", None)

if self.path_registry is None:
return "Path registry not initialized"
file_description = "Cleaned File: "
clean_tools = CleaningTools()
pdbfile = clean_tools._extract_path(pdbfile_path, self.path_registry)
name = pdbfile.split(".")[0]
end = pdbfile.split(".")[1]
CleaningTools()
try:
pdbfile = self.path_registry.get_mapped_path(pdbfile_id)
if "/" in pdbfile:
pdbfile_name = pdbfile.split("/")[-1]
name = pdbfile_name.split("_")[0]
end = pdbfile_name.split(".")[1]
print(f"pdbfile: {pdbfile}", f"name: {name}", f"end: {end}")
except Exception as e:
print(f"error retrieving from path_registry, trying to read file {e}")
return "File not found in path registry. "
fixer = PDBFixer(filename=pdbfile)

try:
fixer.findMissingResidues()
except Exception:
Expand All @@ -321,6 +327,7 @@ def _run(self, **input_args) -> str:
try:
if replace_nonstandard_residues:
fixer.replaceNonstandardResidues()
file_description += " Replaced Nonstandard Residues. "
except Exception:
print("error at replaceNonstandardResidues")
try:
Expand All @@ -343,26 +350,41 @@ def _run(self, **input_args) -> str:
"Missing Atoms Added and replaces nonstandard residues. "
)
file_mode = "w" if add_hydrogens else "a"
if output_path:
file_name = output_path
else:
version = 1
while os.path.exists(f"tidy_{name}v{version}.{end}"):
version += 1

file_name = f"tidy_{name}v{version}.{end}"

file_name = self.path_registry.write_file_name(
type=FileType.PROTEIN,
protein_name=name,
description="Clean",
file_format=end,
)
file_id = self.path_registry.get_fileid(file_name, FileType.PROTEIN)
# if output_path:
# file_name = output_path
# else:
# version = 1
# while os.path.exists(f"tidy_{name}v{version}.{end}"):
# version += 1
#
# file_name = f"tidy_{name}v{version}.{end}"
directory = "files/pdb"
if not os.path.exists(directory):
os.makedirs(directory)
if end == "pdb":
PDBFile.writeFile(
fixer.topology, fixer.positions, open(file_name, file_mode)
fixer.topology,
fixer.positions,
open(f"{directory}/{file_name}", file_mode),
)
elif end == "cif":
PDBxFile.writeFile(
fixer.topology, fixer.positions, open(file_name, file_mode)
fixer.topology,
fixer.positions,
open(f"{directory}/{file_name}", file_mode),
)

self.path_registry.map_path(file_name, file_name, file_description)
return f"{file_description} written to {file_name}"
self.path_registry.map_path(
file_id, f"{directory}/{file_name}", file_description
)
return f"{file_id} written to {directory}/{file_name}"
except FileNotFoundError:
return "Check your file path. File not found."
except Exception as e:
Expand Down
34 changes: 24 additions & 10 deletions mdagent/tools/base_tools/preprocess_tools/pdb_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pdbfixer import PDBFixer
from pydantic import BaseModel, Field, ValidationError, root_validator

from mdagent.utils import PathRegistry
from mdagent.utils import FileType, PathRegistry


def get_pdb(query_string, path_registry=None):
Expand Down Expand Up @@ -41,13 +41,22 @@ def get_pdb(query_string, path_registry=None):
print(f"PDB file found with this ID: {pdbid}")
url = f"https://files.rcsb.org/download/{pdbid}.{filetype}"
pdb = requests.get(url)
filename = f"{pdbid}.{filetype}"
with open(filename, "w") as file:
filename = path_registry.write_file_name(
FileType.PROTEIN,
protein_name=pdbid,
description="raw",
file_format=filetype,
)
file_id = path_registry.get_fileid(filename, FileType.PROTEIN)
directory = "files/pdb"
# Create the directory if it does not exist
if not os.path.exists(directory):
os.makedirs(directory)

with open(f"{directory}/{filename}", "w") as file:
file.write(pdb.text)
print(f"{filename} is created.")
file_description = f"PDB file downloaded from RSCB, PDB ID: {pdbid}"
path_registry.map_path(filename, filename, file_description)
return filename

return filename, file_id
return None


Expand All @@ -73,11 +82,16 @@ def _run(self, query: str) -> str:
try:
if self.path_registry is None: # this should not happen
return "Path registry not initialized"
pdb = get_pdb(query, self.path_registry)
if pdb is None:
filename, pdbfile_id = get_pdb(query, self.path_registry)
if pdbfile_id is None:
return "Name2PDB tool failed to find and download PDB file."
else:
return f"Name2PDB tool successfully downloaded the PDB file: {pdb}"
self.path_registry.map_path(
pdbfile_id,
f"files/pdb/{filename}",
f"PDB file downloaded from RSCB, PDBFile ID: {pdbfile_id}",
)
return f"Name2PDB tool successful. downloaded the PDB file:{pdbfile_id}"
except Exception as e:
return f"Something went wrong. {e}"

Expand Down
35 changes: 27 additions & 8 deletions mdagent/tools/base_tools/simulation_tools/create_simulation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import textwrap
from typing import Optional

Expand Down Expand Up @@ -131,8 +132,13 @@ def remove_leading_spaces(self, text):


class ModifyScriptInput(BaseModel):
query: str = Field(..., description="Simmulation required by the user")
script: str = Field(..., description=" path to the base script file")
query: str = Field(
...,
description="""Simmulation required by the user.You MUST
specify the objective, requirements of the simulation as well
as on what protein you are working.""",
)
script: str = Field(..., description=" simulation ID of the base script file")


class ModifyBaseSimulationScriptTool(BaseTool):
Expand All @@ -150,10 +156,17 @@ def __init__(self, path_registry: Optional[PathRegistry], llm: BaseLanguageModel
self.llm = llm

def _run(self, **input):
base_script_path = input.get("script")
if not base_script_path:
return """No script provided. The keys for the input are:
base_script_id = input.get("script")
if not base_script_id:
return """No id provided. The keys for the input are:
'query' and 'script'"""
try:
base_script_path = self.path_registry.get_mapped_path(base_script_id)
parts = base_script_path.split("/")
if len(parts) > 1:
parts[-1]
except Exception as e:
return f"Error getting path from file id: {e}"
with open(base_script_path, "r") as file:
base_script = file.read()
base_script = "".join(base_script)
Expand All @@ -172,11 +185,17 @@ def _run(self, **input):
script_content = script_content.replace("```", "#")
script_content = textwrap.dedent(script_content).strip()
# Write to file
filename = "modified_simul.py"
with open(filename, "w") as file:
filename = self.path_registry.write_file_name(
type="SIMULATION", Sim_id=base_script_id, modified=True
)
file_id = self.path_registry.get_fileid(filename, type="SIMULATION")
directory = "files/simulations"
if not os.path.exists(directory):
os.makedirs(directory)
with open(f"{directory}/{filename}", "w") as file:
file.write(script_content)

self.path_registry.map_path(filename, filename, description)
self.path_registry.map_path(file_id, filename, description)
return "Script modified successfully"

async def _arun(self, query) -> str:
Expand Down
Loading

0 comments on commit 501bd82

Please sign in to comment.