Skip to content

Commit

Permalink
added tests for all possible param parsing, fixed casing
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 committed Mar 18, 2024
1 parent d9fda15 commit f5f2017
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 33 deletions.
8 changes: 4 additions & 4 deletions mdagent/tools/base_tools/simulation_tools/setup_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,7 @@ def _parse_parameter(self, parameter, default_unit, possible_units):
unit_part = "/" + unit_part
elif "^-1" in parameter_str:
parameter_str = parameter_str.replace("^-1", "")
match = re.match(r'^(\d+(?:\.\d+)?)([a-zA-Z]+)$', parameter_str)
match = re.match(r"^(\d+(?:\.\d+)?)([a-zA-Z]+)$", parameter_str)
num_value = float(match.group(1))
unit_part = "/" + match.group(2)
else:
Expand All @@ -1419,8 +1419,8 @@ def _parse_parameter(self, parameter, default_unit, possible_units):
error_msg += f"Invalid format for parameter: '{parameter_str}'."

# Convert the unit part to an OpenMM unit
if unit_part in possible_units:
return num_value * possible_units[unit_part], error_msg
if unit_part.lower() in possible_units:
return num_value * possible_units[unit_part.lower()], error_msg
else:
# If the unit is not recognized, raise an error
error_msg += f"""Unknown unit '{unit_part}' for parameter.
Expand Down Expand Up @@ -1469,7 +1469,7 @@ def parse_pressure(self, pressure):
"atmosphere": unit.atmospheres,
"pascal": unit.pascals,
"pascals": unit.pascals,
"Pa": unit.pascals,
"pa": unit.pascals,
"poundforce/inch^2": unit.psi,
"psi": unit.psi,
}
Expand Down
101 changes: 72 additions & 29 deletions tests/test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,16 @@ def setupandrun(get_registry):
return SetUpandRunFunction(get_registry)


def test_parse_cutoff(setupandrun):
cutoff = unit.Quantity(1.0, unit.nanometers)
result = setupandrun._parse_cutoff(cutoff)
assert cutoff == result

cutoff = 3.0
result = setupandrun._parse_cutoff(cutoff)
expected_result = unit.Quantity(cutoff, unit.nanometers)
assert expected_result == result

cutoff = "2angstroms"
result = setupandrun._parse_cutoff(cutoff)
expected_result = unit.Quantity(2.0, unit.angstroms)
@pytest.mark.parametrize(
"input_cutoff, expected_result",
[
(unit.Quantity(1.0, unit.nanometers), unit.Quantity(1.0, unit.nanometers)),
(3.0, unit.Quantity(3.0, unit.nanometers)),
("2angstroms", unit.Quantity(2.0, unit.angstroms)),
],
)
def test_parse_cutoff(setupandrun, input_cutoff, expected_result):
result = setupandrun._parse_cutoff(input_cutoff)
assert expected_result == result


Expand All @@ -40,26 +37,72 @@ def test_parse_cutoff_unknown_unit(setupandrun):

def test_parse_temperature(setupandrun):
result = setupandrun.parse_temperature("300k")
result2 = setupandrun.parse_temperature("300kelvin")
expected_result = unit.Quantity(300, unit.kelvin)
assert expected_result == result[0] == result2[0]


@pytest.mark.parametrize(
"input_friction, expected_friction_result",
[
("1/ps", unit.Quantity(1, 1 / unit.picoseconds)),
("1/picosecond", unit.Quantity(1, 1 / unit.picosecond)),
("1/picoseconds", unit.Quantity(1, 1 / unit.picosecond)),
("1picosecond^-1", unit.Quantity(1, 1 / unit.picosecond)),
("1picoseconds^-1", unit.Quantity(1, 1 / unit.picoseconds)),
("1/ps^-1", unit.Quantity(1, 1 / unit.picoseconds)),
("1ps^-1", unit.Quantity(1, 1 / unit.picoseconds)),
("1*ps^-1", unit.Quantity(1, 1 / unit.picoseconds)),
],
)
def test_parse_friction(setupandrun, input_friction, expected_friction_result):
result = setupandrun.parse_friction(input_friction)
assert (
expected_friction_result == result[0]
), f"Expected {expected_friction_result} for {input_friction}, got {result[0]}"


@pytest.mark.parametrize(
"input_time, expected_time_unit",
[
("1ps", unit.picoseconds),
("1picosecond", unit.picoseconds),
("1picoseconds", unit.picoseconds),
("1fs", unit.femtoseconds),
("1femtosecond", unit.femtoseconds),
("1femtoseconds", unit.femtoseconds),
("1ns", unit.nanoseconds),
("1nanosecond", unit.nanoseconds),
("1nanoseconds", unit.nanoseconds),
],
)
def test_parse_time(setupandrun, input_time, expected_time_unit):
result = setupandrun.parse_timestep(input_time)
expected_result = unit.Quantity(1, expected_time_unit)
assert expected_result == result[0]


def parse_friction(setupandrun):
result = setupandrun.parse_friction("1/ps")
expected_result = unit.Quantity(1, unit.picoseconds)
assert expected_result == result[0]


def test_parse_time(setupandrun):
result = setupandrun.parse_timestep("1ns")
expected_result = unit.Quantity(1, unit.nanoseconds)
assert expected_result == result[0]


def test_parse_pressure(setupandrun):
result = setupandrun.parse_pressure("1bar")
expected_result = unit.Quantity(1, unit.bar)
assert expected_result == result[0]
@pytest.mark.parametrize(
"input_pressure, expected_pressure_unit",
[
("1bar", unit.bar),
("1atm", unit.atmospheres),
("1atmosphere", unit.atmospheres),
("1pascal", unit.pascals),
("1pascals", unit.pascals),
("1Pa", unit.pascals),
("1poundforce/inch^2", unit.psi),
("1psi", unit.psi),
],
)
def test_parse_pressure(setupandrun, input_pressure, expected_pressure_unit):
result = setupandrun.parse_pressure(input_pressure)
expected_result = unit.Quantity(1, expected_pressure_unit)
# assert expected_result == result[0]
if expected_result != result[0]:
raise AssertionError(
f"Expected {expected_result} for {input_pressure}, got {result[0]}"
)


def test_process_parameters(setupandrun):
Expand Down

0 comments on commit f5f2017

Please sign in to comment.