Skip to content

Commit f5f2017

Browse files
committed
added tests for all possible param parsing, fixed casing
1 parent d9fda15 commit f5f2017

File tree

2 files changed

+76
-33
lines changed

2 files changed

+76
-33
lines changed

mdagent/tools/base_tools/simulation_tools/setup_and_run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,7 +1401,7 @@ def _parse_parameter(self, parameter, default_unit, possible_units):
14011401
unit_part = "/" + unit_part
14021402
elif "^-1" in parameter_str:
14031403
parameter_str = parameter_str.replace("^-1", "")
1404-
match = re.match(r'^(\d+(?:\.\d+)?)([a-zA-Z]+)$', parameter_str)
1404+
match = re.match(r"^(\d+(?:\.\d+)?)([a-zA-Z]+)$", parameter_str)
14051405
num_value = float(match.group(1))
14061406
unit_part = "/" + match.group(2)
14071407
else:
@@ -1419,8 +1419,8 @@ def _parse_parameter(self, parameter, default_unit, possible_units):
14191419
error_msg += f"Invalid format for parameter: '{parameter_str}'."
14201420

14211421
# Convert the unit part to an OpenMM unit
1422-
if unit_part in possible_units:
1423-
return num_value * possible_units[unit_part], error_msg
1422+
if unit_part.lower() in possible_units:
1423+
return num_value * possible_units[unit_part.lower()], error_msg
14241424
else:
14251425
# If the unit is not recognized, raise an error
14261426
error_msg += f"""Unknown unit '{unit_part}' for parameter.
@@ -1469,7 +1469,7 @@ def parse_pressure(self, pressure):
14691469
"atmosphere": unit.atmospheres,
14701470
"pascal": unit.pascals,
14711471
"pascals": unit.pascals,
1472-
"Pa": unit.pascals,
1472+
"pa": unit.pascals,
14731473
"poundforce/inch^2": unit.psi,
14741474
"psi": unit.psi,
14751475
}

tests/test_setup.py

Lines changed: 72 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,16 @@ def setupandrun(get_registry):
1616
return SetUpandRunFunction(get_registry)
1717

1818

19-
def test_parse_cutoff(setupandrun):
20-
cutoff = unit.Quantity(1.0, unit.nanometers)
21-
result = setupandrun._parse_cutoff(cutoff)
22-
assert cutoff == result
23-
24-
cutoff = 3.0
25-
result = setupandrun._parse_cutoff(cutoff)
26-
expected_result = unit.Quantity(cutoff, unit.nanometers)
27-
assert expected_result == result
28-
29-
cutoff = "2angstroms"
30-
result = setupandrun._parse_cutoff(cutoff)
31-
expected_result = unit.Quantity(2.0, unit.angstroms)
19+
@pytest.mark.parametrize(
20+
"input_cutoff, expected_result",
21+
[
22+
(unit.Quantity(1.0, unit.nanometers), unit.Quantity(1.0, unit.nanometers)),
23+
(3.0, unit.Quantity(3.0, unit.nanometers)),
24+
("2angstroms", unit.Quantity(2.0, unit.angstroms)),
25+
],
26+
)
27+
def test_parse_cutoff(setupandrun, input_cutoff, expected_result):
28+
result = setupandrun._parse_cutoff(input_cutoff)
3229
assert expected_result == result
3330

3431

@@ -40,26 +37,72 @@ def test_parse_cutoff_unknown_unit(setupandrun):
4037

4138
def test_parse_temperature(setupandrun):
4239
result = setupandrun.parse_temperature("300k")
40+
result2 = setupandrun.parse_temperature("300kelvin")
4341
expected_result = unit.Quantity(300, unit.kelvin)
42+
assert expected_result == result[0] == result2[0]
43+
44+
45+
@pytest.mark.parametrize(
46+
"input_friction, expected_friction_result",
47+
[
48+
("1/ps", unit.Quantity(1, 1 / unit.picoseconds)),
49+
("1/picosecond", unit.Quantity(1, 1 / unit.picosecond)),
50+
("1/picoseconds", unit.Quantity(1, 1 / unit.picosecond)),
51+
("1picosecond^-1", unit.Quantity(1, 1 / unit.picosecond)),
52+
("1picoseconds^-1", unit.Quantity(1, 1 / unit.picoseconds)),
53+
("1/ps^-1", unit.Quantity(1, 1 / unit.picoseconds)),
54+
("1ps^-1", unit.Quantity(1, 1 / unit.picoseconds)),
55+
("1*ps^-1", unit.Quantity(1, 1 / unit.picoseconds)),
56+
],
57+
)
58+
def test_parse_friction(setupandrun, input_friction, expected_friction_result):
59+
result = setupandrun.parse_friction(input_friction)
60+
assert (
61+
expected_friction_result == result[0]
62+
), f"Expected {expected_friction_result} for {input_friction}, got {result[0]}"
63+
64+
65+
@pytest.mark.parametrize(
66+
"input_time, expected_time_unit",
67+
[
68+
("1ps", unit.picoseconds),
69+
("1picosecond", unit.picoseconds),
70+
("1picoseconds", unit.picoseconds),
71+
("1fs", unit.femtoseconds),
72+
("1femtosecond", unit.femtoseconds),
73+
("1femtoseconds", unit.femtoseconds),
74+
("1ns", unit.nanoseconds),
75+
("1nanosecond", unit.nanoseconds),
76+
("1nanoseconds", unit.nanoseconds),
77+
],
78+
)
79+
def test_parse_time(setupandrun, input_time, expected_time_unit):
80+
result = setupandrun.parse_timestep(input_time)
81+
expected_result = unit.Quantity(1, expected_time_unit)
4482
assert expected_result == result[0]
4583

4684

47-
def parse_friction(setupandrun):
48-
result = setupandrun.parse_friction("1/ps")
49-
expected_result = unit.Quantity(1, unit.picoseconds)
50-
assert expected_result == result[0]
51-
52-
53-
def test_parse_time(setupandrun):
54-
result = setupandrun.parse_timestep("1ns")
55-
expected_result = unit.Quantity(1, unit.nanoseconds)
56-
assert expected_result == result[0]
57-
58-
59-
def test_parse_pressure(setupandrun):
60-
result = setupandrun.parse_pressure("1bar")
61-
expected_result = unit.Quantity(1, unit.bar)
62-
assert expected_result == result[0]
85+
@pytest.mark.parametrize(
86+
"input_pressure, expected_pressure_unit",
87+
[
88+
("1bar", unit.bar),
89+
("1atm", unit.atmospheres),
90+
("1atmosphere", unit.atmospheres),
91+
("1pascal", unit.pascals),
92+
("1pascals", unit.pascals),
93+
("1Pa", unit.pascals),
94+
("1poundforce/inch^2", unit.psi),
95+
("1psi", unit.psi),
96+
],
97+
)
98+
def test_parse_pressure(setupandrun, input_pressure, expected_pressure_unit):
99+
result = setupandrun.parse_pressure(input_pressure)
100+
expected_result = unit.Quantity(1, expected_pressure_unit)
101+
# assert expected_result == result[0]
102+
if expected_result != result[0]:
103+
raise AssertionError(
104+
f"Expected {expected_result} for {input_pressure}, got {result[0]}"
105+
)
63106

64107

65108
def test_process_parameters(setupandrun):

0 commit comments

Comments
 (0)