Skip to content

Commit

Permalink
fix test_setup + fix rdg tool init
Browse files Browse the repository at this point in the history
  • Loading branch information
Jgmedina95 committed Mar 18, 2024
1 parent 4fc8d16 commit 31863d1
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 15 deletions.
9 changes: 4 additions & 5 deletions mdagent/tools/base_tools/analysis_tools/rdf_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ class RDFToolInput(BaseModel):
description="Selections for RDF. Do not use for now. As "
"it will only calculate RDF for protein and water molecules.",
)
# atom_indices: Optional[List[int]] = Field(
# None, description="Atom indices to load in the trajectory"
# )

# TODO: Add pairs of atoms to calculate RDF within the tool
##pairs: Optional[str] = Field(None, description="Pairs of atoms to calculate RDF ")

Expand All @@ -45,8 +43,9 @@ class RDFTool(BaseTool):
args_schema = RDFToolInput
path_registry: Optional[PathRegistry]

# def __init__(self, path_registry: PathRegistry):
# self.path_registry = path_registry
def __init__(self, path_registry: PathRegistry):
super().__init__()
self.path_registry = path_registry

def _run(self, input):
try:
Expand Down
10 changes: 2 additions & 8 deletions mdagent/tools/base_tools/preprocess_tools/clean_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from typing import Dict, Optional, Type
from typing import Optional, Type

from langchain.tools import BaseTool
from openmm.app import PDBFile, PDBxFile
from pdbfixer import PDBFixer
from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, Field

from mdagent.utils import FileType, PathRegistry

Expand Down Expand Up @@ -227,12 +227,6 @@ class CleaningToolFunctionInput(BaseModel):
)
add_hydrogens_ph: int = Field(7.0, description="pH at which hydrogens are added.")

@root_validator(skip_on_failure=True)
def validate_query(cls, values) -> Dict:
"""Check that the input is valid."""

return values


class CleaningToolFunction(BaseTool):
name = "CleaningToolFunction"
Expand Down
1 change: 0 additions & 1 deletion mdagent/tools/base_tools/preprocess_tools/pdb_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,6 @@ class PDBFilesFixInp(BaseModel):
@root_validator(skip_on_failure=True)
def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict:
if isinstance(values, str):
print("values is a string", values)
raise ValidationError("Input must be a dictionary")

pdbfile = values.get("pdbfiles", "")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@pytest.fixture
def setupandrun(get_registry):
return SetUpandRunFunction(get_registry("raw", False))
return SetUpandRunFunction(path_registry=get_registry("raw", False))


def test_parse_cutoff(setupandrun):
Expand Down

0 comments on commit 31863d1

Please sign in to comment.