Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using a custom list of lambda values #42

Merged
merged 5 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion src/somd2/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
__all__ = ["Config"]


from collections.abc import Iterable as _Iterable
from typing import Iterable as _Iterable
from openmm import Platform as _Platform
from pathlib import Path as _Path

Expand Down Expand Up @@ -72,6 +72,12 @@ class Config:
"log_level": [level.lower() for level in _logger._core.levels],
}

# A dictionary of nargs for the various options.
_nargs = {
"lambda_values": "+",
"lambda_energy": "+",
}

def __init__(
self,
log_level="info",
Expand All @@ -86,6 +92,8 @@ def __init__(
cutoff="7.5A",
h_mass_factor=1.5,
num_lambda=11,
lambda_values=None,
lambda_energy=None,
lambda_schedule="standard_morph",
charge_scale_factor=0.2,
swap_end_states=False,
Expand Down Expand Up @@ -153,6 +161,15 @@ def __init__(
num_lambda: int
Number of lambda windows to use.

lambda_values: [float]
A list of lambda values. When specified, this takes precedence over
the 'num_lambda' option.

lambda_energy: [float]
A list of lambda values at which to output energy data. If not set,
then this will be set to the same as 'lambda_values', or the values
defined by 'num_lambda' if 'lambda_values' is not set.

lambda_schedule: str
Lambda schedule to use for alchemical free energy simulations.

Expand Down Expand Up @@ -281,6 +298,8 @@ def __init__(
self.h_mass_factor = h_mass_factor
self.timestep = timestep
self.num_lambda = num_lambda
self.lambda_values = lambda_values
self.lambda_energy = lambda_energy
self.lambda_schedule = lambda_schedule
self.charge_scale_factor = charge_scale_factor
self.swap_end_states = swap_end_states
Expand Down Expand Up @@ -616,6 +635,56 @@ def num_lambda(self, num_lambda):
raise ValueError("'num_lambda' must be an integer")
self._num_lambda = num_lambda

@property
def lambda_values(self):
return self._lambda_values

@lambda_values.setter
def lambda_values(self, lambda_values):
if lambda_values is not None:
if not isinstance(lambda_values, _Iterable):
raise ValueError("'lambda_values' must be an iterable")
try:
lambda_values = [float(x) for x in lambda_values]
except:
raise ValueError("'lambda_values' must be an iterable of floats")

if not all(0 <= x <= 1 for x in lambda_values):
raise ValueError(
"All entries in 'lambda_values' must be between 0 and 1"
)

# Round to 5dp.
lambda_values = [round(x, 5) for x in lambda_values]

self._num_lambda = len(lambda_values)

self._lambda_values = lambda_values

@property
def lambda_energy(self):
return self._lambda_energy

@lambda_energy.setter
def lambda_energy(self, lambda_energy):
if lambda_energy is not None:
if not isinstance(lambda_energy, _Iterable):
raise ValueError("'lambda_energy' must be an iterable")
try:
lambda_energy = [float(x) for x in lambda_energy]
except:
raise ValueError("'lambda_energy' must be an iterable of floats")

if not all(0 <= x <= 1 for x in lambda_energy):
raise ValueError(
"All entries in 'lambda_energy' must be between 0 and 1"
)

# Round to 5dp.
lambda_energy = [round(x, 5) for x in lambda_energy]

self._lambda_energy = lambda_energy

@property
def lambda_schedule(self):
return self._lambda_schedule
Expand Down Expand Up @@ -1273,6 +1342,12 @@ def _create_parser(cls):
# Get the type of the parameter. If None, then use str.
typ = str if params[param].default is None else type(params[param].default)

# Get the nargs for the parameter.
if param in cls._nargs:
nargs = cls._nargs[param]
else:
nargs = None

# This parameter has choices.
if param in cls._choices:
parser.add_argument(
Expand All @@ -1297,6 +1372,7 @@ def _create_parser(cls):
parser.add_argument(
f"--{cli_param}",
type=typ,
nargs=nargs,
default=params[param].default,
help=help[param],
required=False,
Expand Down
30 changes: 20 additions & 10 deletions src/somd2/runner/_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
system,
lambda_val,
lambda_array,
lambda_energy,
config,
increment=0.001,
device=None,
Expand All @@ -61,8 +62,11 @@ def __init__(
Lambda value for the simulation

lambda_array : list
List of lambda values to be used for perturbation, if none won't return
reduced perturbed energies
List of lambda values to be used for simulation.

lambda_energy: list
List of lambda values to be used for sampling energies. If None, then we
won't return reduced perturbed energies.

increment : float
Increment of lambda value - used for calculating the gradient
Expand Down Expand Up @@ -105,8 +109,14 @@ def __init__(
else:
self._current_block = 0

lambda_energy = lambda_energy.copy()
if not lambda_val in lambda_energy:
lambda_energy.append(lambda_val)
lambda_energy = sorted(lambda_energy)

self._lambda_val = lambda_val
self._lambda_array = lambda_array
self._lambda_energy = lambda_energy
self._increment = increment
self._device = device
self._has_space = has_space
Expand Down Expand Up @@ -136,13 +146,13 @@ def increment_filename(base_filename, suffix):

if lambda_value not in lambda_array:
raise ValueError("lambda_value not in lambda_array")
lam = f"{lambda_value:.5f}"
filenames = {}
index = lambda_array.index(lambda_value)
filenames["topology"] = "system.prm7"
filenames["checkpoint"] = f"checkpoint_{index}.s3"
filenames["energy_traj"] = f"energy_traj_{index}.parquet"
filenames["trajectory"] = f"traj_{index}.dcd"
filenames["trajectory_chunk"] = f"traj_{index}_"
filenames["checkpoint"] = f"checkpoint_{lam}.s3"
filenames["energy_traj"] = f"energy_traj_{lam}.parquet"
filenames["trajectory"] = f"traj_{lam}.dcd"
filenames["trajectory_chunk"] = f"traj_{lam}_"
if restart:
filenames["config"] = increment_filename("config", "yaml")
else:
Expand Down Expand Up @@ -348,10 +358,10 @@ def generate_lam_vals(lambda_base, increment):
# Work out the lambda values for finite-difference gradient analysis.
self._lambda_grad = generate_lam_vals(self._lambda_val, self._increment)

if self._lambda_array is None:
if self._lambda_energy is None:
lam_arr = self._lambda_grad
else:
lam_arr = self._lambda_array + self._lambda_grad
lam_arr = self._lambda_energy + self._lambda_grad

_logger.info(f"Running dynamics at {_lam_sym} = {self._lambda_val}")

Expand Down Expand Up @@ -436,7 +446,7 @@ def generate_lam_vals(lambda_base, increment):
metadata={
"attrs": df.attrs,
"lambda": str(self._lambda_val),
"lambda_array": lam_arr,
"lambda_array": self._lambda_energy,
"lambda_grad": self._lambda_grad,
"temperature": str(self._config.temperature.value()),
},
Expand Down
37 changes: 25 additions & 12 deletions src/somd2/runner/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,8 @@ def __init__(self, system, config):
The perturbable system to be simulated. This can be either a path
to a stream file, or a Sire system object.

num_lambda: int
The number of lambda windows to be simulated.

platform: str
The platform to be used for simulations.
config: :class: `Config <somd2.config.Config>`
The configuration options for the simulation.
"""

if not isinstance(system, (str, _System)):
Expand Down Expand Up @@ -169,10 +166,19 @@ def __init__(self, system, config):
self._check_end_state_constraints()

# Set the lambda values.
self._lambda_values = [
round(i / (self._config.num_lambda - 1), 5)
for i in range(0, self._config.num_lambda)
]
if self._config.lambda_values:
self._lambda_values = self._config.lambda_values
else:
self._lambda_values = [
round(i / (self._config.num_lambda - 1), 5)
for i in range(0, self._config.num_lambda)
]

# Set the lambda energy list.
if self._config.lambda_energy is not None:
self._lambda_energy = self._config.lambda_energy
else:
self._lambda_energy = self._lambda_values

# Work out the current hydrogen mass factor.
h_mass_factor, has_hydrogen = self._get_h_mass_factor(self._system)
Expand Down Expand Up @@ -421,10 +427,10 @@ def get_last_config(output_directory):
)
try:
system_temp = _stream.load(
str(self._config.output_directory / "checkpoint_0.s3")
str(self._config.output_directory / "checkpoint_0.00000.s3")
)
except:
expdir = self._config.output_directory / "checkpoint_0.s3"
expdir = self._config.output_directory / "checkpoint_0.00000.s3"
_logger.error(f"Unable to load checkpoint file from {expdir}.")
raise
else:
Expand Down Expand Up @@ -679,6 +685,7 @@ def _initialise_simulation(self, system, lambda_value, device=None):
system,
lambda_val=lambda_value,
lambda_array=self._lambda_values,
lambda_energy=self._lambda_energy,
config=self._config,
device=device,
has_space=self._has_space,
Expand Down Expand Up @@ -939,6 +946,12 @@ def _run(sim, is_restart=False):

from somd2 import __version__, _sire_version, _sire_revisionid

# Add the current lambda value to the list of lambda values and sort.
lambda_array = self._lambda_energy.copy()
if lambda_value not in lambda_array:
lambda_array.append(lambda_value)
lambda_array = sorted(lambda_array)

# Write final dataframe for the system to the energy trajectory file.
# Note that sire s3 checkpoint files contain energy trajectory data, so this works even for restarts.
_ = _dataframe_to_parquet(
Expand All @@ -948,7 +961,7 @@ def _run(sim, is_restart=False):
"somd2 version": __version__,
"sire version": f"{_sire_version}+{_sire_revisionid}",
"lambda": str(lambda_value),
"lambda_array": self._lambda_values,
"lambda_array": lambda_array,
"lambda_grad": lambda_grad,
"speed": speed,
"temperature": str(self._config.temperature.value()),
Expand Down
94 changes: 94 additions & 0 deletions tests/runner/test_lambda_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from pathlib import Path

import tempfile
import pytest

import sire as sr

from somd2.runner import Runner
from somd2.config import Config
from somd2.io import *


def test_lambda_values(ethane_methanol):
"""
Validate that a simulation can be run with a custom list of lambda values.
"""

with tempfile.TemporaryDirectory() as tmpdir:
mols = ethane_methanol.clone()

config = {
"runtime": "12fs",
"restart": False,
"output_directory": tmpdir,
"energy_frequency": "4fs",
"checkpoint_frequency": "4fs",
"frame_frequency": "4fs",
"platform": "CPU",
"max_threads": 1,
"lambda_values": [0.0, 0.5, 1.0],
}

# Instantiate a runner using the config defined above.
runner = Runner(mols, Config(**config))

# Run the simulation.
runner.run()

# Load the energy trajectory.
energy_traj, meta = parquet_to_dataframe(
Path(tmpdir) / "energy_traj_0.00000.parquet"
)

# Make sure the lambda_array in the metadata is correct. This is the
# lambda_values list in the config.
assert meta["lambda_array"] == [0.0, 0.5, 1.0]

# Make sure the second dimension of the energy trajectory is the correct
# size. This is one for the current lambda value, one for its gradient,
# and two for the additional values in the lambda_values list.
assert energy_traj.shape[1] == 4


def test_lambda_energy(ethane_methanol):
"""
Validate that a simulation can sample energies at a different set of
lambda values.
"""

with tempfile.TemporaryDirectory() as tmpdir:
mols = ethane_methanol.clone()

config = {
"runtime": "12fs",
"restart": False,
"output_directory": tmpdir,
"energy_frequency": "4fs",
"checkpoint_frequency": "4fs",
"frame_frequency": "4fs",
"platform": "CPU",
"max_threads": 1,
"lambda_values": [0.0, 1.0],
"lambda_energy": [0.5],
}

# Instantiate a runner using the config defined above.
runner = Runner(mols, Config(**config))

# Run the simulation.
runner.run()

# Load the energy trajectory.
energy_traj, meta = parquet_to_dataframe(
Path(tmpdir) / "energy_traj_0.00000.parquet"
)

# Make sure the lambda_array in the metadata is correct. This is the
# sampled lambda plus the lambda_energy values in the config.
assert meta["lambda_array"] == [0.0, 0.5]

# Make sure the second dimension of the energy trajectory is the correct
# size. This is one for the current lambda value, one for its gradient,
# and one for the length of lambda_energy.
assert energy_traj.shape[1] == 3
Loading
Loading