From 4874f37ac00e4ffa7169d0a2e37acf2c72667861 Mon Sep 17 00:00:00 2001 From: Jorge Date: Thu, 29 Feb 2024 23:15:08 -0500 Subject: [PATCH] Adding tests for Setupand Run OpenMMSimulation class. 1 Split Pathregistry, setupandrunfunctions and conflists tests --- .../preprocess_tools/clean_tools.py | 2 +- .../base_tools/preprocess_tools/pdb_fix.py | 2 +- tests/conftest.py | 99 +++++++++++++++ tests/test_pathregistry.py | 42 ++++++ tests/test_setupandrun.py | 120 ++++++++++++++++++ 5 files changed, 263 insertions(+), 2 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_pathregistry.py create mode 100644 tests/test_setupandrun.py diff --git a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py b/mdagent/tools/base_tools/preprocess_tools/clean_tools.py index 589a4294..d664dd64 100644 --- a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/clean_tools.py @@ -227,7 +227,7 @@ class CleaningToolFunctionInput(BaseModel): ) add_hydrogens_ph: int = Field(7.0, description="pH at which hydrogens are added.") - @root_validator + @root_validator(skip_on_failure=True) def validate_query(cls, values) -> Dict: """Check that the input is valid.""" diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py b/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py index 4cef4ef0..b63af3f3 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py @@ -660,7 +660,7 @@ class PDBFilesFixInp(BaseModel): ), ) - @root_validator + @root_validator(skip_on_failure=True) def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: if isinstance(values, str): print("values is a string", values) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..7cbb6c94 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,99 @@ +import os +import shutil +from pathlib import Path + +import pytest + +from mdagent.utils import PathRegistry + + +@pytest.fixture(scope="module") +def raw_alanine_pdb_file(request): + pdb_content = """ + ATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 20.00 N + ATOM 2 CA ALA A 1 1.458 0.000 0.000 1.00 20.00 C + ATOM 3 C ALA A 1 2.175 1.395 0.000 1.00 20.00 C + ATOM 4 O ALA A 1 1.461 2.400 0.000 1.00 20.00 O + ATOM 5 CB ALA A 1 1.958 -0.735 -1.231 1.00 20.00 C + TER + END + """.strip() + with open("ALA_raw_123456.pdb", "w") as f: + f.write(pdb_content) + yield "ALA_raw_123456.pdb" + + request.addfinalizer(lambda: os.remove("ALA_raw_123456.pdb")) + + +@pytest.fixture(scope="module") +def clean_alanine_pdb_file(request): + pdb_content = """ +REMARK ACE +CRYST1 32.155 32.155 56.863 90.00 90.00 120.00 P 31 2 1 6 +ATOM 1 1HH3 ACE 1 2.000 1.000 -0.000 +ATOM 2 CH3 ACE 1 2.000 2.090 0.000 +ATOM 3 2HH3 ACE 1 1.486 2.454 0.890 +ATOM 4 3HH3 ACE 1 1.486 2.454 -0.890 +ATOM 5 C ACE 1 3.427 2.641 -0.000 +ATOM 6 O ACE 1 4.391 1.877 -0.000 +ATOM 7 N ALA 2 3.555 3.970 -0.000 +ATOM 8 H ALA 2 2.733 4.556 -0.000 +ATOM 9 CA ALA 2 4.853 4.614 -0.000 +ATOM 10 HA ALA 2 5.408 4.316 0.890 +ATOM 11 CB ALA 2 5.661 4.221 -1.232 +ATOM 12 1HB ALA 2 5.123 4.521 -2.131 +ATOM 13 2HB ALA 2 6.630 4.719 -1.206 +ATOM 14 3HB ALA 2 5.809 3.141 -1.241 +ATOM 15 C ALA 2 4.713 6.129 0.000 +ATOM 16 O ALA 2 3.601 6.653 0.000 +ATOM 17 N NME 3 5.846 6.835 0.000 +ATOM 18 H NME 3 6.737 6.359 -0.000 +ATOM 19 CH3 NME 3 5.846 8.284 0.000 +ATOM 20 1HH3 NME 3 4.819 8.648 0.000 +ATOM 21 2HH3 NME 3 6.360 8.648 0.890 +ATOM 22 3HH3 NME 3 6.360 8.648 -0.890 +TER +END + """ + with open("ALA_clean_654321.pdb", "w") as f: + f.write(pdb_content) + + yield "ALA_clean_654321.pdb" + + request.addfinalizer(lambda: os.remove("ALA_clean_654321.pdb")) + + +@pytest.fixture(scope="function") +def get_registry(raw_alanine_pdb_file, clean_alanine_pdb_file, request): + created_paths = [] # Keep track of created directories for cleanup + + def create(raw_or_clean, with_files): + base_path = "files" + if with_files: + pdb_path = Path(base_path) / "pdb" + record_path = Path(base_path) / "records" + simulation_path = Path(base_path) / "simulation" + + # Create directories + for path in [pdb_path, record_path, simulation_path]: + os.makedirs(path, exist_ok=True) + created_paths.append(path) + if raw_or_clean == "raw": + # Copy the alanine pdb file to the pdb/alanine directory + shutil.copy(raw_alanine_pdb_file, pdb_path) + elif raw_or_clean == "clean": + shutil.copy(clean_alanine_pdb_file, pdb_path) + + # Assuming PathRegistry is defined elsewhere and properly implemented + return PathRegistry() + + # Cleanup: Remove created directories and the copied pdb file + def cleanup(): + for path in reversed(created_paths): # Remove directories + shutil.rmtree(path, ignore_errors=True) + if os.path.exists("path_registry.json"): + os.remove("path_registry.json") + + request.addfinalizer(cleanup) + + return create diff --git a/tests/test_pathregistry.py b/tests/test_pathregistry.py new file mode 100644 index 00000000..d04f6fdd --- /dev/null +++ b/tests/test_pathregistry.py @@ -0,0 +1,42 @@ +import os + +import pytest + + +@pytest.mark.parametrize("with_files, raw_or_clean", [(False, "raw"), (True, "raw")]) +def test_registry_init(get_registry, with_files, raw_or_clean): + # make the test directory the cwd + # print(os.curdir) + # if os.curdir.split("/")[-1] != "tests": + # os.chdir("tests") + registry_without_files = get_registry(raw_or_clean, with_files) + print(with_files, raw_or_clean) + if not with_files: + assert registry_without_files._load_existing_registry() == {} + else: + if raw_or_clean == "raw": + absolute_path = os.path.abspath("files/pdb/ALA_raw_123456.pdb") + expected_json = { + "ALA_123456": { + "path": f"{absolute_path}", + "name": "ALA_raw_123456.pdb", + "description": ( + "Protein ALA pdb file. " + "downloaded from RCSB Protein Data Bank. " + ), + } + } + assert registry_without_files._load_existing_registry() == expected_json + elif raw_or_clean == "clean": + absolute_path = os.path.abspath("files/pdb/ALA_clean_654321.pdb") + expected_json = { + "ALA_654321": { + "path": f"{absolute_path}", + "name": "ALA_clean_654321.pdb", + "description": ( + "Protein ALA pdb file. " + "downloaded from RCSB Protein Data Bank. " + ), + } + } + assert registry_without_files._load_existing_registry() == expected_json diff --git a/tests/test_setupandrun.py b/tests/test_setupandrun.py new file mode 100644 index 00000000..53ba5d73 --- /dev/null +++ b/tests/test_setupandrun.py @@ -0,0 +1,120 @@ +import json +import os + +import pytest + +from mdagent.tools.base_tools.simulation_tools.setup_and_run import ( + OpenMMSimulation, + SetUpandRunFunction, +) + + +@pytest.fixture(scope="module") +def raw(): + return "raw" + + +@pytest.fixture(scope="module") +def clean(): + return "clean" + + +@pytest.fixture(scope="module") +def string_input(): + def create_input(raw_or_clean): + if raw_or_clean == "raw": + pdb_id = "ALA_123456" + elif raw_or_clean == "clean": + pdb_id = "ALA_654321" + return """ + {{ + "pdb_id": "{pdb_id}", + "forcefield_files": ["amber14-all.xml", "amber14/tip3pfb.xml"], + "save": true, + "system_params":{{ + "nonbondedMethod": "PME", + "nonbondedCutoff": "1 * nanometers", + "ewaldErrorTolerance": 0.0005, + "constraints": "HBonds", + "rigidWater": true, + "constraintTolerance": 0.00001, + "solvate": false + }}, + "integrator_params":{{ + "integrator_type": "LangevinMiddle", + "Temperature": "300 * kelvin", + "Friction": "1 / picosecond", + "Timestep": "2 * femtoseconds" + }}, + "simmulation_params": {{ + "Ensemble": "NVT", + "Number of Steps": 5000, + "record_interval_steps": 50, + "record_params": ["step", "potentialEnergy", "temperature"] + }} + }} + """.format( + pdb_id=pdb_id + ).strip() + + return create_input + + +# @pytest.fixture(scope="module") +# def get_tool_input(string_input,type): +# inp = string_input(type) +# return SetUpandRunFunctionInput(**json.loads(inp)) + + +def test_init_SetUpandRunFunction(get_registry): + """Test the SetUpandRunFunction class initialization.""" + registry = get_registry("raw", False) + tool = SetUpandRunFunction(path_registry=registry) + assert tool.name == "SetUpandRunFunction" + assert tool.path_registry == registry + + +def test_check_system_params(get_registry, string_input, raw, clean): + """Test the check_system_params method of the SetUpandRunFunction class.""" + + registry = get_registry(raw, False) + tool = SetUpandRunFunction(path_registry=registry) + final_values_1 = tool.check_system_params(json.loads(string_input(raw))) + assert final_values_1.get("error") is None + final_values_2 = tool.check_system_params(json.loads(string_input(clean))) + assert final_values_2.get("error") is None + + +def test_openmmsimulation_init(get_registry, string_input, raw, clean): + """Test the OpenMMSimulation class initialization.""" + # assert an openmmexception is raised + with pytest.raises(ValueError): + registry = get_registry(raw, True) + tool_input = json.loads(string_input(raw)) + print(tool_input) + Simulation = OpenMMSimulation( + input_params=tool_input, + path_registry=registry, + save=tool_input["save"], + sim_id="sim_123456", + pdb_id=tool_input["pdb_id"], + ) + + registry = get_registry(clean, True) + tool_input = json.loads(string_input(clean)) + inputs = SetUpandRunFunction(path_registry=registry).check_system_params(tool_input) + print(tool_input) + path_of_file = registry.get_mapped_path(tool_input["pdb_id"]) + Simulation = OpenMMSimulation( + input_params=inputs, + path_registry=registry, + save=tool_input["save"], + sim_id="sim_654321", + pdb_id=tool_input["pdb_id"], + ) + assert Simulation.pdb_path == path_of_file + + # remove files that start with LOG, TOP, and TRAJ + for file in os.listdir("."): + if file.startswith("LOG") or file.startswith("TOP") or file.startswith("TRAJ"): + os.remove(f"{file}")