Skip to content

Commit 13366ac

Browse files
committed
Fix handling of REST2 scale factors when subsampling. [closes #69]
1 parent 40e4254 commit 13366ac

File tree

4 files changed

+35
-12
lines changed

4 files changed

+35
-12
lines changed

src/somd2/config/_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,9 @@ def __init__(
305305
the rest of the system. This can either be a single scaling factor, or a list of
306306
scale factors for each lambda window. When a single scaling factor is used, then
307307
the scale factor will be interpolated between a value of 1.0 in the end states,
308-
and the value of 'rest2_scale' in intermediate lambda = 0.5 state.
308+
and the value of 'rest2_scale' in intermediate lambda = 0.5 state. When multiple
309+
values are used, then the number should match the number of lambda windows at which
310+
energies are sampled.
309311
310312
rest2_selection: str
311313
A sire selection string for atoms to include in the REST2 region in

src/somd2/runner/_base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,25 @@ def __init__(self, system, config):
240240
else:
241241
self._lambda_energy = self._lambda_values
242242

243+
# Make sure the lambda values are in the lambda energy list.
244+
is_missing = False
245+
for lambda_value in self._lambda_values:
246+
if lambda_value not in self._lambda_energy:
247+
self._lambda_energy.append(lambda_value)
248+
is_missing = True
249+
250+
# Make sure the lambda_values entries are unique.
251+
if not len(self._lambda_values) == len(set(self._lambda_values)):
252+
msg = "Duplicate entries in 'lambda_values' list."
253+
_logger.error(msg)
254+
raise ValueError(msg)
255+
256+
# Make sure the lambda_energy entries are unique.
257+
if not len(self._lambda_energy) == len(set(self._lambda_energy)):
258+
msg = "Duplicate entries in 'lambda_energy' list."
259+
_logger.error(msg)
260+
raise ValueError(msg)
261+
243262
from math import isclose
244263

245264
# Set the REST2 scale factors.
@@ -258,6 +277,9 @@ def __init__(self, system, config):
258277
else:
259278
if len(self._config.rest2_scale) != len(self._lambda_energy):
260279
msg = f"Length of 'rest2_scale' must match the number of {_lam_sym} values."
280+
if is_missing:
281+
msg += f"If you have omitted some 'lambda_values` from `lambda_energy`, please "
282+
f"add them to `lambda_energy`, along with the corresponding `rest2_scale` values."
261283
_logger.error(msg)
262284
raise ValueError(msg)
263285
# Make sure the end states are close to 1.0.

src/somd2/runner/_runner.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,11 @@ def _run(
317317
# Get the lambda value.
318318
lambda_value = self._lambda_values[index]
319319

320+
# Get the index in the lambda_energy array.
321+
nrg_index = self._lambda_energy.index(lambda_value)
322+
320323
# Get the REST2 scaling factor.
321-
rest2_scale = self._rest2_scale_factors[index]
324+
rest2_scale = self._rest2_scale_factors[nrg_index]
322325

323326
# Check for completion if this is a restart.
324327
if is_restart:
@@ -445,10 +448,6 @@ def generate_lam_vals(lambda_base, increment=0.001):
445448
# Create the array of lambda values for energy sampling.
446449
lambda_energy = self._lambda_energy.copy()
447450

448-
# If missing, add the lambda value.
449-
if lambda_value not in self._lambda_energy:
450-
lambda_energy.append(lambda_value)
451-
452451
# Sort the lambda values.
453452
lambda_energy = sorted(lambda_energy)
454453

tests/runner/test_lambda_values.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ def test_lambda_energy(ethane_methanol):
8585
)
8686

8787
# Make sure the lambda_array in the metadata is correct. This is the
88-
# sampled lambda plus the lambda_energy values in the config.
89-
assert meta["lambda_array"] == [0.0, 0.5]
88+
# sampled lambda_values plus the lambda_energy values in the config.
89+
assert meta["lambda_array"] == [0.0, 0.5, 1.0]
9090

91-
# Make sure the second dimension of the energy trajectory is the correct
92-
# size. This is one for the current lambda value, one for its gradient,
93-
# and one for the length of lambda_energy.
94-
assert energy_traj.shape[1] == 3
91+
# Make sure the second dimension of the energy trajectory is the correct.
92+
# This is the sampled lambda values, i.e. unique entries from lambda_values
93+
# and lambda_energy, plus the gradient for TI.
94+
assert energy_traj.shape[1] == 4

0 commit comments

Comments
 (0)