diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py index 7233b199..00e8a457 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py @@ -471,10 +471,11 @@ def __init__(self, path_registry: typing.Optional[PathRegistry]): self.path_registry = path_registry def _get_sm_pdbs(self, small_molecules): + all_files = self.path_registry.list_path_names() + print(all_files) for molecule in small_molecules: # check path registry for molecule.pdb - exists = self.path_registry._check_json_content(molecule) - if not exists: + if molecule not in all_files: # download molecule using small_molecule_pdb from MolPDB molpdb = MolPDB() molpdb.small_molecule_pdb(molecule, self.path_registry) @@ -1575,7 +1576,7 @@ def smiles2name(self, smi: str) -> str: return "Unknown Molecule" return name - def small_molecule_pdb(self, mol_str: str, path_registry=None) -> str: + def small_molecule_pdb(self, mol_str: str, path_registry) -> str: # takes in molecule name or smiles (converts to smiles if name) # writes pdb file name.pdb (gets name from smiles if possible) # output is done message @@ -1598,7 +1599,7 @@ def small_molecule_pdb(self, mol_str: str, path_registry=None) -> str: Chem.MolToPDBFile(m, file_name) # add to path registry if path_registry: - _ = path_registry.map_math( + _ = path_registry.map_path( file_name, file_name, f"pdb file for the small molecule {mol_name}" ) return ( diff --git a/tests/test_fxns.py b/tests/test_fxns.py index d4427749..92629d0d 100644 --- a/tests/test_fxns.py +++ b/tests/test_fxns.py @@ -1,5 +1,6 @@ import json import os +import time import warnings from unittest.mock import MagicMock, mock_open, patch @@ -294,13 +295,13 @@ def test_map_path(): assert result == "Path successfully mapped to name: new_name" -def test_small_molecule_pdb(molpdb): +def test_small_molecule_pdb(molpdb, get_registry): # Test with a valid SMILES string valid_smiles = "C1=CC=CC=C1" # Benzene expected_output = ( "PDB file for C1=CC=CC=C1 successfully created and saved to benzene.pdb." ) - assert molpdb.small_molecule_pdb(valid_smiles) == expected_output + assert molpdb.small_molecule_pdb(valid_smiles, get_registry) == expected_output assert os.path.exists("benzene.pdb") os.remove("benzene.pdb") # Clean up @@ -310,13 +311,13 @@ def test_small_molecule_pdb(molpdb): expected_output = ( "There was an error getting pdb. Please input a single molecule name." ) - assert molpdb.small_molecule_pdb(invalid_smiles) == expected_output - assert molpdb.small_molecule_pdb(invalid_name) == expected_output + assert molpdb.small_molecule_pdb(invalid_smiles, get_registry) == expected_output + assert molpdb.small_molecule_pdb(invalid_name, get_registry) == expected_output # test with valid molecule name valid_name = "water" expected_output = "PDB file for water successfully created and saved to water.pdb." - assert molpdb.small_molecule_pdb(valid_name) == expected_output + assert molpdb.small_molecule_pdb(valid_name, get_registry) == expected_output assert os.path.exists("water.pdb") os.remove("water.pdb") # Clean up @@ -342,9 +343,30 @@ def test_packmol_sm_download_called(packmol): def test_packmol_download_only(packmol): + path_registry = PathRegistry() + path_registry._remove_path_from_json("water.pdb") + path_registry._remove_path_from_json("benzene.pdb") small_molecules = ["water", "benzene"] packmol._get_sm_pdbs(small_molecules) assert os.path.exists("water.pdb") assert os.path.exists("benzene.pdb") os.remove("water.pdb") os.remove("benzene.pdb") + + +def test_packmol_download_only_once(packmol): + path_registry = PathRegistry() + path_registry._remove_path_from_json("water.pdb") + small_molecules = ["water"] + packmol._get_sm_pdbs(small_molecules) + assert os.path.exists("water.pdb") + water_time = os.path.getmtime("water.pdb") + time.sleep(5) + + # Call the function again with the same molecule + packmol._get_sm_pdbs(small_molecules) + water_time_after = os.path.getmtime("water.pdb") + + assert water_time == water_time_after + # Clean up + os.remove("water.pdb")