From f09fb61e5443353648641dc4c2783e4d8f580745 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 31 Mar 2026 12:31:14 +0100 Subject: [PATCH 01/14] Add support for terminal flip MC. --- src/somd2/_utils/__init__.py | 2 + src/somd2/config/_config.py | 69 ++++ src/somd2/runner/_base.py | 49 +++ src/somd2/runner/_repex.py | 64 +++- src/somd2/runner/_runner.py | 142 ++++++++- src/somd2/runner/_terminal_flip.py | 494 +++++++++++++++++++++++++++++ tests/conftest.py | 35 ++ tests/runner/test_terminal_flip.py | 406 ++++++++++++++++++++++++ 8 files changed, 1257 insertions(+), 4 deletions(-) create mode 100644 src/somd2/runner/_terminal_flip.py create mode 100644 tests/runner/test_terminal_flip.py diff --git a/src/somd2/_utils/__init__.py b/src/somd2/_utils/__init__.py index ec25a367..74897957 100644 --- a/src/somd2/_utils/__init__.py +++ b/src/somd2/_utils/__init__.py @@ -23,8 +23,10 @@ if _platform.system() == "Windows": _lam_sym = "lambda" + _delta_sym = "delta" else: _lam_sym = "λ" + _delta_sym = "ΔE" del _platform diff --git a/src/somd2/config/_config.py b/src/somd2/config/_config.py index 77333718..ba1dba26 100644 --- a/src/somd2/config/_config.py +++ b/src/somd2/config/_config.py @@ -139,6 +139,8 @@ def __init__( replica_exchange=False, randomise_velocities=False, perturbed_system=None, + terminal_flip_frequency=None, + terminal_flip_angle=None, gcmc=False, gcmc_frequency=None, gcmc_selection=None, @@ -377,6 +379,17 @@ def __init__( end state (lambda = 1). This will be used as the starting conformation all lambda windows > 0.5 when performing a replica exchange simulation. + terminal_flip_frequency: str + Frequency at which to attempt terminal ring flip Monte Carlo moves. If None + (the default), no terminal flip moves will be performed. When set, terminal + ring groups in perturbable molecules are detected automatically using Sire's + native connectivity. This must be a multiple of 'energy_frequency'. + + terminal_flip_angle: str + Override the flip angle used for all terminal ring groups, e.g. + ``"180 degrees"``. If None (the default), the angle is determined + automatically for each group from its geometry. + gcmc: bool Whether to perform Grand Canonical Monte Carlo (GCMC) water insertions/deletions. @@ -559,6 +572,8 @@ def __init__( self.replica_exchange = replica_exchange self.randomise_velocities = randomise_velocities self.perturbed_system = perturbed_system + self.terminal_flip_frequency = terminal_flip_frequency + self.terminal_flip_angle = terminal_flip_angle self.gcmc = gcmc self.gcmc_frequency = gcmc_frequency self.gcmc_selection = gcmc_selection @@ -1994,6 +2009,60 @@ def perturbed_system(self, perturbed_system): self._perturbed_system = None self._perturbed_system_file = None + @property + def terminal_flip_frequency(self): + return self._terminal_flip_frequency + + @terminal_flip_frequency.setter + def terminal_flip_frequency(self, terminal_flip_frequency): + if terminal_flip_frequency is not None: + if not isinstance(terminal_flip_frequency, str): + raise TypeError("'terminal_flip_frequency' must be of type 'str'") + + from sire.units import picosecond + + try: + t = _sr.u(terminal_flip_frequency) + except Exception: + raise ValueError( + f"Unable to parse 'terminal_flip_frequency' as a Sire GeneralUnit: " + f"{terminal_flip_frequency}" + ) + + if t.value() != 0 and not t.has_same_units(picosecond): + raise ValueError("'terminal_flip_frequency' units are invalid.") + + self._terminal_flip_frequency = t + else: + self._terminal_flip_frequency = None + + @property + def terminal_flip_angle(self): + return self._terminal_flip_angle + + @terminal_flip_angle.setter + def terminal_flip_angle(self, terminal_flip_angle): + if terminal_flip_angle is not None: + if not isinstance(terminal_flip_angle, str): + raise TypeError("'terminal_flip_angle' must be of type 'str'") + + from sire.units import degrees + + try: + a = _sr.u(terminal_flip_angle) + except Exception: + raise ValueError( + f"Unable to parse 'terminal_flip_angle' as a Sire GeneralUnit: " + f"{terminal_flip_angle}" + ) + + if not a.has_same_units(degrees): + raise ValueError("'terminal_flip_angle' units are invalid.") + + self._terminal_flip_angle = a + else: + self._terminal_flip_angle = None + @property def gcmc(self): return self._gcmc diff --git a/src/somd2/runner/_base.py b/src/somd2/runner/_base.py index 6982c430..18e67407 100644 --- a/src/somd2/runner/_base.py +++ b/src/somd2/runner/_base.py @@ -664,6 +664,55 @@ def __init__(self, system, config): # Store the excess chemcical potential value. self._mu_ex = self._config.gcmc_excess_chemical_potential.value() + # Terminal flip specific validation and setup. + if self._config.terminal_flip_frequency is not None: + from math import isclose + + # Make sure the terminal flip frequency is a multiple of the + # energy frequency. + ratio = ( + self._config.terminal_flip_frequency / self._config.energy_frequency + ).value() + + if not isclose(ratio, round(ratio), abs_tol=1e-4): + msg = "'terminal_flip_frequency' must be a multiple of 'energy_frequency'." + _logger.error(msg) + raise ValueError(msg) + + # Auto-detect terminal ring groups using Sire connectivity. + from ._terminal_flip import detect_terminal_groups + + if isinstance(self._system, list): + mols = self._system[0] + else: + mols = self._system + + flip_angle = ( + self._config.terminal_flip_angle.to("degrees").value() + if self._config.terminal_flip_angle is not None + else None + ) + self._terminal_groups = detect_terminal_groups(mols, flip_angle=flip_angle) + + if not self._terminal_groups: + _logger.warning( + "No terminal ring groups detected. Terminal flip moves will not " + "be performed." + ) + else: + _logger.info( + f"Detected {len(self._terminal_groups)} terminal ring group(s) " + f"for terminal flip MC." + ) + for i, (angle, indices) in enumerate(self._terminal_groups): + _logger.info( + f" Group {i}: flip angle = {angle}°, " + f"anchor = {indices[0]}, pivot = {indices[1]}, " + f"{len(indices) - 2} mobile atom(s)" + ) + else: + self._terminal_groups = [] + # Store the initial system time. if isinstance(self._system, list): self._initial_time = [] diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 83afa649..6e60b34e 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -854,6 +854,21 @@ def __init__(self, system, config): else: self._start_block = 0 + # Create the terminal flip sampler (if terminal groups were detected). + if self._terminal_groups: + from ._terminal_flip import TerminalFlipSampler + + self._terminal_flip_sampler = TerminalFlipSampler( + self._terminal_groups, + float(self._config.temperature.value()), + ) + _logger.info( + f"Terminal flip sampler ready for replica exchange " + f"({len(self._terminal_groups)} group(s))" + ) + else: + self._terminal_flip_sampler = None + from threading import Lock # Create a lock to guard the dynamics cache. @@ -1001,6 +1016,23 @@ def run(self): else: cycles_per_gcmc = cycles + 1 + # Work out the number of cycles per terminal flip move. + if ( + self._config.terminal_flip_frequency is not None + and self._terminal_flip_sampler is not None + ): + cycles_per_flip = max( + 1, + round( + ( + self._config.terminal_flip_frequency + / self._config.energy_frequency + ).value() + ), + ) + else: + cycles_per_flip = cycles + 1 + # Initialise the threshold for the next checkpoint cycle. This is a float # to handle non-integer ratios between the checkpoint and energy frequencies. next_checkpoint = cycles_per_checkpoint @@ -1028,6 +1060,9 @@ def run(self): # Whether to perform a GCMC move before the dynamics block. is_gcmc = (i + 1) % cycles_per_gcmc == 0 + # Whether to perform a terminal flip move before the dynamics block. + is_terminal_flip = (i + 1) % cycles_per_flip == 0 + # Whether a frame is saved at the end of the cycle. write_gcmc_ghosts = (i + 1) % cycles_per_frame == 0 @@ -1043,6 +1078,7 @@ def run(self): repeat(self._lambda_values), repeat(is_gcmc), repeat(write_gcmc_ghosts), + repeat(is_terminal_flip), ): if not result: _logger.error( @@ -1128,6 +1164,15 @@ def run(self): ) self._dynamics_cache.mix_states() + # Log terminal flip acceptance rate at each cycle. + if self._terminal_flip_sampler is not None: + _logger.info( + f"Terminal flip acceptance rate: " + f"{self._terminal_flip_sampler.acceptance_rate:.3f} " + f"({self._terminal_flip_sampler.num_accepted}/" + f"{self._terminal_flip_sampler.num_attempted})" + ) + # This is a checkpoint cycle. if is_checkpoint: # Update the block number. @@ -1202,6 +1247,7 @@ def _run_block( lambdas, is_gcmc=False, write_gcmc_ghosts=False, + is_terminal_flip=False, ): """ Run a dynamics block for a given replica. @@ -1225,6 +1271,10 @@ def _run_block( Whether to write the indices of GCMC ghost residues to file. + is_terminal_flip: bool + Whether a terminal flip MC move should be performed before the + dynamics block. + Returns ------- @@ -1262,6 +1312,11 @@ def _run_block( if write_gcmc_ghosts: gcmc_sampler.write_ghost_residues() + # Perform a terminal flip move before dynamics if requested. + if self._terminal_flip_sampler is not None and is_terminal_flip: + _logger.info(f"Performing terminal flip move at {_lam_sym} = {lam:.5f}") + self._terminal_flip_sampler.move(dynamics.context()) + _logger.info(f"Running dynamics at {_lam_sym} = {lam:.5f}") # Draw new velocities from the Maxwell-Boltzmann distribution. @@ -1701,9 +1756,16 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False): # Push the PyCUDA context on top of the stack. gcmc_sampler.push() try: + n_moves = gcmc_sampler._num_moves + acc_str = ( + f", acceptance rate = {gcmc_sampler.move_acceptance_probability():.3f}" + f" (ins = {gcmc_sampler.num_insertions()}, del = {gcmc_sampler.num_deletions()})" + if n_moves > 0 + else "" + ) _logger.info( f"Current number of waters in GCMC volume at {_lam_sym} = {lam:.5f} " - f"is {gcmc_sampler.num_waters()}" + f"is {gcmc_sampler.num_waters()}{acc_str}" ) finally: # Remove the PyCUDA context from the stack. diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index 32905b76..314e36fb 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -465,6 +465,31 @@ def generate_lam_vals(lambda_base, increment=0.001): else: gcmc_sampler = None + # Create the terminal flip sampler (if terminal groups were detected). + if self._terminal_groups: + from ._terminal_flip import TerminalFlipSampler + + terminal_flip_sampler = TerminalFlipSampler( + self._terminal_groups, + float(self._config.temperature.value()), + ) + flip_every = max( + 1, + round( + ( + self._config.terminal_flip_frequency + / self._config.energy_frequency + ).value() + ), + ) + _logger.info( + f"Terminal flip sampler ready at {_lam_sym} = {lambda_value:.5f} " + f"(every {flip_every} energy block(s))" + ) + else: + terminal_flip_sampler = None + flip_every = None + # Minimisation. if self._config.minimise: constraint = self._config.constraint @@ -719,6 +744,7 @@ def generate_lam_vals(lambda_base, increment=0.001): runtime = _sr.u("0ps") save_frames = self._config.frame_frequency > 0 next_frame = self._config.frame_frequency + flip_counter = 0 # Loop until we reach the runtime. while runtime < checkpoint_frequency: @@ -734,6 +760,17 @@ def generate_lam_vals(lambda_base, increment=0.001): finally: gcmc_sampler.pop() + # Perform a terminal flip move at the specified frequency. + if ( + terminal_flip_sampler is not None + and flip_counter % flip_every == 0 + ): + _logger.info( + f"Performing terminal flip move at " + f"{_lam_sym} = {lambda_value:.5f}" + ) + terminal_flip_sampler.move(dynamics.context()) + # Write ghost residues immediately after the GCMC # move if a frame will be saved in the upcoming # dynamics block. @@ -768,8 +805,41 @@ def generate_lam_vals(lambda_base, increment=0.001): ), ) - # Update the runtime. + # Update the runtime and flip counter. runtime += self._config.energy_frequency + flip_counter += 1 + + elif terminal_flip_sampler is not None: + # Terminal flip without GCMC: perform flip moves at the + # specified frequency then run the full dynamics block. + n_flips = max( + 1, + round( + ( + checkpoint_frequency + / self._config.terminal_flip_frequency + ).value() + ), + ) + for _ in range(n_flips): + _logger.info( + f"Performing terminal flip move at " + f"{_lam_sym} = {lambda_value:.5f}" + ) + terminal_flip_sampler.move(dynamics.context()) + + dynamics.run( + checkpoint_frequency, + energy_frequency=self._config.energy_frequency, + frame_frequency=self._config.frame_frequency, + lambda_windows=lambda_array, + rest2_scale_factors=rest2_scale_factors, + save_velocities=self._config.save_velocities, + auto_fix_minimise=True, + num_energy_neighbours=num_energy_neighbours, + null_energy=self._config.null_energy, + save_crash_report=self._config.save_crash_report, + ) else: dynamics.run( @@ -859,13 +929,30 @@ def generate_lam_vals(lambda_base, increment=0.001): if gcmc_sampler is not None: gcmc_sampler.push() try: + n_moves = gcmc_sampler._num_moves + acc_str = ( + f", acceptance rate = {gcmc_sampler.move_acceptance_probability():.3f}" + f" (ins = {gcmc_sampler.num_insertions()}, del = {gcmc_sampler.num_deletions()})" + if n_moves > 0 + else "" + ) _logger.info( f"Current number of waters in GCMC volume at {_lam_sym} = {lambda_value:.5f} " - f"is {gcmc_sampler.num_waters()}" + f"is {gcmc_sampler.num_waters()}{acc_str}" ) finally: gcmc_sampler.pop() + # Log terminal flip acceptance rate. + if terminal_flip_sampler is not None: + _logger.info( + f"Terminal flip acceptance rate at " + f"{_lam_sym} = {lambda_value:.5f}: " + f"{terminal_flip_sampler.acceptance_rate:.3f} " + f"({terminal_flip_sampler.num_accepted}/" + f"{terminal_flip_sampler.num_attempted})" + ) + if is_final_block: _logger.success( f"{_lam_sym} = {lambda_value:.5f} complete, speed = {speed:.2f} ns day-1" @@ -880,6 +967,14 @@ def generate_lam_vals(lambda_base, increment=0.001): block += 1 block_start = _timer() try: + # Perform one terminal flip at the start of the remainder block. + if terminal_flip_sampler is not None: + _logger.info( + f"Performing terminal flip move at " + f"{_lam_sym} = {lambda_value:.5f}" + ) + terminal_flip_sampler.move(dynamics.context()) + dynamics.run( rem, energy_frequency=self._config.energy_frequency, @@ -952,6 +1047,7 @@ def generate_lam_vals(lambda_base, increment=0.001): runtime = _sr.u("0ps") save_frames = self._config.frame_frequency > 0 next_frame = self._config.frame_frequency + flip_counter = 0 # Loop until we reach the runtime. while runtime < time: @@ -967,6 +1063,17 @@ def generate_lam_vals(lambda_base, increment=0.001): finally: gcmc_sampler.pop() + # Perform a terminal flip move at the specified frequency. + if ( + terminal_flip_sampler is not None + and flip_counter % flip_every == 0 + ): + _logger.info( + f"Performing terminal flip move at " + f"{_lam_sym} = {lambda_value:.5f}" + ) + terminal_flip_sampler.move(dynamics.context()) + # Write ghost residues immediately after the GCMC # move if a frame will be saved in the upcoming # dynamics block. @@ -991,8 +1098,37 @@ def generate_lam_vals(lambda_base, increment=0.001): save_crash_report=self._config.save_crash_report, ) - # Update the runtime. + # Update the runtime and flip counter. runtime += self._config.energy_frequency + flip_counter += 1 + + elif terminal_flip_sampler is not None: + # Terminal flip without GCMC: perform flip moves at the + # start then run the full dynamics block. + n_flips = max( + 1, + round((time / self._config.terminal_flip_frequency).value()), + ) + for _ in range(n_flips): + _logger.info( + f"Performing terminal flip move at " + f"{_lam_sym} = {lambda_value:.5f}" + ) + terminal_flip_sampler.move(dynamics.context()) + + dynamics.run( + time, + energy_frequency=self._config.energy_frequency, + frame_frequency=self._config.frame_frequency, + lambda_windows=lambda_array, + rest2_scale_factors=rest2_scale_factors, + save_velocities=self._config.save_velocities, + auto_fix_minimise=True, + num_energy_neighbours=num_energy_neighbours, + null_energy=self._config.null_energy, + save_crash_report=self._config.save_crash_report, + ) + else: dynamics.run( time, diff --git a/src/somd2/runner/_terminal_flip.py b/src/somd2/runner/_terminal_flip.py new file mode 100644 index 00000000..67901c69 --- /dev/null +++ b/src/somd2/runner/_terminal_flip.py @@ -0,0 +1,494 @@ +###################################################################### +# SOMD2: GPU accelerated alchemical free-energy engine. +# +# Copyright: 2023-2026 +# +# Authors: The OpenBioSim Team +# +# SOMD2 is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# SOMD2 is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with SOMD2. If not, see . +##################################################################### + +# Adapted from the terminal ring flip MC implemenation in GrandFEP: +# https://github.com/deGrootLab/GrandFEP +# (Released under the MIT License.) + +__all__ = ["TerminalFlipSampler", "detect_terminal_groups"] + +import numpy as _np + +import sire.legacy.Mol as _Mol + +from somd2 import _logger +from somd2._utils import _delta_sym + + +def _auto_flip_angle(mol, anchor_idx, pivot_idx, ring_neighbor_idxs): + """ + Compute the flip angle for a terminal group from the molecular geometry. + + The angle is measured between the two ring neighbours of the pivot, + projected onto the plane perpendicular to the rotation axis (anchor → + pivot). For a planar C₂-symmetric ring this is 180°; for higher-symmetry + rings it will be smaller. + + Parameters + ---------- + + mol : sire.legacy.Mol.Molecule + The perturbable molecule. + + anchor_idx : int + Molecule-local index of the anchor atom. + + pivot_idx : int + Molecule-local index of the pivot atom. + + ring_neighbor_idxs : list of int + Molecule-local indices of the two ring atoms directly bonded to the + pivot (i.e. the ortho atoms for a benzene ring). + + Returns + ------- + + float + Raw angle in degrees between the projected ring-neighbour vectors. + """ + + def _coords(idx): + v = mol.atom(_Mol.AtomIdx(idx)).property("coordinates") + return _np.array([v.x().value(), v.y().value(), v.z().value()]) + + anchor = _coords(anchor_idx) + pivot = _coords(pivot_idx) + n1 = _coords(ring_neighbor_idxs[0]) + n2 = _coords(ring_neighbor_idxs[1]) + + # Unit rotation axis from anchor to pivot. + k = pivot - anchor + k = k / _np.linalg.norm(k) + + # Project each ring-neighbour displacement onto the plane perp to k. + v1 = n1 - pivot + v1_perp = v1 - _np.dot(v1, k) * k + + v2 = n2 - pivot + v2_perp = v2 - _np.dot(v2, k) * k + + # Angle between the two projected vectors. + cos_angle = _np.dot(v1_perp, v2_perp) / ( + _np.linalg.norm(v1_perp) * _np.linalg.norm(v2_perp) + ) + return float(_np.degrees(_np.arccos(_np.clip(cos_angle, -1.0, 1.0)))) + + +def _round_to_symmetry_angle(raw_angle, tolerance=10.0): + """ + Round ``raw_angle`` to the nearest crystallographic symmetry angle + (360°/n for n = 2 … 12). Returns ``None`` if the closest match is more + than ``tolerance`` degrees away, indicating that the ring has no useful + rotational symmetry. + + Parameters + ---------- + + raw_angle : float + Measured angle in degrees. + + tolerance : float + Maximum deviation (degrees) from a symmetry angle. Default is 10.0. + + Returns + ------- + + float or None + The nearest symmetry angle in degrees, or None if none is close enough. + """ + symmetry_angles = [360.0 / n for n in range(2, 13)] + diffs = [abs(raw_angle - a) for a in symmetry_angles] + min_idx = int(_np.argmin(diffs)) + if diffs[min_idx] > tolerance: + return None + return symmetry_angles[min_idx] + + +def detect_terminal_groups(system, flip_angle=None): + """ + Detect terminal ring groups in perturbable molecules using Sire's native + connectivity. + + A terminal ring group is identified by a bond between a non-ring atom + (the anchor) and a ring atom (the pivot), where the ring side of the bond + is connected to the rest of the molecule only through that single bond. + The mobile atoms are all atoms reachable from the pivot when the + anchor-pivot bond is cut. + + Parameters + ---------- + + system : sire system or molecule group + The Sire system containing perturbable molecules. + + flip_angle : float or None + The flip angle in degrees. If None (the default), the angle is + determined automatically from the geometry of each terminal group + by measuring the angle between the two ring neighbours of the pivot + projected perpendicular to the rotation axis, then rounding to the + nearest crystallographic symmetry angle (360°/n for n = 2..12). If + a float is given it overrides the geometric measurement for all + groups. + + Returns + ------- + + list of tuple + Each entry is (angle, [anchor_idx, pivot_idx, mobile_idx_0, ...]) + where all indices are absolute atom indices corresponding to OpenMM + atom ordering. + """ + terminal_groups = [] + + # Get the perturbable molecules. + try: + pert_mols = system.molecules("property is_perturbable") + except Exception: + _logger.warning( + "No perturbable molecules found. Terminal flip detection skipped." + ) + return terminal_groups + + # All atoms in the system, used to obtain absolute (OpenMM) atom indices. + all_atoms = system.atoms() + + for mol in pert_mols: + try: + connectivity = mol.property("connectivity") + except Exception: + _logger.warning(f"Molecule {mol} has no 'connectivity' property. Skipping.") + continue + + # Skip molecules whose connectivity changes between end states (e.g. + # ring-breaking/growing perturbations). Terminal groups detected from + # the lambda=0 connectivity would be invalid at lambda=1. + try: + conn0 = mol.property("connectivity0") + conn1 = mol.property("connectivity1") + if conn0 != conn1: + _logger.warning( + f"Molecule {mol} has different connectivity at lambda=0 and " + "lambda=1 (ring-breaking/growing perturbation). Skipping " + "terminal flip detection for this molecule." + ) + continue + except Exception: + pass + + num_atoms = mol.num_atoms() + seen_bonds = set() + + for i in range(num_atoms): + atom_i_idx = _Mol.AtomIdx(i) + + # Only consider non-ring atoms as anchors. + if connectivity.in_ring(atom_i_idx): + continue + + # Skip dead-end atoms (e.g. hydrogen bonded only to a ring + # carbon): a valid anchor must be part of a chain, so it needs + # at least two connections (one to the pivot, one elsewhere). + if len(connectivity.connections_to(atom_i_idx)) < 2: + continue + + for neighbor_idx in connectivity.connections_to(atom_i_idx): + j = neighbor_idx.value() + + # Only consider ring atoms as pivots. + if not connectivity.in_ring(_Mol.AtomIdx(j)): + continue + + # Avoid processing the same bond twice. + bond_key = (min(i, j), max(i, j)) + if bond_key in seen_bonds: + continue + seen_bonds.add(bond_key) + + # Collect mobile atoms via BFS from the pivot, not crossing + # the anchor. The pivot itself does not move (it is the + # rotation centre), so it is excluded from the mobile list. + mobile = _bfs_mobile(connectivity, i, j, num_atoms) + + if not mobile: + continue + + # Determine the flip angle for this group. + if flip_angle is not None: + group_angle = flip_angle + else: + # Find the two ring neighbours of the pivot (mobile atoms + # directly bonded to the pivot that are in the ring). + mobile_set = set(mobile) + pivot_idx_obj = _Mol.AtomIdx(j) + ring_neighbors = [ + n.value() + for n in connectivity.connections_to(pivot_idx_obj) + if n.value() in mobile_set + and connectivity.in_ring(_Mol.AtomIdx(n.value())) + ] + + if len(ring_neighbors) != 2: + _logger.warning( + f"Expected 2 ring neighbours for pivot atom {j}, " + f"found {len(ring_neighbors)}. Skipping group." + ) + continue + + raw = _auto_flip_angle(mol, i, j, ring_neighbors) + group_angle = _round_to_symmetry_angle(raw) + + if group_angle is None: + _logger.warning( + f"Terminal group at pivot atom {j} has no recognised " + f"rotational symmetry (raw angle = {raw:.1f}°). " + "Skipping group." + ) + continue + + _logger.debug( + f"Terminal group at pivot atom {j}: auto-detected flip " + f"angle = {group_angle}° (raw = {raw:.1f}°)." + ) + + # Map molecule-local indices to absolute system indices. + anchor_abs = all_atoms.find(mol.atom(atom_i_idx)) + pivot_abs = all_atoms.find(mol.atom(_Mol.AtomIdx(j))) + mobile_abs = [all_atoms.find(mol.atom(_Mol.AtomIdx(k))) for k in mobile] + + terminal_groups.append( + (group_angle, [anchor_abs, pivot_abs] + mobile_abs) + ) + + return terminal_groups + + +def _bfs_mobile(connectivity, anchor_idx, pivot_idx, num_atoms): + """ + Breadth-first search from ``pivot_idx``, not crossing ``anchor_idx``. + + Returns a sorted list of atom indices for atoms that will be rotated + (all reachable atoms except the anchor and the pivot itself, since the + pivot is the fixed rotation centre). + + Parameters + ---------- + + connectivity : sire.legacy.Mol.Connectivity + The molecular connectivity object. + + anchor_idx : int + Index of the anchor atom (defines the rotation axis start; fixed). + + pivot_idx : int + Index of the pivot atom (rotation centre; fixed). + + num_atoms : int + Total number of atoms in the molecule. + + Returns + ------- + + list of int + Sorted list of mobile atom indices. + """ + visited = {anchor_idx, pivot_idx} + queue = [pivot_idx] + + while queue: + current = queue.pop(0) + for neighbor in connectivity.connections_to(_Mol.AtomIdx(current)): + n = neighbor.value() + if n not in visited: + visited.add(n) + queue.append(n) + + # Exclude the anchor and pivot; only mobile atoms are rotated. + return sorted(visited - {anchor_idx, pivot_idx}) + + +class TerminalFlipSampler: + """ + Monte Carlo sampler for terminal ring flip moves. + + Each move selects one terminal group at random and attempts to rotate + its mobile atoms by ±``flip_angle`` degrees around the bond axis from + the anchor atom to the pivot atom. The move is accepted or rejected + according to the Metropolis criterion. + + The rotation uses Rodrigues' rotation formula:: + + v_rot = v·cos θ + (k × v)·sin θ + k·(k·v)·(1 − cos θ) + + where ``k`` is the unit vector along the rotation axis (anchor → pivot) + and ``v`` is the displacement of a mobile atom from the pivot. + + The sign of ``flip_angle`` is chosen uniformly at random so that the + proposal is symmetric, satisfying detailed balance for any angle. + """ + + def __init__(self, terminal_groups, temperature): + """ + Parameters + ---------- + + terminal_groups : list of tuple + Each entry is (angle, [anchor_idx, pivot_idx, mobile_idx_0, ...]) + where indices are absolute OpenMM atom indices. + + temperature : float + Simulation temperature in Kelvin. + """ + self._terminal_groups = terminal_groups + + # kBT in kJ/mol (R = 8.314462618e-3 kJ mol-1 K-1). + self._kBT = 8.314462618e-3 * temperature + + self._num_attempted = 0 + self._num_accepted = 0 + + def _rotate(self, context, group_idx, angle): + """ + Rotate the mobile atoms of a terminal group by ``angle`` degrees + around the anchor-to-pivot axis, updating the context in place. + + Parameters + ---------- + + context : openmm.Context + The active OpenMM context. + + group_idx : int + Index into ``self._terminal_groups`` selecting the group to rotate. + + angle : float + Rotation angle in degrees. + """ + from openmm import unit as _omm_unit + + _, atom_indices = self._terminal_groups[group_idx] + + positions = ( + context.getState(getPositions=True) + .getPositions(asNumpy=True) + .value_in_unit(_omm_unit.nanometer) + ) + + theta = _np.deg2rad(angle) + cos_t = _np.cos(theta) + sin_t = _np.sin(theta) + + # Anchor (axis start, fixed) and pivot (rotation centre, fixed). + p0 = positions[atom_indices[0]] + p1 = positions[atom_indices[1]] + + # Unit rotation axis from anchor to pivot. + axis = p1 - p0 + axis = axis / _np.linalg.norm(axis) + + # Rotate mobile atoms using Rodrigues' formula. + new_positions = positions.copy() + for atom_idx in atom_indices[2:]: + v = positions[atom_idx] - p1 + new_positions[atom_idx] = ( + p1 + + v * cos_t + + _np.cross(axis, v) * sin_t + + axis * _np.dot(axis, v) * (1.0 - cos_t) + ) + + context.setPositions(new_positions * _omm_unit.nanometer) + + def move(self, context): + """ + Attempt one terminal flip Monte Carlo move. + + A terminal group is chosen at random. The mobile atoms are rotated + by ±``flip_angle`` around the anchor-to-pivot axis. The move is + accepted with Metropolis probability ``min(1, exp(-ΔE / kBT))``. + + Parameters + ---------- + + context : openmm.Context + The active OpenMM context. + """ + from openmm import unit as _omm_unit + + if not self._terminal_groups: + return + + self._num_attempted += 1 + + # Randomly select one terminal group. + group_idx = _np.random.randint(len(self._terminal_groups)) + angle, _ = self._terminal_groups[group_idx] + + # Retrieve current positions and energy before the move. + state = context.getState(getPositions=True, getEnergy=True) + old_positions = state.getPositions(asNumpy=True).value_in_unit( + _omm_unit.nanometer + ) + e_old = state.getPotentialEnergy().value_in_unit(_omm_unit.kilojoule_per_mole) + + # Random sign gives a symmetric proposal (detailed balance). + signed_angle = float(_np.random.choice([-1, 1])) * angle + self._rotate(context, group_idx, signed_angle) + + # Evaluate the energy of the proposed configuration. + e_new = ( + context.getState(getEnergy=True) + .getPotentialEnergy() + .value_in_unit(_omm_unit.kilojoule_per_mole) + ) + + # Metropolis acceptance criterion. + delta_e = (e_new - e_old) / self._kBT + if delta_e <= 0.0 or _np.random.random() < _np.exp(-delta_e): + self._num_accepted += 1 + _logger.debug( + f"Terminal flip accepted (group {group_idx}, " + f"{_delta_sym} = {e_new - e_old:.2f} kJ/mol, " + f"acc = {min(1.0, _np.exp(-delta_e)):.3f})" + ) + else: + context.setPositions(old_positions * _omm_unit.nanometer) + _logger.debug( + f"Terminal flip rejected (group {group_idx}, " + f"{_delta_sym} = {e_new - e_old:.2f} kJ/mol, " + f"acc = {_np.exp(-delta_e):.3f})" + ) + + @property + def num_attempted(self): + """Total number of terminal flip moves attempted.""" + return self._num_attempted + + @property + def num_accepted(self): + """Total number of terminal flip moves accepted.""" + return self._num_accepted + + @property + def acceptance_rate(self): + """Fraction of attempted moves that were accepted.""" + if self._num_attempted == 0: + return 0.0 + return self._num_accepted / self._num_attempted diff --git a/tests/conftest.py b/tests/conftest.py index 02dbb2a8..de64927d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,45 @@ import os + import pytest import sire as sr has_cuda = True if "CUDA_VISIBLE_DEVICES" in os.environ else False +@pytest.fixture(scope="session") +def diphenylethane_mols(): + """ + Load a merged perturbable system built from 1,2-diphenylethane (reference, + lambda = 0) and 1,2-diphenylethanol (perturbed, lambda = 1). + + SMILES: + reference : c1ccccc1CCc1ccccc1 + perturbed : OC(Cc1ccccc1)c1ccccc1 + + Both phenyl rings are terminal, so two terminal ring groups should be + detected. + """ + mols = sr.load_test_files("12diphenylethane_12diphenylethanol.s3") + return sr.morph.link_to_reference(mols) + + +@pytest.fixture(scope="session") +def phenethyl_mols(): + """ + Load a merged perturbable system built from phenethylamine (reference, + lambda = 0) and 2-phenylethanol (perturbed, lambda = 1). + + SMILES: + reference : NCCc1ccccc1 + perturbed : OCCc1ccccc1 + + The phenyl ring is terminal — attached to the aliphatic chain by a single + exocyclic bond — making it the only detectable terminal ring group. + """ + mols = sr.load_test_files("phenethylamine_2phenylethanol.s3") + return sr.morph.link_to_reference(mols) + + @pytest.fixture(scope="session") def ethane_methanol(): mols = sr.load(sr.expand(sr.tutorial_url, "merged_molecule.s3")) diff --git a/tests/runner/test_terminal_flip.py b/tests/runner/test_terminal_flip.py new file mode 100644 index 00000000..ea6dea9b --- /dev/null +++ b/tests/runner/test_terminal_flip.py @@ -0,0 +1,406 @@ +""" +Tests for terminal ring flip Monte Carlo functionality. + +Two fixtures are used (both defined in conftest.py): + +``phenethyl_mols`` + Merged system from phenethylamine (NCCc1ccccc1) and 2-phenylethanol + (OCCc1ccccc1) via ``sr.load_test_files("phenethylamine_2phenylethanol.s3")``. + Contains one terminal phenyl ring. + +``diphenylethane_mols`` + Merged system from 1,2-diphenylethane (c1ccccc1CCc1ccccc1) and + 1,2-diphenylethanol (OC(Cc1ccccc1)c1ccccc1) via + ``sr.load_test_files("12diphenylethane_12diphenylethanol.s3")``. + Contains two terminal phenyl rings. +""" + +import pytest +import tempfile + +import numpy as np + +from somd2.config import Config +from somd2.runner import Runner +from somd2.runner._terminal_flip import TerminalFlipSampler, detect_terminal_groups + +# --------------------------------------------------------------------------- +# detect_terminal_groups +# --------------------------------------------------------------------------- + + +def test_no_terminal_groups(ethane_methanol): + """ + The ethane → methanol perturbation contains no rings, so no terminal + ring groups should be detected. + """ + groups = detect_terminal_groups(ethane_methanol) + assert groups == [] + + +def test_detect_one_terminal_group(phenethyl_mols): + """ + The phenethyl system has exactly one terminal ring (the phenyl group + attached via the –CH2– chain). H atoms bonded to ring carbons must not + be reported as separate groups. + """ + groups = detect_terminal_groups(phenethyl_mols) + assert len(groups) == 1 + + +def test_terminal_group_flip_angle(phenethyl_mols): + """ + The default flip angle should be 180°. + """ + groups = detect_terminal_groups(phenethyl_mols) + angle, _ = groups[0] + assert angle == pytest.approx(180.0) + + +def test_terminal_group_atom_count(phenethyl_mols): + """ + For a mono-substituted benzene ring: + - 1 anchor atom (aliphatic C adjacent to ring) + - 1 pivot atom (ipso ring C) + - 5 mobile ring carbons + - 5 mobile ring hydrogens + Total indices list length = 12. + """ + groups = detect_terminal_groups(phenethyl_mols) + _, indices = groups[0] + # anchor + pivot + 10 mobile atoms + assert len(indices) == 12 + + +def test_anchor_not_in_mobile(phenethyl_mols): + """ + The anchor index must not appear in the mobile atom list. + """ + groups = detect_terminal_groups(phenethyl_mols) + _, indices = groups[0] + anchor_idx = indices[0] + mobile_indices = indices[2:] + assert anchor_idx not in mobile_indices + + +def test_pivot_not_in_mobile(phenethyl_mols): + """ + The pivot index must not appear in the mobile atom list (the pivot is the + fixed rotation centre). + """ + groups = detect_terminal_groups(phenethyl_mols) + _, indices = groups[0] + pivot_idx = indices[1] + mobile_indices = indices[2:] + assert pivot_idx not in mobile_indices + + +def test_auto_flip_angle_phenethyl(phenethyl_mols): + """ + With no flip_angle override, the angle for a monosubstituted benzene ring + should be auto-detected as 180° (C2 symmetry). + """ + groups = detect_terminal_groups(phenethyl_mols) + angle, _ = groups[0] + assert angle == pytest.approx(180.0) + + +def test_auto_flip_angle_diphenylethane(diphenylethane_mols): + """ + Both terminal phenyl groups in the diphenylethane system should + auto-detect as 180°. + """ + groups = detect_terminal_groups(diphenylethane_mols) + assert len(groups) == 2 + for angle, _ in groups: + assert angle == pytest.approx(180.0) + + +def test_custom_flip_angle(phenethyl_mols): + """ + An explicit flip_angle override should be stored and returned as-is, + bypassing the geometric auto-detection. + """ + groups = detect_terminal_groups(phenethyl_mols, flip_angle=90.0) + angle, _ = groups[0] + assert angle == pytest.approx(90.0) + + +def test_detect_two_terminal_groups(diphenylethane_mols): + """ + 1,2-diphenylethane → 1,2-diphenylethanol has two terminal phenyl rings, + each attached via a non-ring CH2/CH anchor, so exactly two groups should + be detected. + """ + groups = detect_terminal_groups(diphenylethane_mols) + assert len(groups) == 2 + + +def test_multiple_groups_unique_pivots(diphenylethane_mols): + """ + The two terminal groups must have distinct pivot atoms (each ring has its + own ipso carbon). + """ + groups = detect_terminal_groups(diphenylethane_mols) + pivot_indices = [indices[1] for _, indices in groups] + assert len(set(pivot_indices)) == 2 + + +def test_multiple_groups_disjoint_mobile(diphenylethane_mols): + """ + The mobile atom sets of the two terminal groups must be disjoint — each + group owns its own ring atoms. + """ + groups = detect_terminal_groups(diphenylethane_mols) + mobile_0 = set(groups[0][1][2:]) + mobile_1 = set(groups[1][1][2:]) + assert mobile_0.isdisjoint(mobile_1) + + +# --------------------------------------------------------------------------- +# Config validation +# --------------------------------------------------------------------------- + + +def test_config_terminal_flip_frequency_none(): + """terminal_flip_frequency defaults to None (disabled).""" + config = Config() + assert config.terminal_flip_frequency is None + + +def test_config_terminal_flip_frequency_valid(): + """A valid time string is parsed to a Sire GeneralUnit.""" + config = Config(terminal_flip_frequency="1 ps") + assert config.terminal_flip_frequency is not None + assert str(config.terminal_flip_frequency).startswith("1") + + +def test_config_terminal_flip_frequency_bad_units(): + """Non-time units should raise ValueError.""" + with pytest.raises(ValueError, match="units are invalid"): + Config(terminal_flip_frequency="5 A") + + +def test_config_terminal_flip_frequency_bad_type(): + """A non-string value should raise TypeError.""" + config = Config() + with pytest.raises(TypeError, match="must be of type 'str'"): + config.terminal_flip_frequency = 5 + + +def test_config_terminal_flip_angle_none(): + """terminal_flip_angle defaults to None (auto-detect).""" + config = Config() + assert config.terminal_flip_angle is None + + +def test_config_terminal_flip_angle_valid(): + """A valid angle string is parsed to a Sire GeneralUnit.""" + config = Config(terminal_flip_angle="180 degrees") + assert config.terminal_flip_angle is not None + + +def test_config_terminal_flip_angle_bad_units(): + """Non-angle units should raise ValueError.""" + with pytest.raises(ValueError, match="units are invalid"): + Config(terminal_flip_angle="5 A") + + +def test_config_terminal_flip_angle_bad_type(): + """A non-string value should raise TypeError.""" + config = Config() + with pytest.raises(TypeError, match="must be of type 'str'"): + config.terminal_flip_angle = 180 + + +# --------------------------------------------------------------------------- +# TerminalFlipSampler +# --------------------------------------------------------------------------- + + +def test_sampler_initial_state(phenethyl_mols): + """ + A freshly constructed sampler should report zero attempts and zero + accepted moves. + """ + groups = detect_terminal_groups(phenethyl_mols) + sampler = TerminalFlipSampler(groups, 300.0) + assert sampler.num_attempted == 0 + assert sampler.num_accepted == 0 + assert sampler.acceptance_rate == 0.0 + + +def test_sampler_move(phenethyl_mols): + """ + After one call to move(), num_attempted should be 1 and the statistics + should be internally consistent. The outcome (accepted or rejected) + depends on the torsional energy around the exocyclic bond and is not + deterministic for an arbitrary starting configuration. + """ + with tempfile.TemporaryDirectory() as tmpdir: + config = Config( + platform="cpu", + output_directory=tmpdir, + num_lambda=1, + lambda_values=[0.0], + terminal_flip_frequency="4fs", + energy_frequency="4fs", + checkpoint_frequency="4fs", + frame_frequency="4fs", + ) + runner = Runner(phenethyl_mols, config) + + # Create a dynamics object to obtain an OpenMM context. + dynamics_kwargs = runner._dynamics_kwargs.copy() + dynamics = runner._system.dynamics(**dynamics_kwargs) + + groups = detect_terminal_groups(phenethyl_mols) + sampler = TerminalFlipSampler(groups, 300.0) + + sampler.move(dynamics.context()) + + assert sampler.num_attempted == 1 + assert sampler.num_accepted in (0, 1) + assert 0.0 <= sampler.acceptance_rate <= 1.0 + + +def test_rotate(phenethyl_mols): + """ + _rotate() must: + - leave the anchor and pivot atoms stationary, + - move all mobile atoms, + - restore all mobile atom positions after two consecutive 180° flips. + """ + from openmm import unit as omm_unit + + with tempfile.TemporaryDirectory() as tmpdir: + config = Config( + platform="cpu", + output_directory=tmpdir, + num_lambda=1, + lambda_values=[0.0], + terminal_flip_frequency="4fs", + energy_frequency="4fs", + checkpoint_frequency="4fs", + frame_frequency="4fs", + ) + runner = Runner(phenethyl_mols, config) + + dynamics_kwargs = runner._dynamics_kwargs.copy() + dynamics = runner._system.dynamics(**dynamics_kwargs) + context = dynamics.context() + + groups = detect_terminal_groups(phenethyl_mols) + sampler = TerminalFlipSampler(groups, 300.0) + + _, indices = groups[0] + anchor_idx = indices[0] + pivot_idx = indices[1] + mobile_indices = indices[2:] + + pos_before = ( + context.getState(getPositions=True) + .getPositions(asNumpy=True) + .value_in_unit(omm_unit.nanometer) + ) + + sampler._rotate(context, 0, 180.0) + + pos_after = ( + context.getState(getPositions=True) + .getPositions(asNumpy=True) + .value_in_unit(omm_unit.nanometer) + ) + + # Anchor and pivot must not move. + np.testing.assert_allclose( + pos_after[anchor_idx], pos_before[anchor_idx], atol=1e-5 + ) + np.testing.assert_allclose( + pos_after[pivot_idx], pos_before[pivot_idx], atol=1e-5 + ) + + # All mobile atoms must have moved. + for idx in mobile_indices: + assert not np.allclose(pos_after[idx], pos_before[idx], atol=1e-5), ( + f"Mobile atom {idx} did not move after 180° rotation" + ) + + # A second 180° flip must restore all mobile atom positions. + sampler._rotate(context, 0, 180.0) + pos_restored = ( + context.getState(getPositions=True) + .getPositions(asNumpy=True) + .value_in_unit(omm_unit.nanometer) + ) + np.testing.assert_allclose( + pos_restored[mobile_indices], pos_before[mobile_indices], atol=1e-5 + ) + + +# --------------------------------------------------------------------------- +# Runner integration +# --------------------------------------------------------------------------- + + +def test_runner_no_terminal_groups(ethane_methanol): + """ + Setting terminal_flip_frequency on a ring-free molecule should succeed + (0 groups detected) and the simulation should complete normally. + """ + with tempfile.TemporaryDirectory() as tmpdir: + config = Config( + runtime="12fs", + output_directory=tmpdir, + energy_frequency="4fs", + checkpoint_frequency="4fs", + frame_frequency="4fs", + platform="cpu", + max_threads=1, + num_lambda=2, + terminal_flip_frequency="4fs", + ) + runner = Runner(ethane_methanol, config) + assert runner._terminal_groups == [] + runner.run() + + +def test_runner_with_terminal_flip(phenethyl_mols): + """ + With terminal_flip_frequency set and a terminal ring present, the runner + should detect one group and complete the simulation successfully. + """ + with tempfile.TemporaryDirectory() as tmpdir: + config = Config( + runtime="12fs", + output_directory=tmpdir, + energy_frequency="4fs", + checkpoint_frequency="4fs", + frame_frequency="4fs", + platform="cpu", + max_threads=1, + num_lambda=2, + terminal_flip_frequency="4fs", + ) + runner = Runner(phenethyl_mols, config) + assert len(runner._terminal_groups) == 1 + runner.run() + + +def test_runner_validation_frequency_multiple(ethane_methanol): + """ + terminal_flip_frequency must be a multiple of energy_frequency. + A non-multiple should raise ValueError during runner initialisation. + """ + with tempfile.TemporaryDirectory() as tmpdir: + config = Config( + output_directory=tmpdir, + platform="cpu", + num_lambda=2, + energy_frequency="4fs", + terminal_flip_frequency="3fs", # not a multiple of 4fs + ) + with pytest.raises( + ValueError, match="must be a multiple of 'energy_frequency'" + ): + Runner(ethane_methanol, config) From 1ac5ad08f24bc47f10610f5716d7a9abf829b515 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 31 Mar 2026 14:44:32 +0100 Subject: [PATCH 02/14] Use encoding check for Unicode symbols. --- src/somd2/_utils/__init__.py | 15 +++++++++------ src/somd2/runner/_terminal_flip.py | 6 +++--- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/somd2/_utils/__init__.py b/src/somd2/_utils/__init__.py index 74897957..0fd4719a 100644 --- a/src/somd2/_utils/__init__.py +++ b/src/somd2/_utils/__init__.py @@ -19,16 +19,19 @@ # along with SOMD2. If not, see . ##################################################################### -import platform as _platform +import sys as _sys -if _platform.system() == "Windows": - _lam_sym = "lambda" - _delta_sym = "delta" -else: +try: + "λΔ°".encode(_sys.stdout.encoding or "utf-8") _lam_sym = "λ" _delta_sym = "ΔE" + _degree_sym = "°" +except (UnicodeEncodeError, LookupError): + _lam_sym = "lambda" + _delta_sym = "delta" + _degree_sym = "deg" -del _platform +del _sys def _has_ghost(mol, idxs, is_lambda1=False): diff --git a/src/somd2/runner/_terminal_flip.py b/src/somd2/runner/_terminal_flip.py index 67901c69..cede8b54 100644 --- a/src/somd2/runner/_terminal_flip.py +++ b/src/somd2/runner/_terminal_flip.py @@ -30,7 +30,7 @@ import sire.legacy.Mol as _Mol from somd2 import _logger -from somd2._utils import _delta_sym +from somd2._utils import _delta_sym, _degree_sym def _auto_flip_angle(mol, anchor_idx, pivot_idx, ring_neighbor_idxs): @@ -258,14 +258,14 @@ def detect_terminal_groups(system, flip_angle=None): if group_angle is None: _logger.warning( f"Terminal group at pivot atom {j} has no recognised " - f"rotational symmetry (raw angle = {raw:.1f}°). " + f"rotational symmetry (raw angle = {raw:.1f}{_degree_sym}). " "Skipping group." ) continue _logger.debug( f"Terminal group at pivot atom {j}: auto-detected flip " - f"angle = {group_angle}° (raw = {raw:.1f}°)." + f"angle = {group_angle}{_degree_sym} (raw = {raw:.1f}{_degree_sym})." ) # Map molecule-local indices to absolute system indices. From ea0ba1299757c708862519487bb3d461e09e084d Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 31 Mar 2026 15:35:11 +0100 Subject: [PATCH 03/14] Optionally randomise velocities after flip. --- README.md | 29 +++++++++++++++++++++++++++++ src/somd2/config/_config.py | 3 ++- src/somd2/runner/_repex.py | 6 +++++- src/somd2/runner/_runner.py | 30 +++++++++++++++++++++++++----- src/somd2/runner/_terminal_flip.py | 11 ++++++++++- 5 files changed, 71 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index bd212329..3d63603f 100644 --- a/README.md +++ b/README.md @@ -219,6 +219,35 @@ require a different `nvcc` to that provided by conda, you can set the Depending on your setup, you may also need to install the `cuda-nvvm` package from `conda-forge`. +## Terminal ring flip Monte Carlo + +SOMD2 supports terminal ring flip Monte Carlo (MC) moves to improve sampling +of terminal aromatic rings in perturbable ligands, as described in +[this paper](https://chemrxiv.org/doi/full/10.26434/chemrxiv-2025-2zkx5). +Each move attempts a discrete rotation of a terminal ring around the bond +connecting it to the rest of the molecule, accepted or rejected via the +Metropolis criterion. Terminal ring groups are detected automatically from +the molecular connectivity of perturbable molecules. + +To enable terminal flip MC, set the frequency at which moves are attempted: + +``` +somd2 perturbable_system.bss --terminal-flip-frequency "1 ps" +``` + +The flip angle for each group is determined automatically from the ring +geometry. To override this for all groups: + +``` +somd2 perturbable_system.bss --terminal-flip-frequency "1 ps" --terminal-flip-angle "180 degrees" +``` + +To see all terminal flip related options, run: + +``` +somd2 --help | grep -A2 ' --terminal-flip' +``` + ## Analysis Simulation output will be written to the directory specified using the diff --git a/src/somd2/config/_config.py b/src/somd2/config/_config.py index ba1dba26..8070d20a 100644 --- a/src/somd2/config/_config.py +++ b/src/somd2/config/_config.py @@ -372,7 +372,8 @@ def __init__( GPU resources are available. randomise_velocities: bool - Whether to randomise velocities at the start of each replica exchange cycle. + Whether to randomise velocities at the start of each replica exchange cycle + or following a terminal flip Monte Carlo move. perturbed_system: str The path to a stream file containing a Sire system for the equilibrated perturbed diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 6e60b34e..4ed30f1f 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -1315,7 +1315,11 @@ def _run_block( # Perform a terminal flip move before dynamics if requested. if self._terminal_flip_sampler is not None and is_terminal_flip: _logger.info(f"Performing terminal flip move at {_lam_sym} = {lam:.5f}") - self._terminal_flip_sampler.move(dynamics.context()) + if ( + self._terminal_flip_sampler.move(dynamics.context()) + and self._config.randomise_velocities + ): + dynamics.randomise_velocities() _logger.info(f"Running dynamics at {_lam_sym} = {lam:.5f}") diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index 314e36fb..df2082ea 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -769,7 +769,11 @@ def generate_lam_vals(lambda_base, increment=0.001): f"Performing terminal flip move at " f"{_lam_sym} = {lambda_value:.5f}" ) - terminal_flip_sampler.move(dynamics.context()) + if ( + terminal_flip_sampler.move(dynamics.context()) + and self._config.randomise_velocities + ): + dynamics.randomise_velocities() # Write ghost residues immediately after the GCMC # move if a frame will be saved in the upcoming @@ -826,7 +830,11 @@ def generate_lam_vals(lambda_base, increment=0.001): f"Performing terminal flip move at " f"{_lam_sym} = {lambda_value:.5f}" ) - terminal_flip_sampler.move(dynamics.context()) + if ( + terminal_flip_sampler.move(dynamics.context()) + and self._config.randomise_velocities + ): + dynamics.randomise_velocities() dynamics.run( checkpoint_frequency, @@ -973,7 +981,11 @@ def generate_lam_vals(lambda_base, increment=0.001): f"Performing terminal flip move at " f"{_lam_sym} = {lambda_value:.5f}" ) - terminal_flip_sampler.move(dynamics.context()) + if ( + terminal_flip_sampler.move(dynamics.context()) + and self._config.randomise_velocities + ): + dynamics.randomise_velocities() dynamics.run( rem, @@ -1072,7 +1084,11 @@ def generate_lam_vals(lambda_base, increment=0.001): f"Performing terminal flip move at " f"{_lam_sym} = {lambda_value:.5f}" ) - terminal_flip_sampler.move(dynamics.context()) + if ( + terminal_flip_sampler.move(dynamics.context()) + and self._config.randomise_velocities + ): + dynamics.randomise_velocities() # Write ghost residues immediately after the GCMC # move if a frame will be saved in the upcoming @@ -1114,7 +1130,11 @@ def generate_lam_vals(lambda_base, increment=0.001): f"Performing terminal flip move at " f"{_lam_sym} = {lambda_value:.5f}" ) - terminal_flip_sampler.move(dynamics.context()) + if ( + terminal_flip_sampler.move(dynamics.context()) + and self._config.randomise_velocities + ): + dynamics.randomise_velocities() dynamics.run( time, diff --git a/src/somd2/runner/_terminal_flip.py b/src/somd2/runner/_terminal_flip.py index cede8b54..001ce094 100644 --- a/src/somd2/runner/_terminal_flip.py +++ b/src/somd2/runner/_terminal_flip.py @@ -429,11 +429,18 @@ def move(self, context): context : openmm.Context The active OpenMM context. + + Returns + ------- + + bool + True if the move was accepted, False otherwise. Returns False + immediately if there are no terminal groups. """ from openmm import unit as _omm_unit if not self._terminal_groups: - return + return False self._num_attempted += 1 @@ -468,6 +475,7 @@ def move(self, context): f"{_delta_sym} = {e_new - e_old:.2f} kJ/mol, " f"acc = {min(1.0, _np.exp(-delta_e)):.3f})" ) + return True else: context.setPositions(old_positions * _omm_unit.nanometer) _logger.debug( @@ -475,6 +483,7 @@ def move(self, context): f"{_delta_sym} = {e_new - e_old:.2f} kJ/mol, " f"acc = {_np.exp(-delta_e):.3f})" ) + return False @property def num_attempted(self): From 5b9c319300f03b26ecba0b2e78d4bd596b1211a5 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 31 Mar 2026 15:41:23 +0100 Subject: [PATCH 04/14] Remove redundant section. --- README.md | 6 ------ 1 file changed, 6 deletions(-) diff --git a/README.md b/README.md index 3d63603f..3c0c8ef4 100644 --- a/README.md +++ b/README.md @@ -242,12 +242,6 @@ geometry. To override this for all groups: somd2 perturbable_system.bss --terminal-flip-frequency "1 ps" --terminal-flip-angle "180 degrees" ``` -To see all terminal flip related options, run: - -``` -somd2 --help | grep -A2 ' --terminal-flip' -``` - ## Analysis Simulation output will be written to the directory specified using the From 70af0fd52f48d414e7ad626f76f35dec280e8237 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 1 Apr 2026 09:18:22 +0100 Subject: [PATCH 05/14] Avoid duplicate velocity randomisation. --- src/somd2/runner/_repex.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 4ed30f1f..6e60b34e 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -1315,11 +1315,7 @@ def _run_block( # Perform a terminal flip move before dynamics if requested. if self._terminal_flip_sampler is not None and is_terminal_flip: _logger.info(f"Performing terminal flip move at {_lam_sym} = {lam:.5f}") - if ( - self._terminal_flip_sampler.move(dynamics.context()) - and self._config.randomise_velocities - ): - dynamics.randomise_velocities() + self._terminal_flip_sampler.move(dynamics.context()) _logger.info(f"Running dynamics at {_lam_sym} = {lam:.5f}") From c5839c9343fe01ada10ae0ab9a66862268fb0248 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 1 Apr 2026 09:24:19 +0100 Subject: [PATCH 06/14] Refactor into an internal _samplers sub-package. --- src/somd2/runner/_base.py | 2 +- src/somd2/runner/_repex.py | 2 +- src/somd2/runner/_runner.py | 2 +- src/somd2/runner/_samplers/__init__.py | 23 +++++++++++++++++++ .../runner/{ => _samplers}/_terminal_flip.py | 0 tests/runner/test_terminal_flip.py | 2 +- 6 files changed, 27 insertions(+), 4 deletions(-) create mode 100644 src/somd2/runner/_samplers/__init__.py rename src/somd2/runner/{ => _samplers}/_terminal_flip.py (100%) diff --git a/src/somd2/runner/_base.py b/src/somd2/runner/_base.py index 18e67407..81ae61a5 100644 --- a/src/somd2/runner/_base.py +++ b/src/somd2/runner/_base.py @@ -680,7 +680,7 @@ def __init__(self, system, config): raise ValueError(msg) # Auto-detect terminal ring groups using Sire connectivity. - from ._terminal_flip import detect_terminal_groups + from ._samplers import detect_terminal_groups if isinstance(self._system, list): mols = self._system[0] diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 6e60b34e..a015190d 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -856,7 +856,7 @@ def __init__(self, system, config): # Create the terminal flip sampler (if terminal groups were detected). if self._terminal_groups: - from ._terminal_flip import TerminalFlipSampler + from ._samplers import TerminalFlipSampler self._terminal_flip_sampler = TerminalFlipSampler( self._terminal_groups, diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index df2082ea..fe34a011 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -467,7 +467,7 @@ def generate_lam_vals(lambda_base, increment=0.001): # Create the terminal flip sampler (if terminal groups were detected). if self._terminal_groups: - from ._terminal_flip import TerminalFlipSampler + from ._samplers import TerminalFlipSampler terminal_flip_sampler = TerminalFlipSampler( self._terminal_groups, diff --git a/src/somd2/runner/_samplers/__init__.py b/src/somd2/runner/_samplers/__init__.py new file mode 100644 index 00000000..5397014f --- /dev/null +++ b/src/somd2/runner/_samplers/__init__.py @@ -0,0 +1,23 @@ +###################################################################### +# SOMD2: GPU accelerated alchemical free-energy engine. +# +# Copyright: 2023-2026 +# +# Authors: The OpenBioSim Team +# +# SOMD2 is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# SOMD2 is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with SOMD2. If not, see . +##################################################################### + +from ._terminal_flip import TerminalFlipSampler as TerminalFlipSampler +from ._terminal_flip import detect_terminal_groups as detect_terminal_groups diff --git a/src/somd2/runner/_terminal_flip.py b/src/somd2/runner/_samplers/_terminal_flip.py similarity index 100% rename from src/somd2/runner/_terminal_flip.py rename to src/somd2/runner/_samplers/_terminal_flip.py diff --git a/tests/runner/test_terminal_flip.py b/tests/runner/test_terminal_flip.py index ea6dea9b..79011214 100644 --- a/tests/runner/test_terminal_flip.py +++ b/tests/runner/test_terminal_flip.py @@ -22,7 +22,7 @@ from somd2.config import Config from somd2.runner import Runner -from somd2.runner._terminal_flip import TerminalFlipSampler, detect_terminal_groups +from somd2.runner._samplers import TerminalFlipSampler, detect_terminal_groups # --------------------------------------------------------------------------- # detect_terminal_groups From 0a366f2036f196e7095eaf5253ec428cc190ff72 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 1 Apr 2026 09:49:23 +0100 Subject: [PATCH 07/14] Sample all discrete flip states, not just nearest neighbours. --- src/somd2/runner/_samplers/_terminal_flip.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/somd2/runner/_samplers/_terminal_flip.py b/src/somd2/runner/_samplers/_terminal_flip.py index 001ce094..9bfbcf5a 100644 --- a/src/somd2/runner/_samplers/_terminal_flip.py +++ b/src/somd2/runner/_samplers/_terminal_flip.py @@ -455,9 +455,12 @@ def move(self, context): ) e_old = state.getPotentialEnergy().value_in_unit(_omm_unit.kilojoule_per_mole) - # Random sign gives a symmetric proposal (detailed balance). - signed_angle = float(_np.random.choice([-1, 1])) * angle - self._rotate(context, group_idx, signed_angle) + # Pick uniformly from the n-1 non-current states, where n = 360 / angle. + # For 180° (n=2) this is equivalent to a random sign; for higher + # symmetry orders it correctly samples any non-current state in one move. + n = round(360.0 / angle) + step = int(_np.random.randint(1, n)) + self._rotate(context, group_idx, step * angle) # Evaluate the energy of the proposed configuration. e_new = ( From 7f3e67f173607b43a4defb173cedc20625359f4d Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 1 Apr 2026 09:53:36 +0100 Subject: [PATCH 08/14] Add reference to original paper. --- src/somd2/runner/_samplers/_terminal_flip.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/somd2/runner/_samplers/_terminal_flip.py b/src/somd2/runner/_samplers/_terminal_flip.py index 9bfbcf5a..e5f226a7 100644 --- a/src/somd2/runner/_samplers/_terminal_flip.py +++ b/src/somd2/runner/_samplers/_terminal_flip.py @@ -22,6 +22,9 @@ # Adapted from the terminal ring flip MC implemenation in GrandFEP: # https://github.com/deGrootLab/GrandFEP # (Released under the MIT License.) +# +# Original method: Wang et al., ChemRxiv, 2025. +# https://doi.org/10.26434/chemrxiv-2025-2zkx5 __all__ = ["TerminalFlipSampler", "detect_terminal_groups"] From d457482bfde129a895362cb02cd4b19add569db6 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 1 Apr 2026 10:27:52 +0100 Subject: [PATCH 09/14] Allow functions to be used standalone. --- src/somd2/runner/_samplers/_terminal_flip.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/somd2/runner/_samplers/_terminal_flip.py b/src/somd2/runner/_samplers/_terminal_flip.py index e5f226a7..3bb97e29 100644 --- a/src/somd2/runner/_samplers/_terminal_flip.py +++ b/src/somd2/runner/_samplers/_terminal_flip.py @@ -173,7 +173,11 @@ def detect_terminal_groups(system, flip_angle=None): # All atoms in the system, used to obtain absolute (OpenMM) atom indices. all_atoms = system.atoms() + import sire.morph as _morph + for mol in pert_mols: + mol = _morph.link_to_reference(mol) + try: connectivity = mol.property("connectivity") except Exception: From d1a0e9e41d6069c8aa3372d2fd16f538600b328e Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 1 Apr 2026 11:39:32 +0100 Subject: [PATCH 10/14] Simplify logger message. --- src/somd2/runner/_repex.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index a015190d..c400733e 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -863,8 +863,7 @@ def __init__(self, system, config): float(self._config.temperature.value()), ) _logger.info( - f"Terminal flip sampler ready for replica exchange " - f"({len(self._terminal_groups)} group(s))" + f"Terminal flip sampler ready ({len(self._terminal_groups)} group(s))" ) else: self._terminal_flip_sampler = None From 3745398ea4a3e707abb7d2ff37e8d7f8433c15eb Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 1 Apr 2026 11:39:54 +0100 Subject: [PATCH 11/14] Fall back to hybridisation check when geometric check fails. --- src/somd2/runner/_samplers/_terminal_flip.py | 40 +++++++++++++++++--- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/src/somd2/runner/_samplers/_terminal_flip.py b/src/somd2/runner/_samplers/_terminal_flip.py index 3bb97e29..f184cd92 100644 --- a/src/somd2/runner/_samplers/_terminal_flip.py +++ b/src/somd2/runner/_samplers/_terminal_flip.py @@ -202,6 +202,7 @@ def detect_terminal_groups(system, flip_angle=None): num_atoms = mol.num_atoms() seen_bonds = set() + rdmol = None # lazily initialised if geometric detection fails for i in range(num_atoms): atom_i_idx = _Mol.AtomIdx(i) @@ -263,12 +264,39 @@ def detect_terminal_groups(system, flip_angle=None): group_angle = _round_to_symmetry_angle(raw) if group_angle is None: - _logger.warning( - f"Terminal group at pivot atom {j} has no recognised " - f"rotational symmetry (raw angle = {raw:.1f}{_degree_sym}). " - "Skipping group." - ) - continue + # Geometric detection failed; fall back to hybridization. + try: + if rdmol is None: + from sire.convert import to_rdkit as _to_rdkit + from rdkit.Chem import HybridizationType as _HybType + + rdmol = _to_rdkit(mol) + hyb = rdmol.GetAtomWithIdx(j).GetHybridization() + if hyb == _HybType.SP2: + group_angle = 180.0 + elif hyb == _HybType.SP3: + group_angle = 120.0 + else: + _logger.warning( + f"Terminal group at pivot atom {j}: geometric " + f"detection gave unrecognised angle " + f"({raw:.1f}{_degree_sym}) and hybridization " + f"({hyb}) has no defined flip angle. Skipping." + ) + continue + _logger.warning( + f"Terminal group at pivot atom {j}: geometric " + f"detection gave unrecognised angle " + f"({raw:.1f}{_degree_sym}), using hybridization-based " + f"angle {group_angle}{_degree_sym} (pivot is {hyb.name})." + ) + except Exception as e: + _logger.warning( + f"Terminal group at pivot atom {j} has no recognised " + f"rotational symmetry (raw angle = {raw:.1f}{_degree_sym}) " + f"and hybridization fallback failed: {e}. Skipping." + ) + continue _logger.debug( f"Terminal group at pivot atom {j}: auto-detected flip " From 88a76f13907fb1b2d98390bd7edaf8ce37c59d51 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 1 Apr 2026 11:49:38 +0100 Subject: [PATCH 12/14] Use a unique TerminalFlipSampler per replica. --- src/somd2/runner/_repex.py | 41 ++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index c400733e..b8a8392e 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -854,19 +854,22 @@ def __init__(self, system, config): else: self._start_block = 0 - # Create the terminal flip sampler (if terminal groups were detected). + # Create a terminal flip sampler per replica (if terminal groups were detected). if self._terminal_groups: from ._samplers import TerminalFlipSampler - self._terminal_flip_sampler = TerminalFlipSampler( - self._terminal_groups, - float(self._config.temperature.value()), - ) + self._terminal_flip_samplers = [ + TerminalFlipSampler( + self._terminal_groups, + float(self._config.temperature.value()), + ) + for _ in self._lambda_values + ] _logger.info( - f"Terminal flip sampler ready ({len(self._terminal_groups)} group(s))" + f"Terminal flip samplers ready ({len(self._terminal_groups)} group(s))" ) else: - self._terminal_flip_sampler = None + self._terminal_flip_samplers = None from threading import Lock @@ -1018,7 +1021,7 @@ def run(self): # Work out the number of cycles per terminal flip move. if ( self._config.terminal_flip_frequency is not None - and self._terminal_flip_sampler is not None + and self._terminal_flip_samplers is not None ): cycles_per_flip = max( 1, @@ -1163,15 +1166,6 @@ def run(self): ) self._dynamics_cache.mix_states() - # Log terminal flip acceptance rate at each cycle. - if self._terminal_flip_sampler is not None: - _logger.info( - f"Terminal flip acceptance rate: " - f"{self._terminal_flip_sampler.acceptance_rate:.3f} " - f"({self._terminal_flip_sampler.num_accepted}/" - f"{self._terminal_flip_sampler.num_attempted})" - ) - # This is a checkpoint cycle. if is_checkpoint: # Update the block number. @@ -1312,9 +1306,9 @@ def _run_block( gcmc_sampler.write_ghost_residues() # Perform a terminal flip move before dynamics if requested. - if self._terminal_flip_sampler is not None and is_terminal_flip: + if self._terminal_flip_samplers is not None and is_terminal_flip: _logger.info(f"Performing terminal flip move at {_lam_sym} = {lam:.5f}") - self._terminal_flip_sampler.move(dynamics.context()) + self._terminal_flip_samplers[index].move(dynamics.context()) _logger.info(f"Running dynamics at {_lam_sym} = {lam:.5f}") @@ -1770,6 +1764,15 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False): # Remove the PyCUDA context from the stack. gcmc_sampler.pop() + # Log terminal flip acceptance rate for this replica. + if self._terminal_flip_samplers is not None: + sampler = self._terminal_flip_samplers[index] + _logger.info( + f"Terminal flip acceptance rate at {_lam_sym} = {lam:.5f}: " + f"{sampler.acceptance_rate:.3f} " + f"({sampler.num_accepted}/{sampler.num_attempted})" + ) + if is_final_block: _logger.success(f"{_lam_sym} = {lam:.5f} complete") From 795f3773a9dbb3adb32d3221acc6884933640384 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 1 Apr 2026 12:03:38 +0100 Subject: [PATCH 13/14] Add serialisation for sampler statistics. --- src/somd2/runner/_base.py | 1 + src/somd2/runner/_repex.py | 42 ++++++++++- src/somd2/runner/_runner.py | 76 ++++++++++++++++++++ src/somd2/runner/_samplers/_terminal_flip.py | 16 +++++ 4 files changed, 134 insertions(+), 1 deletion(-) diff --git a/src/somd2/runner/_base.py b/src/somd2/runner/_base.py index 81ae61a5..2341336e 100644 --- a/src/somd2/runner/_base.py +++ b/src/somd2/runner/_base.py @@ -1197,6 +1197,7 @@ def increment_filename(base_filename, suffix): output_directory / f"energy_components_{lam}.txt" ) filenames["gcmc_ghosts"] = str(output_directory / f"gcmc_ghosts_{lam}.txt") + filenames["sampler_stats"] = str(output_directory / f"sampler_stats_{lam}.pkl") if restart: filenames["config"] = str( output_directory / increment_filename("config", "yaml") diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index b8a8392e..931d26bb 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -106,6 +106,8 @@ def __init__( self._openmm_states = [None] * len(lambdas) self._gcmc_samplers = [None] * len(lambdas) self._gcmc_states = [None] * len(lambdas) + self._gcmc_stats = [None] * len(lambdas) + self._terminal_flip_stats = [[0, 0]] * len(lambdas) self._num_proposed = _np.matrix(_np.zeros((len(lambdas), len(lambdas)))) self._num_accepted = _np.matrix(_np.zeros((len(lambdas), len(lambdas)))) self._num_swaps = _np.matrix(_np.zeros((len(lambdas), len(lambdas)))) @@ -130,6 +132,14 @@ def __setstate__(self, state): for key, value in state.items(): setattr(self, key, value) + # Provide defaults for attributes added after the initial release, + # so that old checkpoint files can still be loaded. + n = len(self._lambdas) + if not hasattr(self, "_gcmc_stats"): + self._gcmc_stats = [None] * n + if not hasattr(self, "_terminal_flip_stats"): + self._terminal_flip_stats = [[0, 0]] * n + def __getstate__(self): """ Get the state of the object. @@ -145,6 +155,8 @@ def __getstate__(self): # Don't pickle the GCMC samplers since they need to be recreated. "_gcmc_samplers": len(self._gcmc_samplers) * [None], "_gcmc_states": self._gcmc_states, + "_gcmc_stats": self._gcmc_stats, + "_terminal_flip_stats": self._terminal_flip_stats, "_num_proposed": self._num_proposed, "_num_accepted": self._num_accepted, "_num_swaps": self._num_swaps, @@ -823,7 +835,7 @@ def __init__(self, system, config): state = self._dynamics_cache._states[i] dynamics.context().setState(self._dynamics_cache._openmm_states[state]) - # Reset the GCMC water state. + # Reset the GCMC water state and restore statistics. if gcmc_sampler is not None: gcmc_sampler.push() try: @@ -834,6 +846,13 @@ def __init__(self, system, config): ) finally: gcmc_sampler.pop() + if self._dynamics_cache._gcmc_stats[i] is not None: + gcmc_sampler.restore_stats(self._dynamics_cache._gcmc_stats[i]) + + # Restore terminal flip sampler statistics. + if self._terminal_flip_samplers is not None: + attempted, accepted = self._dynamics_cache._terminal_flip_stats[i] + self._terminal_flip_samplers[i].reset(attempted, accepted) # Conversion factor for reduced potential. kT = (_sr.units.k_boltz * self._config.temperature).to(_sr.units.kcal_per_mol) @@ -1190,6 +1209,7 @@ def run(self): # Pickle the dynamics cache. _logger.info("Saving replica exchange state") + self._save_sampler_stats() with open(self._repex_state, "wb") as f: _pickle.dump(self._dynamics_cache, f) @@ -1211,6 +1231,11 @@ def run(self): # Pickle final state of the dynamics cache. _logger.info("Saving final replica exchange state") + if self._terminal_flip_samplers is not None: + self._dynamics_cache._terminal_flip_stats = [ + [s.num_attempted, s.num_accepted] + for s in self._terminal_flip_samplers + ] with open(self._repex_state, "wb") as f: _pickle.dump(self._dynamics_cache, f) @@ -1842,6 +1867,21 @@ def _mix_replicas(num_replicas, energy_matrix, proposed, accepted): return states + def _save_sampler_stats(self): + """ + Save GCMC and terminal flip sampler statistics to the dynamics cache + prior to pickling. + """ + for i in range(len(self._lambda_values)): + _, gcmc_sampler = self._dynamics_cache.get(i) + if gcmc_sampler is not None: + self._dynamics_cache._gcmc_stats[i] = gcmc_sampler.get_stats() + + if self._terminal_flip_samplers is not None: + self._dynamics_cache._terminal_flip_stats = [ + [s.num_attempted, s.num_accepted] for s in self._terminal_flip_samplers + ] + def _save_transition_matrix(self): """ Internal method to save the replica exchange transition matrix. diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index fe34a011..54345e43 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -695,6 +695,16 @@ def generate_lam_vals(lambda_base, increment=0.001): finally: gcmc_sampler.pop() + # Restore sampler statistics from a previous run. + if self._is_restart: + stats = self._load_sampler_stats(index) + if stats is not None: + if gcmc_sampler is not None and "gcmc" in stats: + gcmc_sampler.restore_stats(stats["gcmc"]) + if terminal_flip_sampler is not None and "terminal_flip" in stats: + attempted, accepted = stats["terminal_flip"] + terminal_flip_sampler.reset(attempted, accepted) + # Set the number of neighbours used for the energy calculation. # If not None, then we add one to account for the extra windows # used for finite-difference gradient analysis. @@ -924,6 +934,11 @@ def generate_lam_vals(lambda_base, increment=0.001): if error is not None: raise error + # Save sampler statistics alongside the checkpoint. + self._save_sampler_stats( + index, gcmc_sampler, terminal_flip_sampler + ) + # Delete all trajectory frames from the Sire system within the # dynamics object. dynamics._d._sire_mols.delete_all_frames() @@ -1213,12 +1228,73 @@ def generate_lam_vals(lambda_base, increment=0.001): _logger.error(msg) raise RuntimeError(msg) + # Save sampler statistics alongside the final checkpoint. + self._save_sampler_stats(index, gcmc_sampler, terminal_flip_sampler) + _logger.success( f"{_lam_sym} = {lambda_value:.5f} complete, speed = {speed:.2f} ns day-1" ) return time + def _save_sampler_stats(self, index, gcmc_sampler, terminal_flip_sampler): + """ + Save GCMC and terminal flip sampler statistics to a pickle file. + + Parameters + ---------- + + index : int + The index of the lambda value. + + gcmc_sampler : GCMCSampler or None + The GCMC sampler for this replica. + + terminal_flip_sampler : TerminalFlipSampler or None + The terminal flip sampler for this replica. + """ + import pickle as _pickle + + stats = {} + if gcmc_sampler is not None: + stats["gcmc"] = gcmc_sampler.get_stats() + if terminal_flip_sampler is not None: + stats["terminal_flip"] = [ + terminal_flip_sampler.num_attempted, + terminal_flip_sampler.num_accepted, + ] + with open(self._filenames[index]["sampler_stats"], "wb") as f: + _pickle.dump(stats, f) + + def _load_sampler_stats(self, index): + """ + Load sampler statistics from a pickle file. + + Parameters + ---------- + + index : int + The index of the lambda value. + + Returns + ------- + + dict or None + The sampler statistics, or None if the file does not exist. + """ + import pickle as _pickle + from pathlib import Path as _Path + + path = _Path(self._filenames[index]["sampler_stats"]) + if not path.exists(): + return None + try: + with open(path, "rb") as f: + return _pickle.load(f) + except Exception as e: + _logger.warning(f"Could not load sampler stats for index {index}: {e}") + return None + def _minimisation( self, system, diff --git a/src/somd2/runner/_samplers/_terminal_flip.py b/src/somd2/runner/_samplers/_terminal_flip.py index f184cd92..9f3a6d69 100644 --- a/src/somd2/runner/_samplers/_terminal_flip.py +++ b/src/somd2/runner/_samplers/_terminal_flip.py @@ -539,3 +539,19 @@ def acceptance_rate(self): if self._num_attempted == 0: return 0.0 return self._num_accepted / self._num_attempted + + def reset(self, num_attempted=0, num_accepted=0): + """ + Reset the move counters. + + Parameters + ---------- + + num_attempted : int + Value to restore ``num_attempted`` to. Defaults to 0. + + num_accepted : int + Value to restore ``num_accepted`` to. Defaults to 0. + """ + self._num_attempted = num_attempted + self._num_accepted = num_accepted From 3f281a7d712f7bcabd38a3b5e8b41ddadbb8e584 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 1 Apr 2026 12:40:09 +0100 Subject: [PATCH 14/14] Add option to limit size of ring flip region. --- src/somd2/config/_config.py | 27 ++++++++++++++++++++ src/somd2/runner/_base.py | 6 ++++- src/somd2/runner/_samplers/_terminal_flip.py | 15 ++++++++++- 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/somd2/config/_config.py b/src/somd2/config/_config.py index 8070d20a..41bb8d47 100644 --- a/src/somd2/config/_config.py +++ b/src/somd2/config/_config.py @@ -141,6 +141,7 @@ def __init__( perturbed_system=None, terminal_flip_frequency=None, terminal_flip_angle=None, + terminal_flip_max_mobile_atoms=None, gcmc=False, gcmc_frequency=None, gcmc_selection=None, @@ -391,6 +392,11 @@ def __init__( ``"180 degrees"``. If None (the default), the angle is determined automatically for each group from its geometry. + terminal_flip_max_mobile_atoms: int or None + Maximum number of mobile atoms allowed in a terminal ring group. + Groups with more mobile atoms than this threshold are skipped during + detection. Defaults to None (no limit). + gcmc: bool Whether to perform Grand Canonical Monte Carlo (GCMC) water insertions/deletions. @@ -575,6 +581,7 @@ def __init__( self.perturbed_system = perturbed_system self.terminal_flip_frequency = terminal_flip_frequency self.terminal_flip_angle = terminal_flip_angle + self.terminal_flip_max_mobile_atoms = terminal_flip_max_mobile_atoms self.gcmc = gcmc self.gcmc_frequency = gcmc_frequency self.gcmc_selection = gcmc_selection @@ -2064,6 +2071,26 @@ def terminal_flip_angle(self, terminal_flip_angle): else: self._terminal_flip_angle = None + @property + def terminal_flip_max_mobile_atoms(self): + return self._terminal_flip_max_mobile_atoms + + @terminal_flip_max_mobile_atoms.setter + def terminal_flip_max_mobile_atoms(self, terminal_flip_max_mobile_atoms): + if terminal_flip_max_mobile_atoms is not None: + if not isinstance(terminal_flip_max_mobile_atoms, int): + try: + terminal_flip_max_mobile_atoms = int(terminal_flip_max_mobile_atoms) + except: + raise ValueError( + "'terminal_flip_max_mobile_atoms' must be of type 'int'" + ) + if terminal_flip_max_mobile_atoms < 1: + raise ValueError( + "'terminal_flip_max_mobile_atoms' must be greater than 0" + ) + self._terminal_flip_max_mobile_atoms = terminal_flip_max_mobile_atoms + @property def gcmc(self): return self._gcmc diff --git a/src/somd2/runner/_base.py b/src/somd2/runner/_base.py index 2341336e..00d4aa76 100644 --- a/src/somd2/runner/_base.py +++ b/src/somd2/runner/_base.py @@ -692,7 +692,11 @@ def __init__(self, system, config): if self._config.terminal_flip_angle is not None else None ) - self._terminal_groups = detect_terminal_groups(mols, flip_angle=flip_angle) + self._terminal_groups = detect_terminal_groups( + mols, + flip_angle=flip_angle, + max_mobile_atoms=self._config.terminal_flip_max_mobile_atoms, + ) if not self._terminal_groups: _logger.warning( diff --git a/src/somd2/runner/_samplers/_terminal_flip.py b/src/somd2/runner/_samplers/_terminal_flip.py index 9f3a6d69..971a85ac 100644 --- a/src/somd2/runner/_samplers/_terminal_flip.py +++ b/src/somd2/runner/_samplers/_terminal_flip.py @@ -125,7 +125,7 @@ def _round_to_symmetry_angle(raw_angle, tolerance=10.0): return symmetry_angles[min_idx] -def detect_terminal_groups(system, flip_angle=None): +def detect_terminal_groups(system, flip_angle=None, max_mobile_atoms=None): """ Detect terminal ring groups in perturbable molecules using Sire's native connectivity. @@ -151,6 +151,11 @@ def detect_terminal_groups(system, flip_angle=None): a float is given it overrides the geometric measurement for all groups. + max_mobile_atoms : int or None + Maximum number of mobile atoms allowed in a terminal ring group. + Groups with more mobile atoms than this threshold are skipped. + Defaults to None (no limit). + Returns ------- @@ -238,6 +243,14 @@ def detect_terminal_groups(system, flip_angle=None): if not mobile: continue + # Skip groups with too many mobile atoms. + if max_mobile_atoms is not None and len(mobile) > max_mobile_atoms: + _logger.warning( + f"Terminal group at pivot atom {j} has {len(mobile)} mobile " + f"atoms (max_mobile_atoms={max_mobile_atoms}). Skipping group." + ) + continue + # Determine the flip angle for this group. if flip_angle is not None: group_angle = flip_angle