"""File: mmmd.py
Description:
This module provides classes and functions for setting up, running, and
analyzing molecular dynamics (MD) simulations using GROMACS. The main
classes include:
- Mm: Provides methods to prepare simulation files, process PDB
files, run GROMACS commands, and perform various analyses on MD data.
- MmRun: A subclass of GmxSystem dedicated to executing MD simulations and
performing post-processing tasks (e.g., RMSF, RMSD, covariance analysis).
Usage:
Import this module and instantiate the MmSystem or MmRun classes to set up
and run your MD simulations.
Author: DY
"""
import logging
import os
import sys
import shutil
import numpy as np
import MDAnalysis as mda
from MDAnalysis.lib.util import get_ext
from MDAnalysis.lib.mdamath import triclinic_box
import openmm as mm
from openmm import app, unit
from pdbfixer.pdbfixer import PDBFixer
from reforge import cli, io
from reforge.utils import cd, clean_dir, timeit, memprofit
from reforge.mdsystem.mdsystem import MDSystem, MDRun
logger = logging.getLogger(__name__)
[docs]
class MmSystem(MDSystem):
"""Subclass for OpenMM"""
def __init__(self, sysdir, sysname, **kwargs):
"""Initialize the MD system with required directories and file paths."""
super().__init__(sysdir, sysname)
[docs]
def prepare_files(self, *args, **kwargs):
"""Extension for OpenMM system"""
super().prepare_files(*args, **kwargs)
[docs]
def clean_pdb(self, pdb_file, add_missing_atoms=False, add_hydrogens=False, pH=7.0, **kwargs):
"""Clean the starting PDB file using PDBfixer by OpenMM.
Parameters
----------
pdb_file : str
Path to the input PDB file.
add_missing_atoms : bool, optional
Whether to add missing atoms (default: False).
add_hydrogens : bool, optional
Whether to add missing hydrogens (default: False).
pH : float, optional
pH value for adding hydrogens (default: 7.0).
**kwargs : dict, optional
Additional keyword arguments (ignored).
"""
logger.info("Cleaning the PDB")
logger.info(f"Processing {pdb_file}")
pdb = PDBFixer(filename=str(pdb_file))
logger.info("Removing heterogens and checking for missing residues...")
pdb.removeHeterogens(False)
pdb.findMissingResidues()
logger.info("Replacing non-standard residues...")
pdb.findNonstandardResidues()
pdb.replaceNonstandardResidues()
if add_missing_atoms:
logger.info("Adding missing atoms...")
pdb.findMissingAtoms()
pdb.addMissingAtoms()
if add_hydrogens:
logger.info("Adding missing hydrogens...")
pdb.addMissingHydrogens(pH)
topology = pdb.topology
positions = pdb.positions
with open(self.inpdb, "w", encoding="utf-8") as outfile:
app.PDBFile.writeFile(topology, positions, outfile)
logger.info(f"Written cleaned PDB to {self.inpdb}")
################################################################################
# MmRun class
################################################################################
[docs]
class MmRun(MDRun):
def __init__(self, sysdir, sysname, runname):
"""Initializes the MD run environment with additional directories for analysis.
Parameters
----------
sysdir (str): Base directory for the system.
sysname (str): Name of the MD system.
runname (str): Name for the MD run.
"""
super().__init__(sysdir, sysname, runname)
self.sysxml = self.root / "system.xml"
self.systop = self.root / "system.top"
self.sysndx = self.root / "system.ndx"
self.mdpdir = self.root / "mdp"
self.str = self.rundir / "mdc.pdb" # Structure file
self.trj = self.rundir / "mdc.trr" # Trajectory file
self.trj = self.trj if self.trj.exists() else self.rundir / "mdc.xtc"
[docs]
def get_std_reporters(self, append, prefix='md', nlog=10000, nchk=10000, **kwargs):
kwargs.setdefault("step", True)
kwargs.setdefault("time", True)
kwargs.setdefault("elapsedTime", True)
kwargs.setdefault("potentialEnergy", True)
kwargs.setdefault("temperature", True)
kwargs.setdefault("density", False)
log_file = os.path.join(self.rundir, f"{prefix}.log")
xml_file = os.path.join(self.rundir, f"{prefix}.xml")
log_reporter = app.StateDataReporter(log_file, nlog, append=append, **kwargs)
stderr_reporter = app.StateDataReporter(sys.stderr, nlog, append=False, **kwargs)
xml_reporter = app.CheckpointReporter(xml_file, nchk, writeState=True)
reporters = [xml_reporter, log_reporter, stderr_reporter]
return reporters
[docs]
@memprofit(level=logging.INFO)
@timeit(level=logging.INFO, unit='auto')
def em(self, simulation, tolerance=10, max_iterations=1000):
"""Perform energy minimization for the simulation.
Parameters
----------
simulation : openmm.app.Simulation
The simulation object.
tolerance [kJ/nm/mol] : float, optional
RMSF force tolerance for energy minimization (default: 10).
max_iterations : int, optional
Maximum number of iterations (default: 1000).
Notes
-----
Minimizes the energy, saves the minimized state, and logs progress.
"""
logger.info("Minimizing energy...")
simulation.minimizeEnergy(tolerance=tolerance, maxIterations=max_iterations)
self.save_state(simulation, "em")
logger.info("Minimization completed.")
[docs]
@memprofit(level=logging.INFO)
@timeit(level=logging.INFO, unit='auto')
def hu(self, simulation, temperature,
n_cycles=100, steps_per_cycle=100, **kwargs):
"""Run equilibration.
Parameters
----------
simulation : openmm.app.Simulation
The simulation object.
nsteps : int, optional
Number of steps for equilibration (default: 10000).
**kwargs : dict, optional
Additional keyword arguments.
Notes
-----
Loads the minimized state, runs heatup, and saves the final state.
"""
logger.info("Heating up the system...")
in_xml = os.path.join(self.rundir, "em.xml")
simulation.loadState(in_xml)
for i in range(n_cycles):
simulation.integrator.setTemperature(temperature*i/n_cycles)
simulation.step(steps_per_cycle)
self.save_state(simulation, "hu")
logger.info("Heatup completed.")
[docs]
@memprofit(level=logging.INFO)
@timeit(level=logging.INFO, unit='auto')
def eq(self, simulation,
n_cycles=100, steps_per_cycle=100, **kwargs):
"""Run equilibration.
Parameters
----------
simulation : openmm.app.Simulation
The simulation object.
nsteps : int, optional
Number of steps for equilibration (default: 10000).
**kwargs : dict, optional
Additional keyword arguments.
Notes
-----
Loads the heated state, runs equilibration, and saves the equilibrated state.
"""
logger.info("Starting equilibration...")
in_xml = str(self.rundir / "hu.xml")
simulation.loadState(in_xml)
enum = enumerate(simulation.system.getForces())
idx, bb_restraint = [(idx, f) for idx, f in enum if f.getName() == 'BackboneRestraint'][0]
fc = bb_restraint.getGlobalParameterDefaultValue(0)
fcname = bb_restraint.getGlobalParameterName(0)
for i in range(n_cycles):
simulation.step(steps_per_cycle)
new_fc = fc * (1 - (i + 1) / n_cycles)
simulation.context.setParameter(fcname, new_fc)
# Remove the restraints and reinitialize context - we need to get rib of bb_fc
simulation.system.removeForce(idx)
state = simulation.context.getState(getPositions=True, getVelocities=True)
simulation.context.reinitialize(preserveState=False)
simulation.context.setPositions(state.getPositions())
simulation.context.setVelocities(state.getVelocities())
simulation.context.setPeriodicBoxVectors(*state.getPeriodicBoxVectors())
simulation.saveState(str(self.rundir / "eq.xml"))
logger.info("Equilibration completed.")
[docs]
@memprofit(level=logging.INFO)
@timeit(level=logging.INFO, unit='auto')
def extend(self, simulation, curr_prefix, next_prefix, until_time=1000, nsteps=None, **kwargs):
"""Extend production MD simulation"""
logger.info("Extending run...")
in_xml = os.path.join(self.rundir, f"{curr_prefix}.xml")
simulation.loadState(in_xml)
state = simulation.context.getState()
curr_time = state.getTime()
if not nsteps:
dt = simulation.integrator.getStepSize()
logger.info(f"Current time: %s", curr_time)
logger.info(f"Extend until: %s", until_time)
nsteps = int((until_time - curr_time) / dt)
if nsteps <= 0:
logger.warning("Current simulation is longer than UNTIL_TIME, exiting!")
sys.exit(0)
logger.info(f"Number of steps left: %s", nsteps)
simulation.step(nsteps)
out_xml = self.rundir / f"{next_prefix}.xml"
simulation.saveState(str(out_xml))
logger.info("Production completed.")
[docs]
class MmReporter(object):
"""Most of this code is adapted from https://github.com/sef43/openmm-mdanalysis-reporter.
MDAReporter outputs a series of frames from a Simulation to any file format supported by MDAnalysis.
To use it, create a MDAReporter, then add it to the Simulation's list of reporters.
"""
def __init__(
self,
file,
reportInterval,
enforcePeriodicBox=None,
selection: str = None,
writer_kwargs: dict = None
):
"""Create a MDAReporter.
Parameters
----------
file : string
The file to write to
reportInterval : int
The interval (in time steps) at which to write frames
enforcePeriodicBox: bool
Specifies whether particle positions should be translated so the center of every molecule
lies in the same periodic box. If None (the default), it will automatically decide whether
to translate molecules based on whether the system being simulated uses periodic boundary
conditions.
selection : str
MDAnalysis selection string (https://docs.mdanalysis.org/stable/documentation_pages/selections.html)
which will be passed to MDAnalysis.Universe.select_atoms. If None (the default), all atoms will we selected.
writer_kwargs : dict
Additional keyword arguments to pass to the MDAnalysis.Writer object.
"""
self._reportInterval = reportInterval
self._enforcePeriodicBox = enforcePeriodicBox
self._filename = file
self._topology = None
self._nextModel = 0
self._mdaUniverse = None
self._mdaWriter = None
self._selection = selection
self._atomGroup = None
self._writer_kwargs = writer_kwargs or {}
[docs]
def describeNextReport(self, simulation):
"""Get information about the next report this object will generate.
Parameters
----------
simulation : Simulation
The Simulation to generate a report for
Returns
-------
tuple
A six element tuple. The first element is the number of steps
until the next report. The next four elements specify whether
that report will require positions, velocities, forces, and
energies respectively. The final element specifies whether
positions should be wrapped to lie in a single periodic box.
"""
steps = self._reportInterval - simulation.currentStep%self._reportInterval
root, ext = get_ext(self._filename)
if ext in ["trr"]:
positions, velocities, forces = True, True, True
else:
positions, velocities, forces = True, False, False
return steps, positions, velocities, forces, False, self._enforcePeriodicBox
[docs]
def report(self, simulation, state):
"""Generate a report.
Parameters
----------
simulation : Simulation
The Simulation to generate a report for
state : State
The current state of the simulation
"""
if self._nextModel == 0:
self._topology = simulation.topology
dt = simulation.integrator.getStepSize() * self._reportInterval # Time between frames in ps
self._mdaUniverse = mda.Universe(
simulation.topology,
simulation,
topology_format='OPENMMTOPOLOGY',
format='OPENMMSIMULATION',
dt=dt
)
if self._selection is not None:
self._atomGroup = self._mdaUniverse.select_atoms(self._selection)
else:
self._atomGroup = self._mdaUniverse.atoms
self._mdaWriter = mda.Writer(
self._filename,
n_atoms=len(self._atomGroup),
**self._writer_kwargs
)
self._nextModel += 1
# update the positions and velocities if present, convert from OpenMM nm to MDAnalysis angstroms
positions = state.getPositions(asNumpy=True).value_in_unit(unit.angstrom)
self._mdaUniverse.atoms.positions = positions
save_velocities = self.describeNextReport(simulation)[2]
if save_velocities:
velocities = state.getVelocities(asNumpy=True).value_in_unit(unit.angstrom/unit.picosecond)
self._mdaUniverse.atoms.velocities = velocities
# update box vectors
boxVectors = state.getPeriodicBoxVectors(asNumpy=True).value_in_unit(unit.angstrom)
self._mdaUniverse.dimensions = triclinic_box(*boxVectors)
self._mdaUniverse.dimensions[:3] = self._sanitize_box_angles(self._mdaUniverse.dimensions[:3])
# Set simulation time on the universe's trajectory timestep
sim_time = state.getTime().value_in_unit(unit.picosecond)
# Update the universe's timestep attributes
self._mdaUniverse.trajectory.ts.time = sim_time
self._mdaUniverse.trajectory.ts.frame = self._nextModel - 1
# write to the trajectory file
self._mdaWriter.write(self._atomGroup)
self._nextModel += 1
def __del__(self):
if self._mdaWriter:
self._mdaWriter.close()
@staticmethod
def _sanitize_box_angles(angles):
""" Ensure box angles correspond to first quadrant
See `discussion on unitcell angles <https://github.com/MDAnalysis/mdanalysis/pull/2917/files#r620558575>`_
"""
inverted = 180 - angles
return np.min(np.array([angles, inverted]), axis=0)
def _get_platform_info():
"""Report OpenMM platform and hardware information."""
info = {}
# Get number of available platforms and their names
num_platforms = mm.Platform.getNumPlatforms()
info['available_platforms'] = [mm.Platform.getPlatform(i).getName()
for i in range(num_platforms)]
# Try to get the fastest platform (usually CUDA or OpenCL)
platform = None
for platform_name in ['CUDA', 'OpenCL', 'CPU']:
try:
platform = mm.Platform.getPlatformByName(platform_name)
info['platform'] = platform_name
break
except Exception:
continue
if platform is None:
platform = mm.Platform.getPlatform(0)
info['platform'] = platform.getName()
# Get platform properties
info['properties'] = {}
try:
if info['platform'] in ['CUDA', 'OpenCL']:
info['properties']['device_index'] = platform.getPropertyDefaultValue('DeviceIndex')
info['properties']['precision'] = platform.getPropertyDefaultValue('Precision')
if info['platform'] == 'CUDA':
info['properties']['cuda_version'] = mm.version.cuda
info['properties']['gpu_name'] = platform.getPropertyValue(platform.createContext(), 'DeviceName')
info['properties']['cpu_threads'] = platform.getPropertyDefaultValue('Threads')
except Exception as e:
logger.warning(f"Could not get some platform properties: {str(e)}")
# Get OpenMM version
info['openmm_version'] = mm.version.full_version
# Log the information
logger.info("OpenMM Platform Information:")
logger.info(f"Available Platforms: {', '.join(info['available_platforms'])}")
logger.info(f"Selected Platform: {info['platform']}")
logger.info(f"OpenMM Version: {info['openmm_version']}")
logger.info("Platform Properties:")
for key, value in info['properties'].items():
logger.info(f" {key}: {value}")
return info