From 8a19fed6bf4edfeea082f1ac74b0cf57b0580c12 Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Wed, 21 Feb 2024 13:04:50 -0800 Subject: [PATCH 01/12] split test files up --- tests/test_sims_and_clean.py | 186 +++++++++++++++++++++++++++++++++++ tests/test_tools.py | 118 ++++++++++++++++++++++ tests/test_utils.py | 155 +++++++++++++++++++++++++++++ 3 files changed, 459 insertions(+) create mode 100644 tests/test_sims_and_clean.py create mode 100644 tests/test_tools.py create mode 100644 tests/test_utils.py diff --git a/tests/test_sims_and_clean.py b/tests/test_sims_and_clean.py new file mode 100644 index 00000000..d4d55f0c --- /dev/null +++ b/tests/test_sims_and_clean.py @@ -0,0 +1,186 @@ +import os +import time +import warnings +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +from mdagent.tools.base_tools import ( + CleaningTools, + SimulationFunctions, +) +from mdagent.tools.base_tools.preprocess_tools.pdb_tools import MolPDB, PackMolTool +from mdagent.utils import PathRegistry + +warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") + + +@pytest.fixture +def path_to_cif(): + # Save original working directory + original_cwd = os.getcwd() + + # Change current working directory to the directory where the CIF file is located + tests_dir = os.path.dirname(os.path.abspath(__file__)) + os.chdir(tests_dir) + + # Yield the filename only + filename_only = "3pqr.cif" + yield filename_only + + # Restore original working directory after the test is done + os.chdir(original_cwd) + + +@pytest.fixture +def cleaning_fxns(): + return CleaningTools() + + +@pytest.fixture +def molpdb(): + return MolPDB() + + +# Test simulation tools +@pytest.fixture +def sim_fxns(): + return SimulationFunctions() + + +@pytest.fixture +def get_registry(): + return PathRegistry() + + +@pytest.fixture +def packmol(get_registry): + return PackMolTool(get_registry) + + + +def test_add_hydrogens_and_remove_water(path_to_cif, cleaning_fxns, get_registry): + result = cleaning_fxns._add_hydrogens_and_remove_water(path_to_cif, get_registry) + assert "Cleaned File" in result # just want to make sur the function ran + + +@patch("os.path.exists") +@patch("os.listdir") +def test_extract_parameters_path(mock_listdir, mock_exists, sim_fxns): + # Test when parameters.json exists + mock_exists.return_value = True + assert sim_fxns._extract_parameters_path() == "simulation_parameters_summary.json" + mock_exists.assert_called_once_with("simulation_parameters_summary.json") + mock_exists.reset_mock() # Reset the mock for the next scenario + + # Test when parameters.json does not exist, but some_parameters.json does + mock_exists.return_value = False + mock_listdir.return_value = ["some_parameters.json", "other_file.txt"] + assert sim_fxns._extract_parameters_path() == "some_parameters.json" + + # Test when no appropriate file exists + mock_listdir.return_value = ["other_file.json", "other_file.txt"] + with pytest.raises(ValueError) as e: + sim_fxns._extract_parameters_path() + assert str(e.value) == "No parameters.json file found in directory." + + +@patch( + "builtins.open", + new_callable=mock_open, + read_data='{"param1": "value1", "param2": "value2"}', +) +@patch("json.load") +def test_setup_simulation_from_json(mock_json_load, mock_file_open, sim_fxns): + # Define the mock behavior for json.load + mock_json_load.return_value = {"param1": "value1", "param2": "value2"} + params = sim_fxns._setup_simulation_from_json("test_file.json") + mock_file_open.assert_called_once_with("test_file.json", "r") + mock_json_load.assert_called_once() + assert params == {"param1": "value1", "param2": "value2"} + + +def test_small_molecule_pdb(molpdb, get_registry): + # Test with a valid SMILES string + valid_smiles = "C1=CC=CC=C1" # Benzene + expected_output = ( + "PDB file for C1=CC=CC=C1 successfully created and saved to " + "files/pdb/benzene.pdb." + ) + assert molpdb.small_molecule_pdb(valid_smiles, get_registry) == expected_output + assert os.path.exists("files/pdb/benzene.pdb") + os.remove("files/pdb/benzene.pdb") # Clean up + + # test with invalid SMILES string and invalid molecule name + invalid_smiles = "C1=CC=CC=C1X" + invalid_name = "NotAMolecule" + expected_output = ( + "There was an error getting pdb. Please input a single molecule name." + ) + assert molpdb.small_molecule_pdb(invalid_smiles, get_registry) == expected_output + assert molpdb.small_molecule_pdb(invalid_name, get_registry) == expected_output + + # test with valid molecule name + valid_name = "water" + expected_output = ( + "PDB file for water successfully created and " "saved to files/pdb/water.pdb." + ) + assert molpdb.small_molecule_pdb(valid_name, get_registry) == expected_output + assert os.path.exists("files/pdb/water.pdb") + os.remove("files/pdb/water.pdb") # Clean up + + +def test_packmol_sm_download_called(packmol): + path_registry = PathRegistry() + path_registry._remove_path_from_json("water") + path_registry._remove_path_from_json("benzene") + path_registry.map_path("1A3N_144150", "files/pdb/1A3N_144150.pdb", "pdb") + with patch( + "mdagent.tools.base_tools.preprocess_tools.pdb_tools.PackMolTool._get_sm_pdbs", + new=MagicMock(), + ) as mock_get_sm_pdbs: + test_values = { + "pdbfiles_id": ["1A3N_144150"], + "small_molecules": ["water", "benzene"], + "number_of_molecules": [1, 10, 10], + "instructions": [ + ["inside box 0. 0. 0. 100. 100. 100."], + ["inside box 0. 0. 0. 100. 100. 100."], + ["inside box 0. 0. 0. 100. 100. 100."], + ], + } + + packmol._run(**test_values) + + mock_get_sm_pdbs.assert_called_with(["water", "benzene"]) + + +def test_packmol_download_only(packmol): + path_registry = PathRegistry() + path_registry._remove_path_from_json("water") + path_registry._remove_path_from_json("benzene") + small_molecules = ["water", "benzene"] + packmol._get_sm_pdbs(small_molecules) + assert os.path.exists("files/pdb/water.pdb") + assert os.path.exists("files/pdb/benzene.pdb") + os.remove("files/pdb/water.pdb") + os.remove("files/pdb/benzene.pdb") + + +def test_packmol_download_only_once(packmol): + path_registry = PathRegistry() + path_registry._remove_path_from_json("water") + small_molecules = ["water"] + packmol._get_sm_pdbs(small_molecules) + assert os.path.exists("files/pdb/water.pdb") + water_time = os.path.getmtime("files/pdb/water.pdb") + time.sleep(5) + + # Call the function again with the same molecule + packmol._get_sm_pdbs(small_molecules) + water_time_after = os.path.getmtime("files/pdb/water.pdb") + + assert water_time == water_time_after + # Clean up + os.remove("files/pdb/water.pdb") + diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 00000000..2347b332 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,118 @@ +import os +import warnings +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +from mdagent.tools.base_tools import ( + VisFunctions, + get_pdb, +) +from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv +from mdagent.utils import PathRegistry + +warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") + + +@pytest.fixture +def path_to_cif(): + # Save original working directory + original_cwd = os.getcwd() + + # Change current working directory to the directory where the CIF file is located + tests_dir = os.path.dirname(os.path.abspath(__file__)) + os.chdir(tests_dir) + + # Yield the filename only + filename_only = "3pqr.cif" + yield filename_only + + # Restore original working directory after the test is done + os.chdir(original_cwd) + + +# Test visualization tools +@pytest.fixture +def vis_fxns(): + return VisFunctions() + + +# Test MD utility tools +@pytest.fixture +def fibronectin(): + return "fibronectin pdb" + + +@pytest.fixture +def get_registry(): + return PathRegistry() + + +def test_process_csv(): + mock_csv_content = "Time,Value1,Value2\n1,10,20\n2,15,25" + mock_reader = MagicMock() + mock_reader.fieldnames = ["Time", "Value1", "Value2"] + mock_reader.__iter__.return_value = iter( + [ + {"Time": "1", "Value1": "10", "Value2": "20"}, + {"Time": "2", "Value1": "15", "Value2": "25"}, + ] + ) + + with patch("builtins.open", mock_open(read_data=mock_csv_content)): + with patch("csv.DictReader", return_value=mock_reader): + data, headers, matched_headers = process_csv("mock_file.csv") + + assert headers == ["Time", "Value1", "Value2"] + assert len(matched_headers) == 1 + assert matched_headers[0][1] == "Time" + assert len(data) == 2 + assert data[0]["Time"] == "1" and data[0]["Value1"] == "10" + + +def test_plot_data(): + # Test successful plot generation + data_success = [ + {"Time": "1", "Value1": "10", "Value2": "20"}, + {"Time": "2", "Value1": "15", "Value2": "25"}, + ] + headers = ["Time", "Value1", "Value2"] + matched_headers = [(0, "Time")] + + with patch("matplotlib.pyplot.figure"), patch("matplotlib.pyplot.plot"), patch( + "matplotlib.pyplot.xlabel" + ), patch("matplotlib.pyplot.ylabel"), patch("matplotlib.pyplot.title"), patch( + "matplotlib.pyplot.savefig" + ), patch( + "matplotlib.pyplot.close" + ): + created_plots = plot_data(data_success, headers, matched_headers) + assert "time_vs_value1.png" in created_plots + assert "time_vs_value2.png" in created_plots + + # Test failure due to non-numeric data + data_failure = [ + {"Time": "1", "Value1": "A", "Value2": "B"}, + {"Time": "2", "Value1": "C", "Value2": "D"}, + ] + + with pytest.raises(Exception) as excinfo: + plot_data(data_failure, headers, matched_headers) + assert "All plots failed due to non-numeric data." in str(excinfo.value) + + +@pytest.mark.skip(reason="molrender is not pip installable") +def test_run_molrender(path_to_cif, vis_fxns): + result = vis_fxns.run_molrender(path_to_cif) + assert result == "Visualization created" + + +def test_create_notebook(path_to_cif, vis_fxns, get_registry): + result = vis_fxns.create_notebook(path_to_cif, get_registry) + assert result == "Visualization Complete" + + +def test_getpdb(fibronectin, get_registry): + name, _ = get_pdb(fibronectin, get_registry) + assert name.endswith(".pdb") + diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..9fb09f9d --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,155 @@ +import json +import warnings +from unittest.mock import mock_open, patch + +import pytest + +from mdagent.utils import FileType, PathRegistry + +warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") + + +@pytest.fixture +def path_registry(): + registry = PathRegistry() + registry.get_timestamp = lambda: "20240109" + return registry + + +def test_write_to_file(): + path_registry = PathRegistry() + + with patch("builtins.open", mock_open()): + file_name = path_registry.write_file_name( + FileType.PROTEIN, + protein_name="1XYZ", + description="testing", + file_format="pdb", + ) + # assert file name starts and ends correctly + assert file_name.startswith("1XYZ") + assert file_name.endswith(".pdb") + + +def test_write_file_name_protein(path_registry): + file_name = path_registry.write_file_name( + FileType.PROTEIN, protein_name="1XYZ", description="testing", file_format="pdb" + ) + assert file_name == "1XYZ_testing_20240109.pdb" + + +def test_write_file_name_simulation_with_conditions(path_registry): + file_name = path_registry.write_file_name( + FileType.SIMULATION, + type_of_sim="MD", + protein_file_id="1XYZ", + conditions="pH7", + time_stamp="20240109", + ) + assert file_name == "MD_1XYZ_pH7_20240109.py" + + +def test_write_file_name_simulation_modified(path_registry): + file_name = path_registry.write_file_name( + FileType.SIMULATION, Sim_id="SIM456", modified=True, time_stamp="20240109" + ) + assert file_name == "SIM456_MOD_20240109.py" + + +def test_write_file_name_simulation_default(path_registry): + file_name = path_registry.write_file_name( + FileType.SIMULATION, + type_of_sim="MD", + protein_file_id="123", + time_stamp="20240109", + ) + assert file_name == "MD_123_20240109.py" + + +def test_write_file_name_record(path_registry): + file_name = path_registry.write_file_name( + FileType.RECORD, + record_type="REC", + protein_file_id="123", + Sim_id="SIM456", + term="dcd", + time_stamp="20240109", + ) + assert file_name == "REC_SIM456_123_20240109.dcd" + + +def test_map_path(): + mock_json_data = { + "existing_name": { + "path": "existing/path", + "name": "path", + "description": "Existing description", + } + } + new_path_dict = { + "new_name": { + "path": "new/path", + "name": "path", + "description": "New description", + } + } + updated_json_data = {**mock_json_data, **new_path_dict} + + path_registry = PathRegistry() + path_registry.json_file_path = "dummy_json_file.json" + + # Mocking os.path.exists to simulate the JSON file existence + with patch("os.path.exists", return_value=True): + # Mocking open for both reading and writing the JSON file + with patch( + "builtins.open", mock_open(read_data=json.dumps(mock_json_data)) + ) as mocked_file: + # Optionally, you can mock internal methods if needed + with patch.object( + path_registry, "_check_for_json", return_value=True + ), patch.object( + path_registry, "_check_json_content", return_value=True + ), patch.object( + path_registry, "_get_full_path", return_value="new/path" + ): # Mocking _get_full_path + result = path_registry.map_path( + "new_name", "new/path", "New description" + ) + # Aggregating all calls to write into a single string + written_data = "".join( + call.args[0] for call in mocked_file().write.call_args_list + ) + + # Comparing the aggregated data with the expected JSON data + assert json.loads(written_data) == updated_json_data + + # Check the result message + assert result == "Path successfully mapped to name: new_name" + +mocked_files = {"files/solvents": ["water.pdb"]} + + +def mock_exists(path): + return path in mocked_files + + +def mock_listdir(path): + return mocked_files.get(path, []) + + +@pytest.fixture +def path_registry_with_mocked_fs(): + with patch("os.path.exists", side_effect=mock_exists): + with patch("os.listdir", side_effect=mock_listdir): + registry = PathRegistry() + registry.get_timestamp = lambda: "20240109" + return registry + + +def test_init_path_registry(path_registry_with_mocked_fs): + # This test will run with the mocked file system + # Here, you can assert if 'water.pdb' under 'solvents' is registered correctly + # Depending on how your PathRegistry class stores the registry, + # you may need to check the internal state or the contents of the JSON file. + # For example: + assert "water_000000" in path_registry_with_mocked_fs.list_path_names() From f73a6623f7a5074ee7a1748edd08384cacc64f54 Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Wed, 21 Feb 2024 14:28:46 -0800 Subject: [PATCH 02/12] added pqa literature search tool --- .env.example | 3 - .github/workflows/build.yml | 4 +- .github/workflows/tests.yml | 1 - .secrets.baseline | 3 - README.md | 4 +- dev-requirements.txt | 1 + .../base_tools/util_tools/search_tools.py | 83 ++++++++++++++----- mdagent/tools/maketools.py | 10 +-- setup.py | 9 +- tests/test_sims_and_clean.py | 7 +- tests/test_tools.py | 28 +++++-- tests/test_utils.py | 1 + 12 files changed, 98 insertions(+), 56 deletions(-) diff --git a/.env.example b/.env.example index 1f150faf..e4767a97 100644 --- a/.env.example +++ b/.env.example @@ -4,8 +4,5 @@ # OpenAI API Key OPENAI_API_KEY=YOUR_OPENAI_API_KEY_GOES_HERE # pragma: allowlist secret -# PQA API Key -PQA_API_KEY=YOUR_PQA_API_KEY_GOES_HERE # pragma: allowlist secret - # Serp API key SERP_API_KEY=YOUR_SERP_API_KEY_GOES_HERE # pragma: allowlist secret diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 48bcd37d..78ba8d5a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -13,10 +13,10 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python "3.9" + - name: Set up Python "3.11" uses: actions/setup-python@v2 with: - python-version: "3.9" + python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c3ac46bc..18a4f5cc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -45,6 +45,5 @@ jobs: env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} SEMANTIC_SCHOLAR_API_KEY: ${{ secrets.SEMANTIC_SCHOLAR_API_KEY }} - PQA_API_KEY : ${{ secrets.PQA_API_TOKEN }} run: | pytest -m "not skip" tests diff --git a/.secrets.baseline b/.secrets.baseline index b1809030..56e5786e 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -3,8 +3,5 @@ # Rule for detecting OpenAI API keys OpenAI API Key: \b[secrets]{3}_[a-zA-Z0-9]{32}\b -# Rule for detecting pqa API keys -PQA API Key: "pqa[a-zA-Z0-9-._]+" - # Rule for detecting serp API keys # Serp API Key: "[a-zA-Z0-9]{64}" diff --git a/README.md b/README.md index e87e97e1..7a3373b4 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ To use the OpenMM features in the agent, please set up a conda environment, foll - Create conda environment: `conda env create -n mdagent -f environment.yaml` - Activate your environment: `conda activate mdagent` -If you already have a conda environment, you can install the necessary dependencies with the following steps. -- Install the necessary conda dependencies: `conda install -c conda-forge openmm pdbfixer mdanalysis` +If you already have a conda environment, you can install, pdbfixer, a necessary dependency with the following steps. +- Install the necessary conda dependencies: `conda install -c conda-forge pdbfixer` ## Installation diff --git a/dev-requirements.txt b/dev-requirements.txt index 51f1982a..bfd6bc64 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,2 +1,3 @@ pre-commit pytest +pytest-mock diff --git a/mdagent/tools/base_tools/util_tools/search_tools.py b/mdagent/tools/base_tools/util_tools/search_tools.py index 3c6e32d3..e12ec8ba 100644 --- a/mdagent/tools/base_tools/util_tools/search_tools.py +++ b/mdagent/tools/base_tools/util_tools/search_tools.py @@ -1,26 +1,69 @@ -import pqapi -from langchain.tools import BaseTool +import os +import re +from langchain.base_language import BaseLanguageModel +import langchain +import paperqa +import paperscraper +from pypdf.errors import PdfReadError -class Scholar2ResultLLM(BaseTool): - name = "LiteratureSearch" - description = """Input a specific question, - returns an answer from literature search.""" +def paper_scraper(search:str, pdir:str="query") -> dict: + try: + return paperscraper.search_papers(search, pdir=pdir) + except KeyError: + return {} + +def paper_search(llm, query): + prompt = langchain.prompts.PromptTemplate( + input_variables=["question"], + template=""" + I would like to find scholarly papers to answer + this question: {question}. + 'A search query that would bring up papers that can answer + this question would be: '""",) + + query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt) + if not os.path.isdir("./query"): #todo: move to ckpt + os.mkdir("query/") - pqa_key: str = "" + search = query_chain.run(query) + print("\nSearch:", search) + papers = paper_scraper(search, pdir=f"query/{re.sub(' ', '', search)}") + return papers - def __init__(self, pqa_key: str): - super().__init__() - self.pqa_key = pqa_key - def _run(self, question: str) -> str: - """Use the tool""" +def scholar2result_llm(llm, query): + """Useful to answer questions that require + technical knowledge. Ask a specific question.""" + papers = paper_search(llm, query) + if len(papers) == 0: + return "Not enough papers found" + docs = paperqa.Docs(llm=llm) + not_loaded = 0 + for path, data in papers.items(): try: - response = pqapi.agent_query("default", question) - return response.answer - except Exception: - return "Literature search failed." - - async def _arun(self, question: str) -> str: - """Use the tool asynchronously""" - raise NotImplementedError + docs.add(path, data["citation"]) + except (ValueError, FileNotFoundError, PdfReadError): + not_loaded += 1 + + print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}") + return docs.query(query).formatted_answer + + +class Scholar2ResultLLM: + name = "Literature Search" + description = ( + "Useful to answer questions that require technical ", + "knowledge. Ask a specific question.", + ) + llm: BaseLanguageModel + + def __init__(self, llm): + self.llm = llm + + def _run(self, query) -> str: + return scholar2result_llm(self.llm, query) + + async def _arun(self, query) -> str: + """Use the tool asynchronously.""" + raise NotImplementedError("this tool does not support async") \ No newline at end of file diff --git a/mdagent/tools/maketools.py b/mdagent/tools/maketools.py index 15933aed..f1d890a9 100644 --- a/mdagent/tools/maketools.py +++ b/mdagent/tools/maketools.py @@ -29,6 +29,7 @@ SimulationOutputFigures, SmallMolPDB, VisualizeProtein, + Scholar2ResultLLM, ) from .subagent_tools import RetryExecuteSkill, SkillRetrieval, WorkflowPlan @@ -78,6 +79,7 @@ def make_all_tools( # add base tools base_tools = [ + Scholar2ResultLLM(llm=llm), CleaningToolFunction(path_registry=path_instance), CheckDirectoryFiles(), ListRegistryPaths(path_registry=path_instance), @@ -113,14 +115,6 @@ def make_all_tools( learned_tools = get_learned_tools(subagent_settings.ckpt_dir) all_tools += base_tools + subagents_tools + learned_tools - - # add other tools depending on api keys - os.getenv("SERP_API_KEY") - pqa_key = os.getenv("PQA_API_KEY") - # if serp_key: - # all_tools.append(SerpGitTool(serp_key)) # github issues search - if pqa_key: - all_tools.append(Scholar2ResultLLM(pqa_key)) # literature search return all_tools diff --git a/setup.py b/setup.py index 474d02e8..6564c056 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,6 @@ license="MIT", packages=find_packages(), install_requires=[ - "paper-scraper @ git+https://github.com/blackadad/paper-scraper.git", "chromadb==0.3.29", "google-search-results", "langchain==0.0.336", @@ -25,14 +24,14 @@ "matplotlib", "nbformat", "openai", - "paper-qa", - "python-dotenv", - "pqapi", "requests", - "rmrkl", "tiktoken", "rdkit", "streamlit", + "paper-qa", + "openmm", + "MDAnalysis", + "paper-scraper @ git+https://github.com/blackadad/paper-scraper.git", ], test_suite="tests", long_description=long_description, diff --git a/tests/test_sims_and_clean.py b/tests/test_sims_and_clean.py index d4d55f0c..7b9db991 100644 --- a/tests/test_sims_and_clean.py +++ b/tests/test_sims_and_clean.py @@ -5,10 +5,7 @@ import pytest -from mdagent.tools.base_tools import ( - CleaningTools, - SimulationFunctions, -) +from mdagent.tools.base_tools import CleaningTools, SimulationFunctions from mdagent.tools.base_tools.preprocess_tools.pdb_tools import MolPDB, PackMolTool from mdagent.utils import PathRegistry @@ -58,7 +55,6 @@ def packmol(get_registry): return PackMolTool(get_registry) - def test_add_hydrogens_and_remove_water(path_to_cif, cleaning_fxns, get_registry): result = cleaning_fxns._add_hydrogens_and_remove_water(path_to_cif, get_registry) assert "Cleaned File" in result # just want to make sur the function ran @@ -183,4 +179,3 @@ def test_packmol_download_only_once(packmol): assert water_time == water_time_after # Clean up os.remove("files/pdb/water.pdb") - diff --git a/tests/test_tools.py b/tests/test_tools.py index 2347b332..1e9259fe 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,13 +1,10 @@ import os import warnings from unittest.mock import MagicMock, mock_open, patch - +from langchain.chat_models import ChatOpenAI import pytest - -from mdagent.tools.base_tools import ( - VisFunctions, - get_pdb, -) +from mdagent.tools.base_tools import Scholar2ResultLLM +from mdagent.tools.base_tools import VisFunctions, get_pdb from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv from mdagent.utils import PathRegistry @@ -116,3 +113,22 @@ def test_getpdb(fibronectin, get_registry): name, _ = get_pdb(fibronectin, get_registry) assert name.endswith(".pdb") +@pytest.fixture +def questions(): + qs = [ + "What are the effects of norhalichondrin B in mammals?", + ] + return qs[0] + +@pytest.mark.skip(reason="This requires an API call") +def test_litsearch(questions): + llm = ChatOpenAI() + + searchtool = Scholar2ResultLLM(llm=llm) + for q in questions: + ans = searchtool._run(q) + assert isinstance(ans, str) + assert len(ans) > 0 + #then if query folder exists one step back, delete it + if os.path.exists("../query"): + os.rmdir("../query") \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py index 9fb09f9d..400d64fd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -126,6 +126,7 @@ def test_map_path(): # Check the result message assert result == "Path successfully mapped to name: new_name" + mocked_files = {"files/solvents": ["water.pdb"]} From dfac7447f2bb0f0c44bd690ba94bcd51295cf72f Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Wed, 21 Feb 2024 14:45:12 -0800 Subject: [PATCH 03/12] few updates to lit search tool --- .../base_tools/util_tools/search_tools.py | 19 +++++++++++-------- mdagent/tools/maketools.py | 1 - tests/test_tools.py | 12 +++++++----- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/mdagent/tools/base_tools/util_tools/search_tools.py b/mdagent/tools/base_tools/util_tools/search_tools.py index e12ec8ba..fd442755 100644 --- a/mdagent/tools/base_tools/util_tools/search_tools.py +++ b/mdagent/tools/base_tools/util_tools/search_tools.py @@ -1,18 +1,20 @@ import os import re -from langchain.base_language import BaseLanguageModel + import langchain import paperqa import paperscraper +from langchain.base_language import BaseLanguageModel from pypdf.errors import PdfReadError -def paper_scraper(search:str, pdir:str="query") -> dict: +def paper_scraper(search: str, pdir: str = "query") -> dict: try: return paperscraper.search_papers(search, pdir=pdir) except KeyError: return {} - + + def paper_search(llm, query): prompt = langchain.prompts.PromptTemplate( input_variables=["question"], @@ -20,10 +22,11 @@ def paper_search(llm, query): I would like to find scholarly papers to answer this question: {question}. 'A search query that would bring up papers that can answer - this question would be: '""",) - + this question would be: '""", + ) + query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt) - if not os.path.isdir("./query"): #todo: move to ckpt + if not os.path.isdir("./query"): # todo: move to ckpt os.mkdir("query/") search = query_chain.run(query) @@ -51,7 +54,7 @@ def scholar2result_llm(llm, query): class Scholar2ResultLLM: - name = "Literature Search" + name = "LiteratureSearch" description = ( "Useful to answer questions that require technical ", "knowledge. Ask a specific question.", @@ -66,4 +69,4 @@ def _run(self, query) -> str: async def _arun(self, query) -> str: """Use the tool asynchronously.""" - raise NotImplementedError("this tool does not support async") \ No newline at end of file + raise NotImplementedError("this tool does not support async") diff --git a/mdagent/tools/maketools.py b/mdagent/tools/maketools.py index f1d890a9..5d37b6ba 100644 --- a/mdagent/tools/maketools.py +++ b/mdagent/tools/maketools.py @@ -29,7 +29,6 @@ SimulationOutputFigures, SmallMolPDB, VisualizeProtein, - Scholar2ResultLLM, ) from .subagent_tools import RetryExecuteSkill, SkillRetrieval, WorkflowPlan diff --git a/tests/test_tools.py b/tests/test_tools.py index 1e9259fe..04ed518f 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,10 +1,11 @@ import os import warnings from unittest.mock import MagicMock, mock_open, patch -from langchain.chat_models import ChatOpenAI + import pytest -from mdagent.tools.base_tools import Scholar2ResultLLM -from mdagent.tools.base_tools import VisFunctions, get_pdb +from langchain.chat_models import ChatOpenAI + +from mdagent.tools.base_tools import Scholar2ResultLLM, VisFunctions, get_pdb from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv from mdagent.utils import PathRegistry @@ -113,6 +114,7 @@ def test_getpdb(fibronectin, get_registry): name, _ = get_pdb(fibronectin, get_registry) assert name.endswith(".pdb") + @pytest.fixture def questions(): qs = [ @@ -120,6 +122,7 @@ def questions(): ] return qs[0] + @pytest.mark.skip(reason="This requires an API call") def test_litsearch(questions): llm = ChatOpenAI() @@ -129,6 +132,5 @@ def test_litsearch(questions): ans = searchtool._run(q) assert isinstance(ans, str) assert len(ans) > 0 - #then if query folder exists one step back, delete it if os.path.exists("../query"): - os.rmdir("../query") \ No newline at end of file + os.rmdir("../query") From 00eacea2f1099c36f59a0a4583684598606ab2b0 Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Wed, 21 Feb 2024 15:20:58 -0800 Subject: [PATCH 04/12] added notebook for testing lit search --- .../base_tools/util_tools/search_tools.py | 19 +-- notebooks/lit_search.ipynb | 121 ++++++++++++++++++ 2 files changed, 132 insertions(+), 8 deletions(-) create mode 100644 notebooks/lit_search.ipynb diff --git a/mdagent/tools/base_tools/util_tools/search_tools.py b/mdagent/tools/base_tools/util_tools/search_tools.py index fd442755..2f1fca62 100644 --- a/mdagent/tools/base_tools/util_tools/search_tools.py +++ b/mdagent/tools/base_tools/util_tools/search_tools.py @@ -5,6 +5,7 @@ import paperqa import paperscraper from langchain.base_language import BaseLanguageModel +from langchain.tools import BaseTool from pypdf.errors import PdfReadError @@ -20,7 +21,8 @@ def paper_search(llm, query): input_variables=["question"], template=""" I would like to find scholarly papers to answer - this question: {question}. + this question: {question}. Your response must be at + most 10 words long. 'A search query that would bring up papers that can answer this question would be: '""", ) @@ -28,14 +30,13 @@ def paper_search(llm, query): query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt) if not os.path.isdir("./query"): # todo: move to ckpt os.mkdir("query/") - search = query_chain.run(query) print("\nSearch:", search) papers = paper_scraper(search, pdir=f"query/{re.sub(' ', '', search)}") return papers -def scholar2result_llm(llm, query): +def scholar2result_llm(llm, query, k=5, max_sources=2): """Useful to answer questions that require technical knowledge. Ask a specific question.""" papers = paper_search(llm, query) @@ -50,18 +51,20 @@ def scholar2result_llm(llm, query): not_loaded += 1 print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}") - return docs.query(query).formatted_answer + answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer + return answer -class Scholar2ResultLLM: +class Scholar2ResultLLM(BaseTool): name = "LiteratureSearch" description = ( - "Useful to answer questions that require technical ", - "knowledge. Ask a specific question.", + "Useful to answer questions that require technical " + "knowledge. Ask a specific question." ) - llm: BaseLanguageModel + llm: BaseLanguageModel = None def __init__(self, llm): + super().__init__() self.llm = llm def _run(self, query) -> str: diff --git a/notebooks/lit_search.ipynb b/notebooks/lit_search.ipynb new file mode 100644 index 00000000..185fd52d --- /dev/null +++ b/notebooks/lit_search.ipynb @@ -0,0 +1,121 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/samcox/anaconda3/envs/mda_feb21/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from mdagent import MDAgent" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "#until we update to new version\n", + "import nest_asyncio\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "mda = MDAgent()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"Are there any studies that show that the use of a mask can reduce the spread of COVID-19?\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\"Masks COVID-19 transmission reduction studies\"\n", + "Search: \"Masks COVID-19 transmission reduction studies\"\n", + "\n", + "Found 14 papers but couldn't load 0\n", + "Yes, there are studies that show that the use of a mask can reduce the spread of COVID-19. The review by Howard et al. (2021) indicates that mask-wearing reduces the transmissibility of COVID-19 by limiting the spread of infected respiratory particles. This conclusion is supported by evidence from both laboratory and clinical studies." + ] + } + ], + "source": [ + "answer = mda.run(prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Yes, there are studies that show that the use of a mask can reduce the spread of COVID-19. The review by Howard et al. (2021) indicates that mask-wearing reduces the transmissibility of COVID-19 by limiting the spread of infected respiratory particles. This conclusion is supported by evidence from both laboratory and clinical studies.'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "answer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mdagent", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 57bc29ac37fe72dbd2809b33f59ede15a1c32a6b Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Wed, 21 Feb 2024 15:22:42 -0800 Subject: [PATCH 05/12] moved openmm and mdanalysis to pip installs --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 18a4f5cc..959ba591 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,10 +23,10 @@ jobs: environment-file: environment.yaml python-version: ${{ matrix.python-version }} auto-activate-base: true - - name: Install openmm pdbfixer mdanalysis with conda + - name: Install pdbfixer with conda shell: bash -l {0} run: | - conda install -c conda-forge openmm pdbfixer mdanalysis + conda install -c conda-forge pdbfixer - name: Install dependencies shell: bash -l {0} run: | From c62153abfac3beee18c80676ef06d31b77918c7d Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Wed, 21 Feb 2024 17:51:41 -0800 Subject: [PATCH 06/12] adjusted tests --- tests/test_sims_and_clean.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/test_sims_and_clean.py b/tests/test_sims_and_clean.py index 7b9db991..1f466677 100644 --- a/tests/test_sims_and_clean.py +++ b/tests/test_sims_and_clean.py @@ -118,12 +118,10 @@ def test_small_molecule_pdb(molpdb, get_registry): # test with valid molecule name valid_name = "water" - expected_output = ( - "PDB file for water successfully created and " "saved to files/pdb/water.pdb." - ) - assert molpdb.small_molecule_pdb(valid_name, get_registry) == expected_output - assert os.path.exists("files/pdb/water.pdb") - os.remove("files/pdb/water.pdb") # Clean up + assert "successfully" in molpdb.small_molecule_pdb(valid_name, get_registry) + # assert os.path.exists("files/pdb/water.pdb") + if os.path.exists("files/pdb/water.pdb"): + os.remove("files/pdb/water.pdb") def test_packmol_sm_download_called(packmol): @@ -151,18 +149,22 @@ def test_packmol_sm_download_called(packmol): mock_get_sm_pdbs.assert_called_with(["water", "benzene"]) +@pytest.mark.skip(reason="Resume this test when ckpt is implemented") def test_packmol_download_only(packmol): path_registry = PathRegistry() path_registry._remove_path_from_json("water") path_registry._remove_path_from_json("benzene") small_molecules = ["water", "benzene"] packmol._get_sm_pdbs(small_molecules) - assert os.path.exists("files/pdb/water.pdb") - assert os.path.exists("files/pdb/benzene.pdb") - os.remove("files/pdb/water.pdb") - os.remove("files/pdb/benzene.pdb") + # assert os.path.exists("files/pdb/water.pdb") + # assert os.path.exists("files/pdb/benzene.pdb") + if os.path.exists("files/pdb/water.pdb"): + os.remove("files/pdb/water.pdb") + if os.path.exists("files/pdb/benzene.pdb"): + os.remove("files/pdb/benzene.pdb") +@pytest.mark.skip(reason="Resume this test when ckpt is implemented") def test_packmol_download_only_once(packmol): path_registry = PathRegistry() path_registry._remove_path_from_json("water") From 3e5dacc2339f13a825236b856133b893a9853d07 Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Thu, 22 Feb 2024 13:04:19 -0800 Subject: [PATCH 07/12] deleted extra test file --- tests/test_fxns.py | 416 --------------------------------------------- 1 file changed, 416 deletions(-) delete mode 100644 tests/test_fxns.py diff --git a/tests/test_fxns.py b/tests/test_fxns.py deleted file mode 100644 index 19b528e9..00000000 --- a/tests/test_fxns.py +++ /dev/null @@ -1,416 +0,0 @@ -import json -import os -import time -import warnings -from unittest.mock import MagicMock, mock_open, patch - -import pytest - -from mdagent.tools.base_tools import ( - CleaningTools, - SimulationFunctions, - VisFunctions, - get_pdb, -) -from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv -from mdagent.tools.base_tools.preprocess_tools.pdb_tools import MolPDB, PackMolTool -from mdagent.utils import FileType, PathRegistry - -warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") - - -@pytest.fixture -def path_to_cif(): - # Save original working directory - original_cwd = os.getcwd() - - # Change current working directory to the directory where the CIF file is located - tests_dir = os.path.dirname(os.path.abspath(__file__)) - os.chdir(tests_dir) - - # Yield the filename only - filename_only = "3pqr.cif" - yield filename_only - - # Restore original working directory after the test is done - os.chdir(original_cwd) - - -@pytest.fixture -def cleaning_fxns(): - return CleaningTools() - - -@pytest.fixture -def molpdb(): - return MolPDB() - - -# Test simulation tools -@pytest.fixture -def sim_fxns(): - return SimulationFunctions() - - -# Test visualization tools -@pytest.fixture -def vis_fxns(): - return VisFunctions() - - -# Test MD utility tools -@pytest.fixture -def fibronectin(): - return "fibronectin pdb" - - -@pytest.fixture -def get_registry(): - return PathRegistry() - - -@pytest.fixture -def packmol(get_registry): - return PackMolTool(get_registry) - - -def test_process_csv(): - mock_csv_content = "Time,Value1,Value2\n1,10,20\n2,15,25" - mock_reader = MagicMock() - mock_reader.fieldnames = ["Time", "Value1", "Value2"] - mock_reader.__iter__.return_value = iter( - [ - {"Time": "1", "Value1": "10", "Value2": "20"}, - {"Time": "2", "Value1": "15", "Value2": "25"}, - ] - ) - - with patch("builtins.open", mock_open(read_data=mock_csv_content)): - with patch("csv.DictReader", return_value=mock_reader): - data, headers, matched_headers = process_csv("mock_file.csv") - - assert headers == ["Time", "Value1", "Value2"] - assert len(matched_headers) == 1 - assert matched_headers[0][1] == "Time" - assert len(data) == 2 - assert data[0]["Time"] == "1" and data[0]["Value1"] == "10" - - -def test_plot_data(): - # Test successful plot generation - data_success = [ - {"Time": "1", "Value1": "10", "Value2": "20"}, - {"Time": "2", "Value1": "15", "Value2": "25"}, - ] - headers = ["Time", "Value1", "Value2"] - matched_headers = [(0, "Time")] - - with patch("matplotlib.pyplot.figure"), patch("matplotlib.pyplot.plot"), patch( - "matplotlib.pyplot.xlabel" - ), patch("matplotlib.pyplot.ylabel"), patch("matplotlib.pyplot.title"), patch( - "matplotlib.pyplot.savefig" - ), patch( - "matplotlib.pyplot.close" - ): - created_plots = plot_data(data_success, headers, matched_headers) - assert "time_vs_value1.png" in created_plots - assert "time_vs_value2.png" in created_plots - - # Test failure due to non-numeric data - data_failure = [ - {"Time": "1", "Value1": "A", "Value2": "B"}, - {"Time": "2", "Value1": "C", "Value2": "D"}, - ] - - with pytest.raises(Exception) as excinfo: - plot_data(data_failure, headers, matched_headers) - assert "All plots failed due to non-numeric data." in str(excinfo.value) - - -@pytest.mark.skip(reason="molrender is not pip installable") -def test_run_molrender(path_to_cif, vis_fxns): - result = vis_fxns.run_molrender(path_to_cif) - assert result == "Visualization created" - - -def test_create_notebook(path_to_cif, vis_fxns, get_registry): - result = vis_fxns.create_notebook(path_to_cif, get_registry) - assert result == "Visualization Complete" - - -def test_add_hydrogens_and_remove_water(path_to_cif, cleaning_fxns, get_registry): - result = cleaning_fxns._add_hydrogens_and_remove_water(path_to_cif, get_registry) - assert "Cleaned File" in result # just want to make sur the function ran - - -@patch("os.path.exists") -@patch("os.listdir") -def test_extract_parameters_path(mock_listdir, mock_exists, sim_fxns): - # Test when parameters.json exists - mock_exists.return_value = True - assert sim_fxns._extract_parameters_path() == "simulation_parameters_summary.json" - mock_exists.assert_called_once_with("simulation_parameters_summary.json") - mock_exists.reset_mock() # Reset the mock for the next scenario - - # Test when parameters.json does not exist, but some_parameters.json does - mock_exists.return_value = False - mock_listdir.return_value = ["some_parameters.json", "other_file.txt"] - assert sim_fxns._extract_parameters_path() == "some_parameters.json" - - # Test when no appropriate file exists - mock_listdir.return_value = ["other_file.json", "other_file.txt"] - with pytest.raises(ValueError) as e: - sim_fxns._extract_parameters_path() - assert str(e.value) == "No parameters.json file found in directory." - - -@patch( - "builtins.open", - new_callable=mock_open, - read_data='{"param1": "value1", "param2": "value2"}', -) -@patch("json.load") -def test_setup_simulation_from_json(mock_json_load, mock_file_open, sim_fxns): - # Define the mock behavior for json.load - mock_json_load.return_value = {"param1": "value1", "param2": "value2"} - params = sim_fxns._setup_simulation_from_json("test_file.json") - mock_file_open.assert_called_once_with("test_file.json", "r") - mock_json_load.assert_called_once() - assert params == {"param1": "value1", "param2": "value2"} - - -def test_getpdb(fibronectin, get_registry): - name, _ = get_pdb(fibronectin, get_registry) - assert name.endswith(".pdb") - - -@pytest.fixture -def path_registry(): - registry = PathRegistry() - registry.get_timestamp = lambda: "20240109" - return registry - - -def test_write_to_file(): - path_registry = PathRegistry() - - with patch("builtins.open", mock_open()): - file_name = path_registry.write_file_name( - FileType.PROTEIN, - protein_name="1XYZ", - description="testing", - file_format="pdb", - ) - # assert file name starts and ends correctly - assert file_name.startswith("1XYZ") - assert file_name.endswith(".pdb") - - -def test_write_file_name_protein(path_registry): - file_name = path_registry.write_file_name( - FileType.PROTEIN, protein_name="1XYZ", description="testing", file_format="pdb" - ) - assert file_name == "1XYZ_testing_20240109.pdb" - - -def test_write_file_name_simulation_with_conditions(path_registry): - file_name = path_registry.write_file_name( - FileType.SIMULATION, - type_of_sim="MD", - protein_file_id="1XYZ", - conditions="pH7", - time_stamp="20240109", - ) - assert file_name == "MD_1XYZ_pH7_20240109.py" - - -def test_write_file_name_simulation_modified(path_registry): - file_name = path_registry.write_file_name( - FileType.SIMULATION, Sim_id="SIM456", modified=True, time_stamp="20240109" - ) - assert file_name == "SIM456_MOD_20240109.py" - - -def test_write_file_name_simulation_default(path_registry): - file_name = path_registry.write_file_name( - FileType.SIMULATION, - type_of_sim="MD", - protein_file_id="123", - time_stamp="20240109", - ) - assert file_name == "MD_123_20240109.py" - - -def test_write_file_name_record(path_registry): - file_name = path_registry.write_file_name( - FileType.RECORD, - record_type="REC", - protein_file_id="123", - Sim_id="SIM456", - term="dcd", - time_stamp="20240109", - ) - assert file_name == "REC_SIM456_123_20240109.dcd" - - -def test_map_path(): - mock_json_data = { - "existing_name": { - "path": "existing/path", - "name": "path", - "description": "Existing description", - } - } - new_path_dict = { - "new_name": { - "path": "new/path", - "name": "path", - "description": "New description", - } - } - updated_json_data = {**mock_json_data, **new_path_dict} - - path_registry = PathRegistry() - path_registry.json_file_path = "dummy_json_file.json" - - # Mocking os.path.exists to simulate the JSON file existence - with patch("os.path.exists", return_value=True): - # Mocking open for both reading and writing the JSON file - with patch( - "builtins.open", mock_open(read_data=json.dumps(mock_json_data)) - ) as mocked_file: - # Optionally, you can mock internal methods if needed - with patch.object( - path_registry, "_check_for_json", return_value=True - ), patch.object( - path_registry, "_check_json_content", return_value=True - ), patch.object( - path_registry, "_get_full_path", return_value="new/path" - ): # Mocking _get_full_path - result = path_registry.map_path( - "new_name", "new/path", "New description" - ) - # Aggregating all calls to write into a single string - written_data = "".join( - call.args[0] for call in mocked_file().write.call_args_list - ) - - # Comparing the aggregated data with the expected JSON data - assert json.loads(written_data) == updated_json_data - - # Check the result message - assert result == "Path successfully mapped to name: new_name" - - -def test_small_molecule_pdb(molpdb, get_registry): - # Test with a valid SMILES string - valid_smiles = "C1=CC=CC=C1" # Benzene - expected_output = ( - "PDB file for C1=CC=CC=C1 successfully created and saved to " - "files/pdb/benzene.pdb." - ) - assert molpdb.small_molecule_pdb(valid_smiles, get_registry) == expected_output - assert os.path.exists("files/pdb/benzene.pdb") - os.remove("files/pdb/benzene.pdb") # Clean up - - # test with invalid SMILES string and invalid molecule name - invalid_smiles = "C1=CC=CC=C1X" - invalid_name = "NotAMolecule" - expected_output = ( - "There was an error getting pdb. Please input a single molecule name." - ) - assert molpdb.small_molecule_pdb(invalid_smiles, get_registry) == expected_output - assert molpdb.small_molecule_pdb(invalid_name, get_registry) == expected_output - - # test with valid molecule name - valid_name = "water" - expected_output = ( - "PDB file for water successfully created and " "saved to files/pdb/water.pdb." - ) - assert molpdb.small_molecule_pdb(valid_name, get_registry) == expected_output - assert os.path.exists("files/pdb/water.pdb") - os.remove("files/pdb/water.pdb") # Clean up - - -def test_packmol_sm_download_called(packmol): - path_registry = PathRegistry() - path_registry._remove_path_from_json("water") - path_registry._remove_path_from_json("benzene") - path_registry.map_path("1A3N_144150", "files/pdb/1A3N_144150.pdb", "pdb") - with patch( - "mdagent.tools.base_tools.preprocess_tools.pdb_tools.PackMolTool._get_sm_pdbs", - new=MagicMock(), - ) as mock_get_sm_pdbs: - test_values = { - "pdbfiles_id": ["1A3N_144150"], - "small_molecules": ["water", "benzene"], - "number_of_molecules": [1, 10, 10], - "instructions": [ - ["inside box 0. 0. 0. 100. 100. 100."], - ["inside box 0. 0. 0. 100. 100. 100."], - ["inside box 0. 0. 0. 100. 100. 100."], - ], - } - - packmol._run(**test_values) - - mock_get_sm_pdbs.assert_called_with(["water", "benzene"]) - - -def test_packmol_download_only(packmol): - path_registry = PathRegistry() - path_registry._remove_path_from_json("water") - path_registry._remove_path_from_json("benzene") - small_molecules = ["water", "benzene"] - packmol._get_sm_pdbs(small_molecules) - assert os.path.exists("files/pdb/water.pdb") - assert os.path.exists("files/pdb/benzene.pdb") - os.remove("files/pdb/water.pdb") - os.remove("files/pdb/benzene.pdb") - - -def test_packmol_download_only_once(packmol): - path_registry = PathRegistry() - path_registry._remove_path_from_json("water") - small_molecules = ["water"] - packmol._get_sm_pdbs(small_molecules) - assert os.path.exists("files/pdb/water.pdb") - water_time = os.path.getmtime("files/pdb/water.pdb") - time.sleep(5) - - # Call the function again with the same molecule - packmol._get_sm_pdbs(small_molecules) - water_time_after = os.path.getmtime("files/pdb/water.pdb") - - assert water_time == water_time_after - # Clean up - os.remove("files/pdb/water.pdb") - - -mocked_files = {"files/solvents": ["water.pdb"]} - - -def mock_exists(path): - return path in mocked_files - - -def mock_listdir(path): - return mocked_files.get(path, []) - - -@pytest.fixture -def path_registry_with_mocked_fs(): - with patch("os.path.exists", side_effect=mock_exists): - with patch("os.listdir", side_effect=mock_listdir): - registry = PathRegistry() - registry.get_timestamp = lambda: "20240109" - return registry - - -def test_init_path_registry(path_registry_with_mocked_fs): - # This test will run with the mocked file system - # Here, you can assert if 'water.pdb' under 'solvents' is registered correctly - # Depending on how your PathRegistry class stores the registry, - # you may need to check the internal state or the contents of the JSON file. - # For example: - assert "water_000000" in path_registry_with_mocked_fs.list_path_names() From 40761c40ab14a6f3df2310dac2214ad4c586435f Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Thu, 22 Feb 2024 13:29:22 -0800 Subject: [PATCH 08/12] fixed tests with new file_path function & unit test --- mdagent/utils/__init__.py | 3 +- mdagent/utils/general_utils.py | 22 +++++++++++++++ tests/test_sims_and_clean.py | 51 +++++++++++++++++++--------------- tests/test_utils.py | 15 +++++++++- 4 files changed, 66 insertions(+), 25 deletions(-) create mode 100644 mdagent/utils/general_utils.py diff --git a/mdagent/utils/__init__.py b/mdagent/utils/__init__.py index ef0fa47b..ad59b1e4 100644 --- a/mdagent/utils/__init__.py +++ b/mdagent/utils/__init__.py @@ -1,4 +1,5 @@ +from .general_utils import find_file_path from .makellm import _make_llm from .path_registry import FileType, PathRegistry -__all__ = ["_make_llm", "PathRegistry", "FileType"] +__all__ = ["_make_llm", "PathRegistry", "FileType", "find_file_path"] diff --git a/mdagent/utils/general_utils.py b/mdagent/utils/general_utils.py new file mode 100644 index 00000000..e22764c0 --- /dev/null +++ b/mdagent/utils/general_utils.py @@ -0,0 +1,22 @@ +import os + + +def find_file_path(file_name: str, exact_match: bool = True): + """get the path of a file, if it exists in repo""" + setup_dir = None + for dirpath, dirnames, filenames in os.walk("."): + if "setup.py" in filenames: + setup_dir = dirpath + break + + if setup_dir is None: + raise FileNotFoundError("Unable to find root directory.") + + for dirpath, dirnames, filenames in os.walk(setup_dir): + for filename in filenames: + if (exact_match and filename == file_name) or ( + not exact_match and file_name in filename + ): + return os.path.join(dirpath, filename) + + return None diff --git a/tests/test_sims_and_clean.py b/tests/test_sims_and_clean.py index 1f466677..15d5c29c 100644 --- a/tests/test_sims_and_clean.py +++ b/tests/test_sims_and_clean.py @@ -7,7 +7,7 @@ from mdagent.tools.base_tools import CleaningTools, SimulationFunctions from mdagent.tools.base_tools.preprocess_tools.pdb_tools import MolPDB, PackMolTool -from mdagent.utils import PathRegistry +from mdagent.utils import PathRegistry, find_file_path warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") @@ -99,13 +99,13 @@ def test_setup_simulation_from_json(mock_json_load, mock_file_open, sim_fxns): def test_small_molecule_pdb(molpdb, get_registry): # Test with a valid SMILES string valid_smiles = "C1=CC=CC=C1" # Benzene - expected_output = ( - "PDB file for C1=CC=CC=C1 successfully created and saved to " - "files/pdb/benzene.pdb." + expected_output_success = "successfully created and saved to " + assert expected_output_success in molpdb.small_molecule_pdb( + valid_smiles, get_registry ) - assert molpdb.small_molecule_pdb(valid_smiles, get_registry) == expected_output - assert os.path.exists("files/pdb/benzene.pdb") - os.remove("files/pdb/benzene.pdb") # Clean up + file_path = find_file_path("benzene", exact_match=False) + assert file_path is not None # assert file was found + os.remove(file_path) # Clean up # test with invalid SMILES string and invalid molecule name invalid_smiles = "C1=CC=CC=C1X" @@ -113,15 +113,17 @@ def test_small_molecule_pdb(molpdb, get_registry): expected_output = ( "There was an error getting pdb. Please input a single molecule name." ) - assert molpdb.small_molecule_pdb(invalid_smiles, get_registry) == expected_output - assert molpdb.small_molecule_pdb(invalid_name, get_registry) == expected_output + assert expected_output in molpdb.small_molecule_pdb(invalid_smiles, get_registry) + assert expected_output in molpdb.small_molecule_pdb(invalid_name, get_registry) # test with valid molecule name valid_name = "water" - assert "successfully" in molpdb.small_molecule_pdb(valid_name, get_registry) - # assert os.path.exists("files/pdb/water.pdb") - if os.path.exists("files/pdb/water.pdb"): - os.remove("files/pdb/water.pdb") + assert expected_output_success in molpdb.small_molecule_pdb( + valid_name, get_registry + ) + file_path = find_file_path("water", exact_match=False) + assert file_path is not None # assert file was found + os.remove(file_path) # Clean up def test_packmol_sm_download_called(packmol): @@ -156,12 +158,14 @@ def test_packmol_download_only(packmol): path_registry._remove_path_from_json("benzene") small_molecules = ["water", "benzene"] packmol._get_sm_pdbs(small_molecules) - # assert os.path.exists("files/pdb/water.pdb") - # assert os.path.exists("files/pdb/benzene.pdb") - if os.path.exists("files/pdb/water.pdb"): - os.remove("files/pdb/water.pdb") - if os.path.exists("files/pdb/benzene.pdb"): - os.remove("files/pdb/benzene.pdb") + + water_path = find_file_path("water", exact_match=False) + assert water_path is not None + os.remove(water_path) + + benzene_path = find_file_path("benzene", exact_match=False) + assert benzene_path is not None + os.remove(benzene_path) @pytest.mark.skip(reason="Resume this test when ckpt is implemented") @@ -170,14 +174,15 @@ def test_packmol_download_only_once(packmol): path_registry._remove_path_from_json("water") small_molecules = ["water"] packmol._get_sm_pdbs(small_molecules) - assert os.path.exists("files/pdb/water.pdb") - water_time = os.path.getmtime("files/pdb/water.pdb") + path_name = find_file_path("water", exact_match=False) + assert path_name is not None + water_time = os.path.getmtime(path_name) time.sleep(5) # Call the function again with the same molecule packmol._get_sm_pdbs(small_molecules) - water_time_after = os.path.getmtime("files/pdb/water.pdb") + water_time_after = os.path.getmtime(path_name) assert water_time == water_time_after # Clean up - os.remove("files/pdb/water.pdb") + os.remove(path_name) diff --git a/tests/test_utils.py b/tests/test_utils.py index 400d64fd..c193b207 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,11 @@ import json +import os import warnings from unittest.mock import mock_open, patch import pytest -from mdagent.utils import FileType, PathRegistry +from mdagent.utils import FileType, PathRegistry, find_file_path warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") @@ -154,3 +155,15 @@ def test_init_path_registry(path_registry_with_mocked_fs): # you may need to check the internal state or the contents of the JSON file. # For example: assert "water_000000" in path_registry_with_mocked_fs.list_path_names() + + +def test_find_file_path(): + file_name = "test_utils.py" + file_path_current = os.path.abspath(file_name) + file_path_test = find_file_path(file_name, exact_match=True) + assert file_path_current == file_path_test + + file_name_short = file_name[-4] + file_path_current_short = os.path.abspath(file_name_short) + file_path_test_short = find_file_path(file_name_short, exact_match=False) + assert file_path_current_short == file_path_test_short From e05111c8701ad04b3a080bd6391525d66e7e0232 Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Thu, 22 Feb 2024 14:02:59 -0800 Subject: [PATCH 09/12] fixed unit test for file_path --- mdagent/tools/base_tools/preprocess_tools/pdb_tools.py | 9 +++++++++ mdagent/utils/general_utils.py | 4 +++- tests/test_utils.py | 9 +++++---- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py index 7c9a5b2c..191de14e 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py @@ -1406,6 +1406,15 @@ def molname2smiles( "One possible cause is that the input is incorrect, " "input one molecule at a time." ) + except IndexError: + print("The smiles property was not found for this molecule.") + except TypeError: + print( + "The information from the PubChem database is " + "not in the expected format." + ) + except ValueError: + print("The requested value is not in the expected format.") # remove salts return Chem.CanonSmiles(self.largest_mol(smi)) diff --git a/mdagent/utils/general_utils.py b/mdagent/utils/general_utils.py index e22764c0..6a1a6652 100644 --- a/mdagent/utils/general_utils.py +++ b/mdagent/utils/general_utils.py @@ -17,6 +17,8 @@ def find_file_path(file_name: str, exact_match: bool = True): if (exact_match and filename == file_name) or ( not exact_match and file_name in filename ): - return os.path.join(dirpath, filename) + path_full = os.path.join(dirpath, filename) + # make sure its absolute + return os.path.abspath(path_full) return None diff --git a/tests/test_utils.py b/tests/test_utils.py index c193b207..3bf37ebd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -159,11 +159,12 @@ def test_init_path_registry(path_registry_with_mocked_fs): def test_find_file_path(): file_name = "test_utils.py" - file_path_current = os.path.abspath(file_name) + # current directory won't include folders + cwd = os.getcwd() + file_path_current = os.path.join(cwd, "tests", file_name) file_path_test = find_file_path(file_name, exact_match=True) assert file_path_current == file_path_test - file_name_short = file_name[-4] - file_path_current_short = os.path.abspath(file_name_short) + file_name_short = "test_util" file_path_test_short = find_file_path(file_name_short, exact_match=False) - assert file_path_current_short == file_path_test_short + assert file_path_current == file_path_test_short From ff0bcb81c76d29cb5919a7309b98c05847908171 Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Thu, 22 Feb 2024 14:20:57 -0800 Subject: [PATCH 10/12] idk what i did tbh --- mdagent/tools/base_tools/preprocess_tools/pdb_tools.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py index 191de14e..d82e506a 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py @@ -1416,6 +1416,8 @@ def molname2smiles( except ValueError: print("The requested value is not in the expected format.") # remove salts + except Exception as e: + print(f"An error occurred: {e}") return Chem.CanonSmiles(self.largest_mol(smi)) def smiles2name(self, smi: str) -> str: @@ -1452,7 +1454,7 @@ def small_molecule_pdb(self, mol_str: str, path_registry) -> str: mol_name = mol_str try: # only if needed m = Chem.AddHs(m) - except Exception: # TODO: we should be more specific here + except Exception: pass Chem.AllChem.EmbedMolecule(m) file_name = f"files/pdb/{mol_name}.pdb" @@ -1465,11 +1467,7 @@ def small_molecule_pdb(self, mol_str: str, path_registry) -> str: return ( f"PDB file for {mol_str} successfully created and saved to {file_name}." ) - except Exception: # TODO: we should be more specific here - print( - "There was an error getting pdb. Please input a single molecule name." - f"{mol_str},{mol_name}, {smi}" - ) + except Exception: return ( "There was an error getting pdb. Please input a single molecule name." ) From eb47cc9688265345083904312fed3c843aa750d7 Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Mon, 26 Feb 2024 12:50:34 -0800 Subject: [PATCH 11/12] moved unit tests back --- tests/test_fxns.py | 439 +++++++++++++++++++++++++++++++++++ tests/test_sims_and_clean.py | 188 --------------- tests/test_tools.py | 136 ----------- tests/test_utils.py | 170 -------------- 4 files changed, 439 insertions(+), 494 deletions(-) create mode 100644 tests/test_fxns.py delete mode 100644 tests/test_sims_and_clean.py delete mode 100644 tests/test_tools.py delete mode 100644 tests/test_utils.py diff --git a/tests/test_fxns.py b/tests/test_fxns.py new file mode 100644 index 00000000..47d03289 --- /dev/null +++ b/tests/test_fxns.py @@ -0,0 +1,439 @@ +import json +import os +import time +import warnings +from unittest.mock import MagicMock, mock_open, patch + +import pytest +from langchain.chat_models import ChatOpenAI + +from mdagent.tools.base_tools import ( + CleaningTools, + Scholar2ResultLLM, + SimulationFunctions, + VisFunctions, + get_pdb, +) +from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv +from mdagent.tools.base_tools.preprocess_tools.pdb_tools import MolPDB, PackMolTool +from mdagent.utils import FileType, PathRegistry + +warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") + + +@pytest.fixture +def path_to_cif(): + # Save original working directory + original_cwd = os.getcwd() + + # Change current working directory to the directory where the CIF file is located + tests_dir = os.path.dirname(os.path.abspath(__file__)) + os.chdir(tests_dir) + + # Yield the filename only + filename_only = "3pqr.cif" + yield filename_only + + # Restore original working directory after the test is done + os.chdir(original_cwd) + + +@pytest.fixture +def cleaning_fxns(): + return CleaningTools() + + +@pytest.fixture +def molpdb(): + return MolPDB() + + +# Test simulation tools +@pytest.fixture +def sim_fxns(): + return SimulationFunctions() + + +# Test visualization tools +@pytest.fixture +def vis_fxns(): + return VisFunctions() + + +# Test MD utility tools +@pytest.fixture +def fibronectin(): + return "fibronectin pdb" + + +@pytest.fixture +def get_registry(): + return PathRegistry() + + +@pytest.fixture +def packmol(get_registry): + return PackMolTool(get_registry) + + +def test_process_csv(): + mock_csv_content = "Time,Value1,Value2\n1,10,20\n2,15,25" + mock_reader = MagicMock() + mock_reader.fieldnames = ["Time", "Value1", "Value2"] + mock_reader.__iter__.return_value = iter( + [ + {"Time": "1", "Value1": "10", "Value2": "20"}, + {"Time": "2", "Value1": "15", "Value2": "25"}, + ] + ) + + with patch("builtins.open", mock_open(read_data=mock_csv_content)): + with patch("csv.DictReader", return_value=mock_reader): + data, headers, matched_headers = process_csv("mock_file.csv") + + assert headers == ["Time", "Value1", "Value2"] + assert len(matched_headers) == 1 + assert matched_headers[0][1] == "Time" + assert len(data) == 2 + assert data[0]["Time"] == "1" and data[0]["Value1"] == "10" + + +def test_plot_data(): + # Test successful plot generation + data_success = [ + {"Time": "1", "Value1": "10", "Value2": "20"}, + {"Time": "2", "Value1": "15", "Value2": "25"}, + ] + headers = ["Time", "Value1", "Value2"] + matched_headers = [(0, "Time")] + + with patch("matplotlib.pyplot.figure"), patch("matplotlib.pyplot.plot"), patch( + "matplotlib.pyplot.xlabel" + ), patch("matplotlib.pyplot.ylabel"), patch("matplotlib.pyplot.title"), patch( + "matplotlib.pyplot.savefig" + ), patch( + "matplotlib.pyplot.close" + ): + created_plots = plot_data(data_success, headers, matched_headers) + assert "time_vs_value1.png" in created_plots + assert "time_vs_value2.png" in created_plots + + # Test failure due to non-numeric data + data_failure = [ + {"Time": "1", "Value1": "A", "Value2": "B"}, + {"Time": "2", "Value1": "C", "Value2": "D"}, + ] + + with pytest.raises(Exception) as excinfo: + plot_data(data_failure, headers, matched_headers) + assert "All plots failed due to non-numeric data." in str(excinfo.value) + + +@pytest.mark.skip(reason="molrender is not pip installable") +def test_run_molrender(path_to_cif, vis_fxns): + result = vis_fxns.run_molrender(path_to_cif) + assert result == "Visualization created" + + +def test_create_notebook(path_to_cif, vis_fxns, get_registry): + result = vis_fxns.create_notebook(path_to_cif, get_registry) + assert result == "Visualization Complete" + + +def test_add_hydrogens_and_remove_water(path_to_cif, cleaning_fxns, get_registry): + result = cleaning_fxns._add_hydrogens_and_remove_water(path_to_cif, get_registry) + assert "Cleaned File" in result # just want to make sur the function ran + + +@patch("os.path.exists") +@patch("os.listdir") +def test_extract_parameters_path(mock_listdir, mock_exists, sim_fxns): + # Test when parameters.json exists + mock_exists.return_value = True + assert sim_fxns._extract_parameters_path() == "simulation_parameters_summary.json" + mock_exists.assert_called_once_with("simulation_parameters_summary.json") + mock_exists.reset_mock() # Reset the mock for the next scenario + + # Test when parameters.json does not exist, but some_parameters.json does + mock_exists.return_value = False + mock_listdir.return_value = ["some_parameters.json", "other_file.txt"] + assert sim_fxns._extract_parameters_path() == "some_parameters.json" + + # Test when no appropriate file exists + mock_listdir.return_value = ["other_file.json", "other_file.txt"] + with pytest.raises(ValueError) as e: + sim_fxns._extract_parameters_path() + assert str(e.value) == "No parameters.json file found in directory." + + +@patch( + "builtins.open", + new_callable=mock_open, + read_data='{"param1": "value1", "param2": "value2"}', +) +@patch("json.load") +def test_setup_simulation_from_json(mock_json_load, mock_file_open, sim_fxns): + # Define the mock behavior for json.load + mock_json_load.return_value = {"param1": "value1", "param2": "value2"} + params = sim_fxns._setup_simulation_from_json("test_file.json") + mock_file_open.assert_called_once_with("test_file.json", "r") + mock_json_load.assert_called_once() + assert params == {"param1": "value1", "param2": "value2"} + + +def test_getpdb(fibronectin, get_registry): + name, _ = get_pdb(fibronectin, get_registry) + assert name.endswith(".pdb") + + +@pytest.fixture +def path_registry(): + registry = PathRegistry() + registry.get_timestamp = lambda: "20240109" + return registry + + +def test_write_to_file(): + path_registry = PathRegistry() + + with patch("builtins.open", mock_open()): + file_name = path_registry.write_file_name( + FileType.PROTEIN, + protein_name="1XYZ", + description="testing", + file_format="pdb", + ) + # assert file name starts and ends correctly + assert file_name.startswith("1XYZ") + assert file_name.endswith(".pdb") + + +def test_write_file_name_protein(path_registry): + file_name = path_registry.write_file_name( + FileType.PROTEIN, protein_name="1XYZ", description="testing", file_format="pdb" + ) + assert file_name == "1XYZ_testing_20240109.pdb" + + +def test_write_file_name_simulation_with_conditions(path_registry): + file_name = path_registry.write_file_name( + FileType.SIMULATION, + type_of_sim="MD", + protein_file_id="1XYZ", + conditions="pH7", + time_stamp="20240109", + ) + assert file_name == "MD_1XYZ_pH7_20240109.py" + + +def test_write_file_name_simulation_modified(path_registry): + file_name = path_registry.write_file_name( + FileType.SIMULATION, Sim_id="SIM456", modified=True, time_stamp="20240109" + ) + assert file_name == "SIM456_MOD_20240109.py" + + +def test_write_file_name_simulation_default(path_registry): + file_name = path_registry.write_file_name( + FileType.SIMULATION, + type_of_sim="MD", + protein_file_id="123", + time_stamp="20240109", + ) + assert file_name == "MD_123_20240109.py" + + +def test_write_file_name_record(path_registry): + file_name = path_registry.write_file_name( + FileType.RECORD, + record_type="REC", + protein_file_id="123", + Sim_id="SIM456", + term="dcd", + time_stamp="20240109", + ) + assert file_name == "REC_SIM456_123_20240109.dcd" + + +def test_map_path(): + mock_json_data = { + "existing_name": { + "path": "existing/path", + "name": "path", + "description": "Existing description", + } + } + new_path_dict = { + "new_name": { + "path": "new/path", + "name": "path", + "description": "New description", + } + } + updated_json_data = {**mock_json_data, **new_path_dict} + + path_registry = PathRegistry() + path_registry.json_file_path = "dummy_json_file.json" + + # Mocking os.path.exists to simulate the JSON file existence + with patch("os.path.exists", return_value=True): + # Mocking open for both reading and writing the JSON file + with patch( + "builtins.open", mock_open(read_data=json.dumps(mock_json_data)) + ) as mocked_file: + # Optionally, you can mock internal methods if needed + with patch.object( + path_registry, "_check_for_json", return_value=True + ), patch.object( + path_registry, "_check_json_content", return_value=True + ), patch.object( + path_registry, "_get_full_path", return_value="new/path" + ): # Mocking _get_full_path + result = path_registry.map_path( + "new_name", "new/path", "New description" + ) + # Aggregating all calls to write into a single string + written_data = "".join( + call.args[0] for call in mocked_file().write.call_args_list + ) + + # Comparing the aggregated data with the expected JSON data + assert json.loads(written_data) == updated_json_data + + # Check the result message + assert result == "Path successfully mapped to name: new_name" + + +def test_small_molecule_pdb(molpdb, get_registry): + # Test with a valid SMILES string + valid_smiles = "C1=CC=CC=C1" # Benzene + expected_output = ( + "PDB file for C1=CC=CC=C1 successfully created and saved to " + "files/pdb/benzene.pdb." + ) + assert molpdb.small_molecule_pdb(valid_smiles, get_registry) == expected_output + assert os.path.exists("files/pdb/benzene.pdb") + os.remove("files/pdb/benzene.pdb") # Clean up + + # test with invalid SMILES string and invalid molecule name + invalid_smiles = "C1=CC=CC=C1X" + invalid_name = "NotAMolecule" + expected_output = ( + "There was an error getting pdb. Please input a single molecule name." + ) + assert molpdb.small_molecule_pdb(invalid_smiles, get_registry) == expected_output + assert molpdb.small_molecule_pdb(invalid_name, get_registry) == expected_output + + # test with valid molecule name + valid_name = "water" + expected_output = ( + "PDB file for water successfully created and " "saved to files/pdb/water.pdb." + ) + assert molpdb.small_molecule_pdb(valid_name, get_registry) == expected_output + assert os.path.exists("files/pdb/water.pdb") + os.remove("files/pdb/water.pdb") # Clean up + + +def test_packmol_sm_download_called(packmol): + path_registry = PathRegistry() + path_registry._remove_path_from_json("water") + path_registry._remove_path_from_json("benzene") + path_registry.map_path("1A3N_144150", "files/pdb/1A3N_144150.pdb", "pdb") + with patch( + "mdagent.tools.base_tools.preprocess_tools.pdb_tools.PackMolTool._get_sm_pdbs", + new=MagicMock(), + ) as mock_get_sm_pdbs: + test_values = { + "pdbfiles_id": ["1A3N_144150"], + "small_molecules": ["water", "benzene"], + "number_of_molecules": [1, 10, 10], + "instructions": [ + ["inside box 0. 0. 0. 100. 100. 100."], + ["inside box 0. 0. 0. 100. 100. 100."], + ["inside box 0. 0. 0. 100. 100. 100."], + ], + } + + packmol._run(**test_values) + + mock_get_sm_pdbs.assert_called_with(["water", "benzene"]) + + +def test_packmol_download_only(packmol): + path_registry = PathRegistry() + path_registry._remove_path_from_json("water") + path_registry._remove_path_from_json("benzene") + small_molecules = ["water", "benzene"] + packmol._get_sm_pdbs(small_molecules) + assert os.path.exists("files/pdb/water.pdb") + assert os.path.exists("files/pdb/benzene.pdb") + os.remove("files/pdb/water.pdb") + os.remove("files/pdb/benzene.pdb") + + +def test_packmol_download_only_once(packmol): + path_registry = PathRegistry() + path_registry._remove_path_from_json("water") + small_molecules = ["water"] + packmol._get_sm_pdbs(small_molecules) + assert os.path.exists("files/pdb/water.pdb") + water_time = os.path.getmtime("files/pdb/water.pdb") + time.sleep(5) + + # Call the function again with the same molecule + packmol._get_sm_pdbs(small_molecules) + water_time_after = os.path.getmtime("files/pdb/water.pdb") + + assert water_time == water_time_after + # Clean up + os.remove("files/pdb/water.pdb") + + +mocked_files = {"files/solvents": ["water.pdb"]} + + +def mock_exists(path): + return path in mocked_files + + +def mock_listdir(path): + return mocked_files.get(path, []) + + +@pytest.fixture +def path_registry_with_mocked_fs(): + with patch("os.path.exists", side_effect=mock_exists): + with patch("os.listdir", side_effect=mock_listdir): + registry = PathRegistry() + registry.get_timestamp = lambda: "20240109" + return registry + + +def test_init_path_registry(path_registry_with_mocked_fs): + # This test will run with the mocked file system + # Here, you can assert if 'water.pdb' under 'solvents' is registered correctly + # Depending on how your PathRegistry class stores the registry, + # you may need to check the internal state or the contents of the JSON file. + # For example: + assert "water_000000" in path_registry_with_mocked_fs.list_path_names() + + +@pytest.fixture +def questions(): + qs = [ + "What are the effects of norhalichondrin B in mammals?", + ] + return qs[0] + + +@pytest.mark.skip(reason="This requires an API call") +def test_litsearch(questions): + llm = ChatOpenAI() + + searchtool = Scholar2ResultLLM(llm=llm) + for q in questions: + ans = searchtool._run(q) + assert isinstance(ans, str) + assert len(ans) > 0 + if os.path.exists("../query"): + os.rmdir("../query") diff --git a/tests/test_sims_and_clean.py b/tests/test_sims_and_clean.py deleted file mode 100644 index 15d5c29c..00000000 --- a/tests/test_sims_and_clean.py +++ /dev/null @@ -1,188 +0,0 @@ -import os -import time -import warnings -from unittest.mock import MagicMock, mock_open, patch - -import pytest - -from mdagent.tools.base_tools import CleaningTools, SimulationFunctions -from mdagent.tools.base_tools.preprocess_tools.pdb_tools import MolPDB, PackMolTool -from mdagent.utils import PathRegistry, find_file_path - -warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") - - -@pytest.fixture -def path_to_cif(): - # Save original working directory - original_cwd = os.getcwd() - - # Change current working directory to the directory where the CIF file is located - tests_dir = os.path.dirname(os.path.abspath(__file__)) - os.chdir(tests_dir) - - # Yield the filename only - filename_only = "3pqr.cif" - yield filename_only - - # Restore original working directory after the test is done - os.chdir(original_cwd) - - -@pytest.fixture -def cleaning_fxns(): - return CleaningTools() - - -@pytest.fixture -def molpdb(): - return MolPDB() - - -# Test simulation tools -@pytest.fixture -def sim_fxns(): - return SimulationFunctions() - - -@pytest.fixture -def get_registry(): - return PathRegistry() - - -@pytest.fixture -def packmol(get_registry): - return PackMolTool(get_registry) - - -def test_add_hydrogens_and_remove_water(path_to_cif, cleaning_fxns, get_registry): - result = cleaning_fxns._add_hydrogens_and_remove_water(path_to_cif, get_registry) - assert "Cleaned File" in result # just want to make sur the function ran - - -@patch("os.path.exists") -@patch("os.listdir") -def test_extract_parameters_path(mock_listdir, mock_exists, sim_fxns): - # Test when parameters.json exists - mock_exists.return_value = True - assert sim_fxns._extract_parameters_path() == "simulation_parameters_summary.json" - mock_exists.assert_called_once_with("simulation_parameters_summary.json") - mock_exists.reset_mock() # Reset the mock for the next scenario - - # Test when parameters.json does not exist, but some_parameters.json does - mock_exists.return_value = False - mock_listdir.return_value = ["some_parameters.json", "other_file.txt"] - assert sim_fxns._extract_parameters_path() == "some_parameters.json" - - # Test when no appropriate file exists - mock_listdir.return_value = ["other_file.json", "other_file.txt"] - with pytest.raises(ValueError) as e: - sim_fxns._extract_parameters_path() - assert str(e.value) == "No parameters.json file found in directory." - - -@patch( - "builtins.open", - new_callable=mock_open, - read_data='{"param1": "value1", "param2": "value2"}', -) -@patch("json.load") -def test_setup_simulation_from_json(mock_json_load, mock_file_open, sim_fxns): - # Define the mock behavior for json.load - mock_json_load.return_value = {"param1": "value1", "param2": "value2"} - params = sim_fxns._setup_simulation_from_json("test_file.json") - mock_file_open.assert_called_once_with("test_file.json", "r") - mock_json_load.assert_called_once() - assert params == {"param1": "value1", "param2": "value2"} - - -def test_small_molecule_pdb(molpdb, get_registry): - # Test with a valid SMILES string - valid_smiles = "C1=CC=CC=C1" # Benzene - expected_output_success = "successfully created and saved to " - assert expected_output_success in molpdb.small_molecule_pdb( - valid_smiles, get_registry - ) - file_path = find_file_path("benzene", exact_match=False) - assert file_path is not None # assert file was found - os.remove(file_path) # Clean up - - # test with invalid SMILES string and invalid molecule name - invalid_smiles = "C1=CC=CC=C1X" - invalid_name = "NotAMolecule" - expected_output = ( - "There was an error getting pdb. Please input a single molecule name." - ) - assert expected_output in molpdb.small_molecule_pdb(invalid_smiles, get_registry) - assert expected_output in molpdb.small_molecule_pdb(invalid_name, get_registry) - - # test with valid molecule name - valid_name = "water" - assert expected_output_success in molpdb.small_molecule_pdb( - valid_name, get_registry - ) - file_path = find_file_path("water", exact_match=False) - assert file_path is not None # assert file was found - os.remove(file_path) # Clean up - - -def test_packmol_sm_download_called(packmol): - path_registry = PathRegistry() - path_registry._remove_path_from_json("water") - path_registry._remove_path_from_json("benzene") - path_registry.map_path("1A3N_144150", "files/pdb/1A3N_144150.pdb", "pdb") - with patch( - "mdagent.tools.base_tools.preprocess_tools.pdb_tools.PackMolTool._get_sm_pdbs", - new=MagicMock(), - ) as mock_get_sm_pdbs: - test_values = { - "pdbfiles_id": ["1A3N_144150"], - "small_molecules": ["water", "benzene"], - "number_of_molecules": [1, 10, 10], - "instructions": [ - ["inside box 0. 0. 0. 100. 100. 100."], - ["inside box 0. 0. 0. 100. 100. 100."], - ["inside box 0. 0. 0. 100. 100. 100."], - ], - } - - packmol._run(**test_values) - - mock_get_sm_pdbs.assert_called_with(["water", "benzene"]) - - -@pytest.mark.skip(reason="Resume this test when ckpt is implemented") -def test_packmol_download_only(packmol): - path_registry = PathRegistry() - path_registry._remove_path_from_json("water") - path_registry._remove_path_from_json("benzene") - small_molecules = ["water", "benzene"] - packmol._get_sm_pdbs(small_molecules) - - water_path = find_file_path("water", exact_match=False) - assert water_path is not None - os.remove(water_path) - - benzene_path = find_file_path("benzene", exact_match=False) - assert benzene_path is not None - os.remove(benzene_path) - - -@pytest.mark.skip(reason="Resume this test when ckpt is implemented") -def test_packmol_download_only_once(packmol): - path_registry = PathRegistry() - path_registry._remove_path_from_json("water") - small_molecules = ["water"] - packmol._get_sm_pdbs(small_molecules) - path_name = find_file_path("water", exact_match=False) - assert path_name is not None - water_time = os.path.getmtime(path_name) - time.sleep(5) - - # Call the function again with the same molecule - packmol._get_sm_pdbs(small_molecules) - water_time_after = os.path.getmtime(path_name) - - assert water_time == water_time_after - # Clean up - os.remove(path_name) diff --git a/tests/test_tools.py b/tests/test_tools.py deleted file mode 100644 index 04ed518f..00000000 --- a/tests/test_tools.py +++ /dev/null @@ -1,136 +0,0 @@ -import os -import warnings -from unittest.mock import MagicMock, mock_open, patch - -import pytest -from langchain.chat_models import ChatOpenAI - -from mdagent.tools.base_tools import Scholar2ResultLLM, VisFunctions, get_pdb -from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv -from mdagent.utils import PathRegistry - -warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") - - -@pytest.fixture -def path_to_cif(): - # Save original working directory - original_cwd = os.getcwd() - - # Change current working directory to the directory where the CIF file is located - tests_dir = os.path.dirname(os.path.abspath(__file__)) - os.chdir(tests_dir) - - # Yield the filename only - filename_only = "3pqr.cif" - yield filename_only - - # Restore original working directory after the test is done - os.chdir(original_cwd) - - -# Test visualization tools -@pytest.fixture -def vis_fxns(): - return VisFunctions() - - -# Test MD utility tools -@pytest.fixture -def fibronectin(): - return "fibronectin pdb" - - -@pytest.fixture -def get_registry(): - return PathRegistry() - - -def test_process_csv(): - mock_csv_content = "Time,Value1,Value2\n1,10,20\n2,15,25" - mock_reader = MagicMock() - mock_reader.fieldnames = ["Time", "Value1", "Value2"] - mock_reader.__iter__.return_value = iter( - [ - {"Time": "1", "Value1": "10", "Value2": "20"}, - {"Time": "2", "Value1": "15", "Value2": "25"}, - ] - ) - - with patch("builtins.open", mock_open(read_data=mock_csv_content)): - with patch("csv.DictReader", return_value=mock_reader): - data, headers, matched_headers = process_csv("mock_file.csv") - - assert headers == ["Time", "Value1", "Value2"] - assert len(matched_headers) == 1 - assert matched_headers[0][1] == "Time" - assert len(data) == 2 - assert data[0]["Time"] == "1" and data[0]["Value1"] == "10" - - -def test_plot_data(): - # Test successful plot generation - data_success = [ - {"Time": "1", "Value1": "10", "Value2": "20"}, - {"Time": "2", "Value1": "15", "Value2": "25"}, - ] - headers = ["Time", "Value1", "Value2"] - matched_headers = [(0, "Time")] - - with patch("matplotlib.pyplot.figure"), patch("matplotlib.pyplot.plot"), patch( - "matplotlib.pyplot.xlabel" - ), patch("matplotlib.pyplot.ylabel"), patch("matplotlib.pyplot.title"), patch( - "matplotlib.pyplot.savefig" - ), patch( - "matplotlib.pyplot.close" - ): - created_plots = plot_data(data_success, headers, matched_headers) - assert "time_vs_value1.png" in created_plots - assert "time_vs_value2.png" in created_plots - - # Test failure due to non-numeric data - data_failure = [ - {"Time": "1", "Value1": "A", "Value2": "B"}, - {"Time": "2", "Value1": "C", "Value2": "D"}, - ] - - with pytest.raises(Exception) as excinfo: - plot_data(data_failure, headers, matched_headers) - assert "All plots failed due to non-numeric data." in str(excinfo.value) - - -@pytest.mark.skip(reason="molrender is not pip installable") -def test_run_molrender(path_to_cif, vis_fxns): - result = vis_fxns.run_molrender(path_to_cif) - assert result == "Visualization created" - - -def test_create_notebook(path_to_cif, vis_fxns, get_registry): - result = vis_fxns.create_notebook(path_to_cif, get_registry) - assert result == "Visualization Complete" - - -def test_getpdb(fibronectin, get_registry): - name, _ = get_pdb(fibronectin, get_registry) - assert name.endswith(".pdb") - - -@pytest.fixture -def questions(): - qs = [ - "What are the effects of norhalichondrin B in mammals?", - ] - return qs[0] - - -@pytest.mark.skip(reason="This requires an API call") -def test_litsearch(questions): - llm = ChatOpenAI() - - searchtool = Scholar2ResultLLM(llm=llm) - for q in questions: - ans = searchtool._run(q) - assert isinstance(ans, str) - assert len(ans) > 0 - if os.path.exists("../query"): - os.rmdir("../query") diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 3bf37ebd..00000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,170 +0,0 @@ -import json -import os -import warnings -from unittest.mock import mock_open, patch - -import pytest - -from mdagent.utils import FileType, PathRegistry, find_file_path - -warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") - - -@pytest.fixture -def path_registry(): - registry = PathRegistry() - registry.get_timestamp = lambda: "20240109" - return registry - - -def test_write_to_file(): - path_registry = PathRegistry() - - with patch("builtins.open", mock_open()): - file_name = path_registry.write_file_name( - FileType.PROTEIN, - protein_name="1XYZ", - description="testing", - file_format="pdb", - ) - # assert file name starts and ends correctly - assert file_name.startswith("1XYZ") - assert file_name.endswith(".pdb") - - -def test_write_file_name_protein(path_registry): - file_name = path_registry.write_file_name( - FileType.PROTEIN, protein_name="1XYZ", description="testing", file_format="pdb" - ) - assert file_name == "1XYZ_testing_20240109.pdb" - - -def test_write_file_name_simulation_with_conditions(path_registry): - file_name = path_registry.write_file_name( - FileType.SIMULATION, - type_of_sim="MD", - protein_file_id="1XYZ", - conditions="pH7", - time_stamp="20240109", - ) - assert file_name == "MD_1XYZ_pH7_20240109.py" - - -def test_write_file_name_simulation_modified(path_registry): - file_name = path_registry.write_file_name( - FileType.SIMULATION, Sim_id="SIM456", modified=True, time_stamp="20240109" - ) - assert file_name == "SIM456_MOD_20240109.py" - - -def test_write_file_name_simulation_default(path_registry): - file_name = path_registry.write_file_name( - FileType.SIMULATION, - type_of_sim="MD", - protein_file_id="123", - time_stamp="20240109", - ) - assert file_name == "MD_123_20240109.py" - - -def test_write_file_name_record(path_registry): - file_name = path_registry.write_file_name( - FileType.RECORD, - record_type="REC", - protein_file_id="123", - Sim_id="SIM456", - term="dcd", - time_stamp="20240109", - ) - assert file_name == "REC_SIM456_123_20240109.dcd" - - -def test_map_path(): - mock_json_data = { - "existing_name": { - "path": "existing/path", - "name": "path", - "description": "Existing description", - } - } - new_path_dict = { - "new_name": { - "path": "new/path", - "name": "path", - "description": "New description", - } - } - updated_json_data = {**mock_json_data, **new_path_dict} - - path_registry = PathRegistry() - path_registry.json_file_path = "dummy_json_file.json" - - # Mocking os.path.exists to simulate the JSON file existence - with patch("os.path.exists", return_value=True): - # Mocking open for both reading and writing the JSON file - with patch( - "builtins.open", mock_open(read_data=json.dumps(mock_json_data)) - ) as mocked_file: - # Optionally, you can mock internal methods if needed - with patch.object( - path_registry, "_check_for_json", return_value=True - ), patch.object( - path_registry, "_check_json_content", return_value=True - ), patch.object( - path_registry, "_get_full_path", return_value="new/path" - ): # Mocking _get_full_path - result = path_registry.map_path( - "new_name", "new/path", "New description" - ) - # Aggregating all calls to write into a single string - written_data = "".join( - call.args[0] for call in mocked_file().write.call_args_list - ) - - # Comparing the aggregated data with the expected JSON data - assert json.loads(written_data) == updated_json_data - - # Check the result message - assert result == "Path successfully mapped to name: new_name" - - -mocked_files = {"files/solvents": ["water.pdb"]} - - -def mock_exists(path): - return path in mocked_files - - -def mock_listdir(path): - return mocked_files.get(path, []) - - -@pytest.fixture -def path_registry_with_mocked_fs(): - with patch("os.path.exists", side_effect=mock_exists): - with patch("os.listdir", side_effect=mock_listdir): - registry = PathRegistry() - registry.get_timestamp = lambda: "20240109" - return registry - - -def test_init_path_registry(path_registry_with_mocked_fs): - # This test will run with the mocked file system - # Here, you can assert if 'water.pdb' under 'solvents' is registered correctly - # Depending on how your PathRegistry class stores the registry, - # you may need to check the internal state or the contents of the JSON file. - # For example: - assert "water_000000" in path_registry_with_mocked_fs.list_path_names() - - -def test_find_file_path(): - file_name = "test_utils.py" - # current directory won't include folders - cwd = os.getcwd() - file_path_current = os.path.join(cwd, "tests", file_name) - file_path_test = find_file_path(file_name, exact_match=True) - assert file_path_current == file_path_test - - file_name_short = "test_util" - file_path_test_short = find_file_path(file_name_short, exact_match=False) - assert file_path_current == file_path_test_short From 998e8564bda2ff7af29ed6723e508af354d648b9 Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Tue, 27 Feb 2024 10:06:15 -0800 Subject: [PATCH 12/12] removed unused function --- mdagent/utils/__init__.py | 3 +-- mdagent/utils/general_utils.py | 24 ------------------------ 2 files changed, 1 insertion(+), 26 deletions(-) delete mode 100644 mdagent/utils/general_utils.py diff --git a/mdagent/utils/__init__.py b/mdagent/utils/__init__.py index ad59b1e4..ef0fa47b 100644 --- a/mdagent/utils/__init__.py +++ b/mdagent/utils/__init__.py @@ -1,5 +1,4 @@ -from .general_utils import find_file_path from .makellm import _make_llm from .path_registry import FileType, PathRegistry -__all__ = ["_make_llm", "PathRegistry", "FileType", "find_file_path"] +__all__ = ["_make_llm", "PathRegistry", "FileType"] diff --git a/mdagent/utils/general_utils.py b/mdagent/utils/general_utils.py deleted file mode 100644 index 6a1a6652..00000000 --- a/mdagent/utils/general_utils.py +++ /dev/null @@ -1,24 +0,0 @@ -import os - - -def find_file_path(file_name: str, exact_match: bool = True): - """get the path of a file, if it exists in repo""" - setup_dir = None - for dirpath, dirnames, filenames in os.walk("."): - if "setup.py" in filenames: - setup_dir = dirpath - break - - if setup_dir is None: - raise FileNotFoundError("Unable to find root directory.") - - for dirpath, dirnames, filenames in os.walk(setup_dir): - for filename in filenames: - if (exact_match and filename == file_name) or ( - not exact_match and file_name in filename - ): - path_full = os.path.join(dirpath, filename) - # make sure its absolute - return os.path.abspath(path_full) - - return None