From 7e5b1b6acf48d3007a1e4220966eb10f447fcf31 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Mon, 15 Jun 2026 21:12:40 +0100 Subject: [PATCH] Serialise OpenMM state as NumPy arrays, not XML. --- src/somd2/runner/_repex.py | 60 ++++++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index c4e1fab..3039e09 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -468,8 +468,40 @@ def save_openmm_state(self, index): .getState(getPositions=True, getVelocities=True) ) - # Store the state. - self._openmm_states[index] = state + # Store positions, velocities, and box vectors as compact numpy arrays + # rather than the OpenMM State object, which serialises to XML when + # pickled and is orders of magnitude larger. + self._openmm_states[index] = { + "positions": state.getPositions(asNumpy=True), + "velocities": state.getVelocities(asNumpy=True), + "box": state.getPeriodicBoxVectors(asNumpy=True), + } + + @staticmethod + def _apply_openmm_state(context, state): + """ + Apply a saved OpenMM state to a context. + + Parameters + ---------- + + context: openmm.Context + The OpenMM context to update. + + state: dict or openmm.State + The state to apply. Dicts (new format) contain "positions", + "velocities", and "box" numpy arrays. A bare openmm.State is + accepted for backwards compatibility with old checkpoint files. + """ + if isinstance(state, dict): + context.setPositions(state["positions"]) + context.setVelocities(state["velocities"]) + if state["box"] is not None: + context.setPeriodicBoxVectors(*state["box"]) + else: + # Legacy openmm.State from checkpoint files written before this + # format change. + context.setState(state) def save_gcmc_state(self, index): """ @@ -520,7 +552,9 @@ def mix_states(self): # The state has changed. if i != state: _logger.debug(f"Replica {i} seeded from state {state}") - self._dynamics[i].context().setState(self._openmm_states[state]) + self._apply_openmm_state( + self._dynamics[i].context(), self._openmm_states[state] + ) # Swap the water state in the GCMCSamplers. if self._gcmc_samplers[i] is not None: @@ -821,7 +855,9 @@ def __init__(self, system, config): # Reset the OpenMM state, applying the last replica exchange # mixing so the correct post-mix state is restored. state = self._dynamics_cache._states[i] - dynamics.context().setState(self._dynamics_cache._openmm_states[state]) + DynamicsCache._apply_openmm_state( + dynamics.context(), self._dynamics_cache._openmm_states[state] + ) # Reset the GCMC water state and restore statistics. if gcmc_sampler is not None: @@ -1222,9 +1258,11 @@ def run(self): # Snapshot the pre-run state for crash recovery. if self._config.auto_fix_minimise: for i, state in enumerate(self._dynamics_cache.get_states()): - self._dynamics_cache._dynamics[ - i - ]._d._pre_run_state = self._dynamics_cache._openmm_states[state] + self._dynamics_cache._dynamics[i]._d._pre_run_state = ( + self._dynamics_cache._dynamics[i] + .context() + .getState(getPositions=True, getVelocities=True) + ) # This is a checkpoint cycle. if is_checkpoint: @@ -1734,14 +1772,18 @@ def _compute_energies(self, index): # Loop over the states. for i in range(self._config.num_lambda): # Set the state. - dynamics.context().setState(self._dynamics_cache._openmm_states[i]) + DynamicsCache._apply_openmm_state( + dynamics.context(), self._dynamics_cache._openmm_states[i] + ) dynamics._d._clear_state() # Compute and store the energy for this state. energies[i] = dynamics.current_potential_energy().value() # Reset the state. - dynamics.context().setState(self._dynamics_cache._openmm_states[index]) + DynamicsCache._apply_openmm_state( + dynamics.context(), self._dynamics_cache._openmm_states[index] + ) return index, energies