From e96316aa75a6c636697e21246505e58d3dc17800 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Mon, 15 Jun 2026 20:51:15 +0100 Subject: [PATCH 1/2] Handle restarts from crashes during final checkpoint. --- src/somd2/runner/_repex.py | 80 +++++++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index c4e1fab..91cb1b5 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -716,6 +716,11 @@ def __init__(self, system, config): # Store the name of the replica exchange swap acceptance matrix. self._repex_matrix = self._config.output_directory / "repex_matrix.txt" + # Sentinel file written only after a fully successful run (dynamics + + # trajectory consolidation + backup cleanup). Used to distinguish + # "truly complete" from "complete dynamics but killed during cleanup". + self._done_file = self._config.output_directory / "simulation.done" + # Flag that we haven't equilibrated. self._is_equilibration = False @@ -756,6 +761,11 @@ def __init__(self, system, config): } ) + # On a fresh (non-restart) run, remove any leftover sentinel so that + # a repeated run with --overwrite doesn't immediately exit as complete. + if not self._is_restart and self._done_file.exists(): + self._done_file.unlink() + # Create the dynamics cache. if not self._is_restart: xml_filenames = ( @@ -777,10 +787,33 @@ def __init__(self, system, config): else: _logger.debug("Restarting from file") - # Check to see if the simulation is already complete. time = self._system[0].time() + + # Check to see if the simulation is already complete. + if self._done_file.exists(): + # The runtime may have been extended beyond the previous run. + # If so, clear the sentinel and continue. + if time < self._config.runtime - self._config.timestep: + _logger.info( + "Runtime has been extended. Clearing completion sentinel." + ) + self._done_file.unlink() + else: + _logger.success("Simulation already complete. Exiting.") + _sys.exit(0) + if time > self._config.runtime - self._config.timestep: - _logger.success("Simulation already complete. Exiting.") + # Dynamics finished but the process was killed before cleanup + # completed (e.g. during DCD consolidation or backup removal). + # Consolidate any remaining trajectory chunks and tidy up. + _logger.warning( + "Simulation dynamics are complete but post-run cleanup was " + "not finished. Completing cleanup now." + ) + self._consolidate_trajectories() + self._cleanup() + self._done_file.touch() + _logger.success("Cleanup complete. Exiting.") _sys.exit(0) else: _logger.info( @@ -1300,6 +1333,10 @@ def run(self): # Delete all backup files from the working directory. self._cleanup() + # Write the sentinel file to signal that the run completed fully, + # including trajectory consolidation and cleanup. + self._done_file.touch() + def _run_block( self, index, @@ -1872,6 +1909,45 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False): except Exception as e: return index, e + def _consolidate_trajectories(self): + """ + Consolidate any remaining trajectory chunk files into the final DCD. + + Called when a restart detects that dynamics completed but the process + was killed before post-run cleanup finished. Safe to call when some + replicas are already fully consolidated (no chunks left) — those are + skipped automatically. + """ + from glob import glob as _glob_local + from pathlib import Path as _Path_local + from shutil import copyfile as _copyfile_local + + if not self._config.save_trajectories: + return + + for i in range(len(self._lambda_values)): + traj_filename = self._filenames[i]["trajectory"] + chunk_pattern = f"{self._filenames[i]['trajectory_chunk']}*" + traj_chunks = sorted(_glob_local(chunk_pattern)) + + # On a restart, prepend an existing final DCD as .prev so frames + # from a previous (possibly partial) consolidation are preserved. + path = _Path_local(traj_filename) + if path.exists() and path.stat().st_size > 0: + prev = f"{traj_filename}.prev" + _copyfile_local(traj_filename, prev) + traj_chunks = [prev] + traj_chunks + + if not traj_chunks: + continue + + topology0 = self._filenames["topology0"] + mols = _sr.load([topology0] + traj_chunks) + _sr.save(mols.trajectory(), traj_filename, format=["DCD"]) + + for chunk in traj_chunks: + _Path_local(chunk).unlink() + @staticmethod @_njit def _mix_replicas(num_replicas, energy_matrix, proposed, accepted): From e9f1cc3106eef47033d79de6b7b524f72b65c464 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Mon, 15 Jun 2026 20:21:45 +0100 Subject: [PATCH 2/2] Remove redundant old_states attribute. --- src/somd2/runner/_repex.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 91cb1b5..1d7caf4 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -102,7 +102,6 @@ def __init__( self._lambdas = lambdas self._rest2_scale_factors = rest2_scale_factors self._states = _np.array(range(len(lambdas))) - self._old_states = _np.array(range(len(lambdas))) self._openmm_states = [None] * len(lambdas) self._gcmc_samplers = [None] * len(lambdas) self._gcmc_states = [None] * len(lambdas) @@ -150,7 +149,6 @@ def __getstate__(self): "_lambdas": self._lambdas, "_rest2_scale_factors": self._rest2_scale_factors, "_states": self._states, - "_old_states": self._old_states, "_openmm_states": self._openmm_states, # Don't pickle the GCMC samplers since they need to be recreated. "_gcmc_samplers": len(self._gcmc_samplers) * [None], @@ -511,9 +509,14 @@ def set_states(self, states): """ self._states = states - def mix_states(self): + def mix_states(self, old_states): """ Mix the states of the dynamics objects. + + Parameters + ---------- + old_states : numpy.ndarray + The state indices from before the last replica mix. """ # Mix the states. for i, state in enumerate(self._states): @@ -541,11 +544,7 @@ def mix_states(self): self._gcmc_samplers[i].pop() # Update the swap matrix. - old_state = self._old_states[i] - self._num_swaps[old_state, state] += 1 - - # Store the current states. - self._old_states = self._states.copy() + self._num_swaps[old_states[i], state] += 1 def get_proposed(self): """ @@ -1242,6 +1241,7 @@ def run(self): # Mix the replicas. _logger.info("Mixing replicas") + old_states = self._dynamics_cache.get_states() self._dynamics_cache.set_states( self._mix_replicas( self._config.num_lambda, @@ -1250,7 +1250,7 @@ def run(self): self._dynamics_cache.get_accepted(), ) ) - self._dynamics_cache.mix_states() + self._dynamics_cache.mix_states(old_states) # Snapshot the pre-run state for crash recovery. if self._config.auto_fix_minimise: