Skip to content

Commit

Permalink
Preprocess tests (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 authored Mar 18, 2024
1 parent b359793 commit 656d7e2
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 4 deletions.
17 changes: 13 additions & 4 deletions mdagent/tools/base_tools/simulation_tools/setup_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,8 @@ def _parse_cutoff(self, cutoff):

# Convert to string in case it's not (e.g., int or float)
cutoff = str(cutoff)
if cutoff[-1] == "s":
cutoff = cutoff[:-1]

# Remove spaces and convert to lowercase for easier parsing
cutoff = cutoff.replace(" ", "").lower()
Expand Down Expand Up @@ -1388,12 +1390,20 @@ def _parse_parameter(self, parameter, default_unit, possible_units):
if "*" in parameter_str:
num_part, unit_part = parameter_str.split("*")
num_value = float(num_part)
elif "poundforce/inch^2" in parameter_str:
num_value = float(parameter_str.replace("poundforce/inch^2", ""))
unit_part = "poundforce/inch^2"
# Check for division symbol and split if necessary
# e.g. "1/ps" or "1/ps^-1"
elif "/" in parameter_str:
num_part, unit_part = parameter_str.split("/")
num_value = float(num_part)
unit_part = "/" + unit_part
elif "^-1" in parameter_str:
parameter_str = parameter_str.replace("^-1", "")
match = re.match(r"^(\d+(?:\.\d+)?)([a-zA-Z]+)$", parameter_str)
num_value = float(match.group(1))
unit_part = "/" + match.group(2)
else:
# Attempt to convert directly to float; if it fails,
# it must have a unit like "K", "ps", etc.
Expand All @@ -1409,16 +1419,15 @@ def _parse_parameter(self, parameter, default_unit, possible_units):
error_msg += f"Invalid format for parameter: '{parameter_str}'."

# Convert the unit part to an OpenMM unit
if unit_part in possible_units:
return num_value * possible_units[unit_part], error_msg
if unit_part.lower() in possible_units:
return num_value * possible_units[unit_part.lower()], error_msg
else:
# If the unit is not recognized, raise an error
error_msg += f"""Unknown unit '{unit_part}' for parameter.
Valid units include: {list(possible_units.keys())}."""

return parameter, error_msg

# Example method to use _parse_parameter for specific parameter
def parse_temperature(self, temperature):
possible_units = {
"k": unit.kelvin,
Expand Down Expand Up @@ -1460,7 +1469,7 @@ def parse_pressure(self, pressure):
"atmosphere": unit.atmospheres,
"pascal": unit.pascals,
"pascals": unit.pascals,
"Pa": unit.pascals,
"pa": unit.pascals,
"poundforce/inch^2": unit.psi,
"psi": unit.psi,
}
Expand Down
121 changes: 121 additions & 0 deletions tests/test_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pytest
from openmm import unit
from openmm.app import PME, HBonds

from mdagent.tools.base_tools.simulation_tools import SetUpandRunFunction
from mdagent.utils import PathRegistry


@pytest.fixture
def get_registry():
return PathRegistry()


@pytest.fixture
def setupandrun(get_registry):
return SetUpandRunFunction(get_registry)


@pytest.mark.parametrize(
"input_cutoff, expected_result",
[
(unit.Quantity(1.0, unit.nanometers), unit.Quantity(1.0, unit.nanometers)),
(3.0, unit.Quantity(3.0, unit.nanometers)),
("2angstroms", unit.Quantity(2.0, unit.angstroms)),
],
)
def test_parse_cutoff(setupandrun, input_cutoff, expected_result):
result = setupandrun._parse_cutoff(input_cutoff)
assert expected_result == result


def test_parse_cutoff_unknown_unit(setupandrun):
with pytest.raises(ValueError) as e:
setupandrun._parse_cutoff("2pc")
assert "Unknown unit" in str(e.value)


def test_parse_temperature(setupandrun):
result = setupandrun.parse_temperature("300k")
result2 = setupandrun.parse_temperature("300kelvin")
expected_result = unit.Quantity(300, unit.kelvin)
assert expected_result == result[0] == result2[0]


@pytest.mark.parametrize(
"input_friction, expected_friction_result",
[
("1/ps", unit.Quantity(1, 1 / unit.picoseconds)),
("1/picosecond", unit.Quantity(1, 1 / unit.picosecond)),
("1/picoseconds", unit.Quantity(1, 1 / unit.picosecond)),
("1picosecond^-1", unit.Quantity(1, 1 / unit.picosecond)),
("1picoseconds^-1", unit.Quantity(1, 1 / unit.picoseconds)),
("1/ps^-1", unit.Quantity(1, 1 / unit.picoseconds)),
("1ps^-1", unit.Quantity(1, 1 / unit.picoseconds)),
("1*ps^-1", unit.Quantity(1, 1 / unit.picoseconds)),
],
)
def test_parse_friction(setupandrun, input_friction, expected_friction_result):
result = setupandrun.parse_friction(input_friction)
assert (
expected_friction_result == result[0]
), f"Expected {expected_friction_result} for {input_friction}, got {result[0]}"


@pytest.mark.parametrize(
"input_time, expected_time_unit",
[
("1ps", unit.picoseconds),
("1picosecond", unit.picoseconds),
("1picoseconds", unit.picoseconds),
("1fs", unit.femtoseconds),
("1femtosecond", unit.femtoseconds),
("1femtoseconds", unit.femtoseconds),
("1ns", unit.nanoseconds),
("1nanosecond", unit.nanoseconds),
("1nanoseconds", unit.nanoseconds),
],
)
def test_parse_time(setupandrun, input_time, expected_time_unit):
result = setupandrun.parse_timestep(input_time)
expected_result = unit.Quantity(1, expected_time_unit)
assert expected_result == result[0]


@pytest.mark.parametrize(
"input_pressure, expected_pressure_unit",
[
("1bar", unit.bar),
("1atm", unit.atmospheres),
("1atmosphere", unit.atmospheres),
("1pascal", unit.pascals),
("1pascals", unit.pascals),
("1Pa", unit.pascals),
("1poundforce/inch^2", unit.psi),
("1psi", unit.psi),
],
)
def test_parse_pressure(setupandrun, input_pressure, expected_pressure_unit):
result = setupandrun.parse_pressure(input_pressure)
expected_result = unit.Quantity(1, expected_pressure_unit)
# assert expected_result == result[0]
if expected_result != result[0]:
raise AssertionError(
f"Expected {expected_result} for {input_pressure}, got {result[0]}"
)


def test_process_parameters(setupandrun):
parameters = {
"nonbondedMethod": "PME",
"constraints": "HBonds",
"rigidWater": True,
}
result = setupandrun._process_parameters(parameters)
expected_result = {
"nonbondedMethod": PME,
"constraints": HBonds,
"rigidWater": True,
}
for key in expected_result:
assert result[0][key] == expected_result[key]

0 comments on commit 656d7e2

Please sign in to comment.