From 5e73871556f763da4587679d42296e09c5d888f7 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Fri, 17 May 2024 10:59:13 +0100 Subject: [PATCH] Add support for saving energy contribution from each force. --- src/somd2/config/_config.py | 16 +++++++++++ src/somd2/runner/_dynamics.py | 54 ++++++++++++++++++++++++++++++++++- src/somd2/runner/_runner.py | 4 ++- 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/src/somd2/config/_config.py b/src/somd2/config/_config.py index 6f6afbc..51062f1 100644 --- a/src/somd2/config/_config.py +++ b/src/somd2/config/_config.py @@ -124,6 +124,7 @@ def __init__( overwrite=False, somd1_compatibility=False, pert_file=None, + save_energy_components=False, ): """ Constructor. @@ -281,6 +282,10 @@ def __init__( pert_file: str The path to a SOMD1 perturbation file to apply to the reference system. When set, this will automatically set 'somd1_compatibility' to True. + + save_energy_components: bool + Whether to save the energy contribution for each force when checkpointing. + This is useful when debugging crashes. """ # Setup logger before doing anything else @@ -327,6 +332,7 @@ def __init__( self.restart = restart self.somd1_compatibility = somd1_compatibility self.pert_file = pert_file + self.save_energy_components = save_energy_components self.write_config = write_config @@ -1201,6 +1207,16 @@ def pert_file(self, pert_file): if pert_file is not None: self._somd1_compatibility = True + @property + def save_energy_components(self): + return self._save_energy_components + + @save_energy_components.setter + def save_energy_components(self, save_energy_components): + if not isinstance(save_energy_components, bool): + raise ValueError("'save_energy_components' must be of type 'bool'") + self._save_energy_components = save_energy_components + @property def output_directory(self): return self._output_directory diff --git a/src/somd2/runner/_dynamics.py b/src/somd2/runner/_dynamics.py index 2a54d27..e7eeabb 100644 --- a/src/somd2/runner/_dynamics.py +++ b/src/somd2/runner/_dynamics.py @@ -127,6 +127,9 @@ def __init__( self._config.restart, ) + self._nrg_sample = 0 + self._nrg_file = "energy_components.txt" + @staticmethod def create_filenames(lambda_array, lambda_value, output_directory, restart=False): # Create incremental file name for current restart. @@ -153,6 +156,7 @@ def increment_filename(base_filename, suffix): filenames["energy_traj"] = f"energy_traj_{lam}.parquet" filenames["trajectory"] = f"traj_{lam}.dcd" filenames["trajectory_chunk"] = f"traj_{lam}_" + filenames["energy_components"] = f"energy_components_{lam}.txt" if restart: filenames["config"] = increment_filename("config", "yaml") else: @@ -371,7 +375,7 @@ def generate_lam_vals(lambda_base, increment): ) if self._config.checkpoint_frequency.value() > 0.0: - # Calculate the number of blocks and the remaineder time. + # Calculate the number of blocks and the remainder time. frac = ( self._config.runtime.value() / self._config.checkpoint_frequency.value() ) @@ -409,6 +413,10 @@ def generate_lam_vals(lambda_base, increment): # Checkpoint. try: + # Save the energy contribution for each force. + if self._config.save_energy_components: + self._save_energy_components() + # Set to the current block number if this is a restart. if x == 0: x = self._current_block @@ -584,3 +592,47 @@ def get_timing(self): def _cleanup(self): del self._dyn + + def _save_energy_components(self): + + from copy import deepcopy + import openmm + + # Get the current context and system. + context = self._dyn._d._omm_mols + system = deepcopy(context.getSystem()) + + # Add each force to a unique group. + for i, f in enumerate(system.getForces()): + f.setForceGroup(i) + + # Create a new context. + new_context = openmm.Context(system, deepcopy(context.getIntegrator())) + new_context.setPositions(context.getState(getPositions=True).getPositions()) + + header = f"{'# Sample':>10}" + record = f"{self._nrg_sample:>10}" + + # Process the records. + for i, f in enumerate(system.getForces()): + state = new_context.getState(getEnergy=True, groups={i}) + header += f"{f.getName():>25}" + record += f"{state.getPotentialEnergy().value_in_unit(openmm.unit.kilocalories_per_mole):>25.2f}" + + # Write to file. + if self._nrg_sample == 0: + with open( + self._config.output_directory / self._filenames["energy_components"], + "w", + ) as f: + f.write(header + "\n") + f.write(record + "\n") + else: + with open( + self._config.output_directory / self._filenames["energy_components"], + "a", + ) as f: + f.write(record + "\n") + + # Increment the sample number. + self._nrg_sample += 1 diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index d97a4ed..30e4dad 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -969,5 +969,7 @@ def _run(sim, is_restart=False): filename=self._fnames[lambda_value]["energy_traj"], ) del system - _logger.success(f"{_lam_sym} = {lambda_value} complete, speed = {speed:.2f} ns day-1") + _logger.success( + f"{_lam_sym} = {lambda_value} complete, speed = {speed:.2f} ns day-1" + ) return True