diff --git a/.gitignore b/.gitignore
index f2e103a..04a99aa 100644
--- a/.gitignore
+++ b/.gitignore
@@ -113,3 +113,4 @@ ENV/
# Local development
scratch
+CLAUDE.md
diff --git a/smee/_constants.py b/smee/_constants.py
index d04e395..ca96e79 100644
--- a/smee/_constants.py
+++ b/smee/_constants.py
@@ -40,6 +40,7 @@ class EnergyFn(_StrEnum):
"""An enumeration of the energy functions supported by ``smee`` out of the box."""
COULOMB = "coul"
+ POLARIZATION = "coul+pol"
VDW_LJ = "4*epsilon*((sigma/r)**12-(sigma/r)**6)"
VDW_DEXP = (
@@ -48,6 +49,7 @@ class EnergyFn(_StrEnum):
"alpha/(alpha-beta)*exp(beta*(1-r/r_min)))"
)
# VDW_BUCKINGHAM = "a*exp(-b*r)-c*r^-6"
+ VDW_DAMPEDEXP6810 = "force_at_zero*beta**-1*exp(-beta*(r-rho))-f_6(beta*r)*c6**6-f_8(beta*r)*c8**8-f_10(beta*r)*c10**10"
BOND_HARMONIC = "k/2*(r-length)**2"
diff --git a/smee/converters/openff/_openff.py b/smee/converters/openff/_openff.py
index 73ea3b4..1297f38 100644
--- a/smee/converters/openff/_openff.py
+++ b/smee/converters/openff/_openff.py
@@ -71,6 +71,26 @@ def _get_value(
) -> float:
"""Returns the value of a parameter in its default units"""
default_units = default_units[parameter]
+
+ # Handle missing parameters by using default values
+ if parameter not in potential.parameters:
+ if default_value is None:
+ # Set specific defaults for known multipole parameters
+ if parameter == "thole":
+ default_value = 0.39 * openff.units.unit.dimensionless
+ elif parameter == "dampingFactor":
+ # Will be computed from polarizability if not provided
+ return 0.0 # Placeholder, will be computed later
+ elif parameter.startswith("dipole"):
+ default_value = 0.0 * openff.units.unit.elementary_charge * openff.units.unit.angstrom
+ elif parameter.startswith("quadrupole"):
+ default_value = 0.0 * openff.units.unit.elementary_charge * openff.units.unit.angstrom**2
+ elif parameter in ["axisType", "multipoleAtomZ", "multipoleAtomX", "multipoleAtomY"]:
+ default_value = -1 * openff.units.unit.dimensionless # -1 indicates undefined
+ else:
+ raise KeyError(f"Parameter '{parameter}' not found and no default provided")
+ return default_value.m_as(default_units)
+
value = potential.parameters[parameter]
return (value if value is not None else default_value).m_as(default_units)
diff --git a/smee/converters/openff/nonbonded.py b/smee/converters/openff/nonbonded.py
index c695a46..267f3fa 100644
--- a/smee/converters/openff/nonbonded.py
+++ b/smee/converters/openff/nonbonded.py
@@ -217,6 +217,282 @@ def convert_dexp(
return potential, parameter_maps
+@smee.converters.smirnoff_parameter_converter(
+ "DampedExp6810",
+ {
+ "rho": _ANGSTROM,
+ "beta": _ANGSTROM**-1,
+ "c6": _KCAL_PER_MOL * _ANGSTROM**6,
+ "c8": _KCAL_PER_MOL * _ANGSTROM**8,
+ "c10": _KCAL_PER_MOL * _ANGSTROM**10,
+ "force_at_zero": _KCAL_PER_MOL * _ANGSTROM**-1,
+ "scale_12": _UNITLESS,
+ "scale_13": _UNITLESS,
+ "scale_14": _UNITLESS,
+ "scale_15": _UNITLESS,
+ "cutoff": _ANGSTROM,
+ "switch_width": _ANGSTROM,
+ },
+)
+def convert_dampedexp6810(
+ handlers: list[
+ "smirnoff_plugins.collections.nonbonded.SMIRNOFFDampedExp6810Collection"
+ ],
+ topologies: list[openff.toolkit.Topology],
+ v_site_maps: list[smee.VSiteMap | None],
+) -> tuple[smee.TensorPotential, list[smee.NonbondedParameterMap]]:
+ import smee.potentials.nonbonded
+
+ (
+ potential,
+ parameter_maps,
+ ) = smee.converters.openff.nonbonded.convert_nonbonded_handlers(
+ handlers,
+ "DampedExp6810",
+ topologies,
+ v_site_maps,
+ ("rho", "beta", "c6", "c8", "c10"),
+ ("cutoff", "switch_width", "force_at_zero"),
+ )
+ potential.type = smee.PotentialType.VDW
+ potential.fn = smee.EnergyFn.VDW_DAMPEDEXP6810
+
+ return potential, parameter_maps
+
+
+@smee.converters.smirnoff_parameter_converter(
+ "Multipole",
+ {
+ # Molecular multipole moments
+ "dipoleX": _ELEMENTARY_CHARGE * _ANGSTROM,
+ "dipoleY": _ELEMENTARY_CHARGE * _ANGSTROM,
+ "dipoleZ": _ELEMENTARY_CHARGE * _ANGSTROM,
+ "quadrupoleXX": _ELEMENTARY_CHARGE * _ANGSTROM**2,
+ "quadrupoleXY": _ELEMENTARY_CHARGE * _ANGSTROM**2,
+ "quadrupoleXZ": _ELEMENTARY_CHARGE * _ANGSTROM**2,
+ "quadrupoleYX": _ELEMENTARY_CHARGE * _ANGSTROM**2,
+ "quadrupoleYY": _ELEMENTARY_CHARGE * _ANGSTROM**2,
+ "quadrupoleYZ": _ELEMENTARY_CHARGE * _ANGSTROM**2,
+ "quadrupoleZX": _ELEMENTARY_CHARGE * _ANGSTROM**2,
+ "quadrupoleZY": _ELEMENTARY_CHARGE * _ANGSTROM**2,
+ "quadrupoleZZ": _ELEMENTARY_CHARGE * _ANGSTROM**2,
+ # Local frame definition
+ "axisType": _UNITLESS,
+ "multipoleAtomZ": _UNITLESS,
+ "multipoleAtomX": _UNITLESS,
+ "multipoleAtomY": _UNITLESS,
+ # Damping and polarizability (these may not be present in current force fields)
+ "thole": _UNITLESS,
+ "dampingFactor": _ANGSTROM,
+ "polarity": _ANGSTROM**3,
+ # Cutoff and scaling
+ "cutoff": _ANGSTROM,
+ "scale_12": _UNITLESS,
+ "scale_13": _UNITLESS,
+ "scale_14": _UNITLESS,
+ "scale_15": _UNITLESS,
+ },
+ depends_on=["Electrostatics"],
+)
+def convert_multipole(
+ handlers: list[
+ "smirnoff_plugins.collections.nonbonded.SMIRNOFFMultipoleCollection"
+ ],
+ topologies: list[openff.toolkit.Topology],
+ v_site_maps: list[smee.VSiteMap | None],
+ dependencies: dict[
+ str, tuple[smee.TensorPotential, list[smee.NonbondedParameterMap]]
+ ],
+) -> tuple[smee.TensorPotential, list[smee.NonbondedParameterMap]]:
+
+ potential_chg, parameter_maps_chg = dependencies["Electrostatics"]
+
+ (
+ potential_pol,
+ parameter_maps_pol,
+ ) = smee.converters.openff.nonbonded.convert_nonbonded_handlers(
+ handlers,
+ "Multipole",
+ topologies,
+ v_site_maps,
+ (
+ "dipoleX", "dipoleY", "dipoleZ",
+ "quadrupoleXX", "quadrupoleXY", "quadrupoleXZ",
+ "quadrupoleYX", "quadrupoleYY", "quadrupoleYZ",
+ "quadrupoleZX", "quadrupoleZY", "quadrupoleZZ",
+ "axisType", "multipoleAtomZ", "multipoleAtomX", "multipoleAtomY",
+ "thole", "dampingFactor", "polarity"
+ ),
+ ("cutoff",),
+ has_exclusions=False,
+ )
+
+ cutoff_idx_pol = potential_pol.attribute_cols.index("cutoff")
+ cutoff_idx_chg = potential_chg.attribute_cols.index("cutoff")
+
+ assert torch.isclose(
+ potential_pol.attributes[cutoff_idx_pol],
+ potential_chg.attributes[cutoff_idx_chg],
+ )
+
+ potential_chg.fn = smee.EnergyFn.POLARIZATION
+
+ potential_chg.parameter_cols = (
+ *potential_chg.parameter_cols,
+ *potential_pol.parameter_cols,
+ )
+ potential_chg.parameter_units = (
+ *potential_chg.parameter_units,
+ *potential_pol.parameter_units,
+ )
+ potential_chg.parameter_keys = [
+ *potential_chg.parameter_keys,
+ *potential_pol.parameter_keys,
+ ]
+
+ # Handle different numbers of columns between charge and polarizability potentials
+ n_chg_cols = potential_chg.parameters.shape[1]
+ n_pol_cols = potential_pol.parameters.shape[1]
+
+ # Pad charge parameters with zeros for the new polarizability columns
+ parameters_chg = torch.cat(
+ (potential_chg.parameters, torch.zeros(potential_chg.parameters.shape[0], n_pol_cols, dtype=potential_chg.parameters.dtype)), dim=1
+ )
+
+ # Resolve multipole axis atoms from SMIRKS indices to actual topology indices.
+ # The OFFXML stores multipoleAtomZ/X/Y as 1-based SMIRKS atom map numbers.
+ # We need to resolve these to actual 0-based topology indices for each atom.
+ # Columns in potential_pol.parameters: 13=multipoleAtomZ, 14=multipoleAtomX, 15=multipoleAtomY
+
+ parameter_key_to_idx = {
+ key: i for i, key in enumerate(potential_pol.parameter_keys)
+ }
+
+ all_resolved_pol_params = []
+ resolved_parameter_maps_pol = []
+
+ for handler, topology, v_site_map, param_map_pol in zip(
+ handlers, topologies, v_site_maps, parameter_maps_pol, strict=True
+ ):
+ n_atoms = topology.n_atoms
+
+ # Create per-atom resolved parameters for multipoles
+ # Each atom gets its own parameter row with resolved axis indices
+ resolved_params = []
+
+ for atom_idx in range(n_atoms):
+ # Find the topology_key for this atom
+ matched_key = None
+ matched_param_idx = None
+
+ for topology_key, parameter_key in handler.key_map.items():
+ if isinstance(topology_key, openff.interchange.models.VirtualSiteKey):
+ continue
+ if topology_key.atom_indices[0] == atom_idx:
+ matched_key = topology_key
+ matched_param_idx = parameter_key_to_idx[parameter_key]
+ break
+
+ if matched_key is None:
+ # No multipole parameters for this atom, create zero row
+ resolved_params.append(torch.zeros(n_pol_cols, dtype=torch.float64))
+ continue
+
+ # Get the base parameters for this atom
+ base_params = potential_pol.parameters[matched_param_idx].clone()
+
+ # Resolve SMIRKS indices to actual topology indices
+ # SMIRKS indices are 1-based, so multipoleAtomZ=2 means atom_indices[1]
+ smirks_atom_z = int(base_params[13]) # multipoleAtomZ
+ smirks_atom_x = int(base_params[14]) # multipoleAtomX
+ smirks_atom_y = int(base_params[15]) # multipoleAtomY
+
+ # Resolve using atom_indices from the match
+ # SMIRKS :N corresponds to atom_indices[N-1]
+ atom_indices = matched_key.atom_indices
+
+ if smirks_atom_z > 0 and smirks_atom_z <= len(atom_indices):
+ base_params[13] = atom_indices[smirks_atom_z - 1]
+ else:
+ base_params[13] = -1
+
+ if smirks_atom_x > 0 and smirks_atom_x <= len(atom_indices):
+ base_params[14] = atom_indices[smirks_atom_x - 1]
+ else:
+ base_params[14] = -1
+
+ if smirks_atom_y > 0 and smirks_atom_y <= len(atom_indices):
+ base_params[15] = atom_indices[smirks_atom_y - 1]
+ else:
+ base_params[15] = -1
+
+ # Compute damping factor from polarity
+ base_params[17] = base_params[18] ** (1/6)
+
+ resolved_params.append(base_params)
+
+ # Stack into a tensor for this topology
+ resolved_params_tensor = torch.stack(resolved_params)
+ all_resolved_pol_params.append(resolved_params_tensor)
+
+ # Create identity-like assignment matrix (each atom maps to its own row)
+ assignment_matrix = torch.eye(n_atoms, dtype=torch.float64).to_sparse()
+ resolved_parameter_maps_pol.append(
+ smee.NonbondedParameterMap(
+ assignment_matrix=assignment_matrix,
+ exclusions=param_map_pol.exclusions,
+ exclusion_scale_idxs=param_map_pol.exclusion_scale_idxs,
+ )
+ )
+
+ # Combine all resolved parameters across topologies into a single tensor
+ # Each topology's parameters are separate, so we need to track offsets
+ total_pol_params = sum(p.shape[0] for p in all_resolved_pol_params)
+
+ # Create combined parameters tensor
+ if all_resolved_pol_params:
+ combined_pol_params = torch.cat(all_resolved_pol_params, dim=0)
+ else:
+ combined_pol_params = torch.zeros((0, n_pol_cols), dtype=torch.float64)
+
+ # Pad with charge columns
+ parameters_pol = torch.cat(
+ (torch.zeros(combined_pol_params.shape[0], n_chg_cols, dtype=torch.float64), combined_pol_params), dim=1
+ )
+
+ potential_chg.parameters = torch.cat((parameters_chg, parameters_pol), dim=0)
+
+ # Update assignment matrices with proper offsets
+ param_offset = 0
+ for i, (parameter_map_chg, resolved_map_pol) in enumerate(zip(
+ parameter_maps_chg, resolved_parameter_maps_pol, strict=True
+ )):
+ n_chg_params = parameter_map_chg.assignment_matrix.shape[1]
+ n_atoms = resolved_map_pol.assignment_matrix.shape[0]
+
+ # The resolved_map_pol assignment matrix is identity, but we need to offset
+ # the column indices by (n_charge_params + param_offset)
+ pol_assignment_dense = resolved_map_pol.assignment_matrix.to_dense()
+
+ # Create the full assignment matrix
+ n_total_params = parameters_chg.shape[0] + combined_pol_params.shape[0]
+ full_assignment = torch.zeros((n_atoms, n_total_params), dtype=torch.float64)
+
+ # Copy charge assignments
+ chg_assignment = parameter_map_chg.assignment_matrix.to_dense()
+ full_assignment[:, :n_chg_params] = chg_assignment
+
+ # Add multipole assignments with proper offset
+ pol_start_col = parameters_chg.shape[0] + param_offset
+ for atom_idx in range(n_atoms):
+ full_assignment[atom_idx, pol_start_col + atom_idx] = 1.0
+
+ parameter_map_chg.assignment_matrix = full_assignment.to_sparse()
+ param_offset += n_atoms
+
+ return potential_chg, parameter_maps_chg
+
+
def _make_v_site_electrostatics_compatible(
handlers: list[openff.interchange.smirnoff.SMIRNOFFElectrostaticsCollection],
):
diff --git a/smee/converters/openmm/nonbonded.py b/smee/converters/openmm/nonbonded.py
index 2188d87..56d2d27 100644
--- a/smee/converters/openmm/nonbonded.py
+++ b/smee/converters/openmm/nonbonded.py
@@ -7,6 +7,8 @@
import openmm
import torch
+from tholedipoleplugin import TholeDipoleForce
+
import smee
import smee.converters.openmm
import smee.potentials.nonbonded
@@ -22,9 +24,9 @@
def _create_nonbonded_force(
- potential: smee.TensorPotential,
- system: smee.TensorSystem,
- cls: typing.Type[_T] = openmm.NonbondedForce,
+ potential: smee.TensorPotential,
+ system: smee.TensorSystem,
+ cls: typing.Type[_T] = openmm.NonbondedForce,
) -> _T:
"""Create a non-bonded force for a given potential and system, making sure to set
the appropriate method and cutoffs."""
@@ -71,10 +73,10 @@ def _create_nonbonded_force(
def _eval_mixing_fn(
- potential: smee.TensorPotential,
- mixing_fn: dict[str, str],
- param_1: torch.Tensor,
- param_2: torch.Tensor,
+ potential: smee.TensorPotential,
+ mixing_fn: dict[str, str],
+ param_1: torch.Tensor,
+ param_2: torch.Tensor,
) -> dict[str, float]:
import symengine
@@ -96,8 +98,8 @@ def _eval_mixing_fn(
def _build_vdw_lookup(
- potential: smee.TensorPotential,
- mixing_fn: dict[str, str],
+ potential: smee.TensorPotential,
+ mixing_fn: dict[str, str],
) -> dict[str, list[float]]:
"""Build the ``n_param x n_param`` vdW parameter lookup table containing
parameters for all interactions.
@@ -155,7 +157,7 @@ def _prepend_scale_to_energy_fn(fn: str, scale_var: str = _INTRA_SCALE_VAR) -> s
def _detect_parameters(
- potential: smee.TensorPotential, energy_fn: str, mixing_fn: dict[str, str]
+ potential: smee.TensorPotential, energy_fn: str, mixing_fn: dict[str, str]
) -> tuple[list[str], list[str]]:
"""Detect the required parameters and attributes for a given energy function
and associated mixing rules."""
@@ -209,7 +211,7 @@ def _detect_parameters(
def _extract_parameters(
- potential: smee.TensorPotential, parameter: torch.Tensor, cols: list[str]
+ potential: smee.TensorPotential, parameter: torch.Tensor, cols: list[str]
) -> list[float]:
"""Extract the values of a subset of parameters from a parameter tensor."""
@@ -230,13 +232,13 @@ def _extract_parameters(
def _add_parameters_to_vdw_without_lookup(
- potential: smee.TensorPotential,
- system: smee.TensorSystem,
- energy_fn: str,
- mixing_fn: dict[str, str],
- inter_force: openmm.CustomNonbondedForce,
- intra_force: openmm.CustomBondForce,
- used_parameters: list[str],
+ potential: smee.TensorPotential,
+ system: smee.TensorSystem,
+ energy_fn: str,
+ mixing_fn: dict[str, str],
+ inter_force: openmm.CustomNonbondedForce,
+ intra_force: openmm.CustomBondForce,
+ used_parameters: list[str],
):
"""Add parameters to a vdW force directly, i.e. without using a lookup table."""
@@ -292,12 +294,12 @@ def _add_parameters_to_vdw_without_lookup(
def _add_parameters_to_vdw_with_lookup(
- potential: smee.TensorPotential,
- system: smee.TensorSystem,
- energy_fn: str,
- mixing_fn: dict[str, str],
- inter_force: openmm.CustomNonbondedForce,
- intra_force: openmm.CustomBondForce,
+ potential: smee.TensorPotential,
+ system: smee.TensorSystem,
+ energy_fn: str,
+ mixing_fn: dict[str, str],
+ inter_force: openmm.CustomNonbondedForce,
+ intra_force: openmm.CustomBondForce,
):
"""Add parameters to a vdW force, explicitly defining all pairwise parameters
using a lookup table."""
@@ -356,10 +358,10 @@ def _add_parameters_to_vdw_with_lookup(
def convert_custom_vdw_potential(
- potential: smee.TensorPotential,
- system: smee.TensorSystem,
- energy_fn: str,
- mixing_fn: dict[str, str],
+ potential: smee.TensorPotential,
+ system: smee.TensorSystem,
+ energy_fn: str,
+ mixing_fn: dict[str, str],
) -> tuple[openmm.CustomNonbondedForce, openmm.CustomBondForce]:
"""Converts an arbitrary vdW potential to OpenMM forces.
@@ -444,7 +446,7 @@ def convert_custom_vdw_potential(
smee.PotentialType.VDW, smee.EnergyFn.VDW_LJ
)
def convert_lj_potential(
- potential: smee.TensorPotential, system: smee.TensorSystem
+ potential: smee.TensorPotential, system: smee.TensorSystem
) -> openmm.NonbondedForce | list[openmm.CustomNonbondedForce | openmm.CustomBondForce]:
"""Convert a Lennard-Jones potential to an OpenMM force.
@@ -501,7 +503,7 @@ def convert_lj_potential(
smee.PotentialType.VDW, smee.EnergyFn.VDW_DEXP
)
def convert_dexp_potential(
- potential: smee.TensorPotential, system: smee.TensorSystem
+ potential: smee.TensorPotential, system: smee.TensorSystem
) -> tuple[openmm.CustomNonbondedForce, openmm.CustomBondForce]:
"""Convert a DEXP potential to OpenMM forces.
@@ -526,11 +528,212 @@ def convert_dexp_potential(
return convert_custom_vdw_potential(potential, system, energy_fn, mixing_fn)
+@smee.converters.openmm.potential_converter(
+ smee.PotentialType.VDW, smee.EnergyFn.VDW_DAMPEDEXP6810
+)
+def convert_dampedexp6810_potential(
+ potential: smee.TensorPotential, system: smee.TensorSystem
+) -> tuple[openmm.CustomNonbondedForce, openmm.CustomBondForce]:
+ """Convert a DampedExp6810 potential to OpenMM forces.
+
+ The intermolcular interactions are described by a custom nonbonded force, while the
+ intramolecular interactions are described by a custom bond force.
+
+ If the potential has custom mixing rules (i.e. exceptions), a lookup table will be
+ used to store the parameters. Otherwise, the mixing rules will be applied directly
+ in the energy function.
+ """
+ energy_fn = (
+ "repulsion - ttdamp6*c6*invR^6 - ttdamp8*c8*invR^8 - ttdamp10*c10*invR^10;"
+ "repulsion = force_at_zero*invbeta*exp(-beta*(r-rho));"
+ "ttdamp10 = select(expbr, 1.0 - expbr * ttdamp10Sum, 1);"
+ "ttdamp8 = select(expbr, 1.0 - expbr * ttdamp8Sum, 1);"
+ "ttdamp6 = select(expbr, 1.0 - expbr * ttdamp6Sum, 1);"
+ "ttdamp10Sum = ttdamp8Sum + br^9/362880 + br^10/3628800;"
+ "ttdamp8Sum = ttdamp6Sum + br^7/5040 + br^8/40320;"
+ "ttdamp6Sum = 1.0 + br + br^2/2 + br^3/6 + br^4/24 + br^5/120 + br^6/720;"
+ "expbr = exp(-br);"
+ "br = beta*r;"
+ "invR = 1.0/r;"
+ "invbeta = 1.0/beta;"
+ )
+ mixing_fn = {
+ "beta": "2.0 * beta1 * beta2 / (beta1 + beta2)",
+ "rho": "0.5 * (rho1 + rho2)",
+ "c6": "sqrt(c61*c62)",
+ "c8": "sqrt(c81*c82)",
+ "c10": "sqrt(c101*c102)",
+ }
+
+ return convert_custom_vdw_potential(potential, system, energy_fn, mixing_fn)
+
+
+@smee.converters.openmm.potential_converter(
+ smee.PotentialType.ELECTROSTATICS, smee.EnergyFn.POLARIZATION
+)
+def convert_multipole_potential(
+ potential: smee.TensorPotential, system: smee.TensorSystem
+) -> TholeDipoleForce:
+ """Convert a Multipole potential to OpenMM TholeDipoleForce.
+
+ TholeDipole parameter layout (9 columns):
+ Column 0: charge (e)
+ Columns 1-3: molecularDipole (e·Å, x, y, z)
+ Column 4: axisType (int, 0-5)
+ Column 5: multipoleAtomZ (int)
+ Column 6: multipoleAtomX (int)
+ Column 7: multipoleAtomY (int)
+ Column 8: polarity (ų)
+ """
+ cutoff_idx = potential.attribute_cols.index(smee.CUTOFF_ATTRIBUTE)
+ cutoff = float(potential.attributes[cutoff_idx]) * 0.1 # Å to nm
+
+ force = TholeDipoleForce()
+
+ if system.is_periodic:
+ force.setNonbondedMethod(TholeDipoleForce.PME)
+ else:
+ force.setNonbondedMethod(TholeDipoleForce.NoCutoff)
+
+ force.setPolarizationType(TholeDipoleForce.Mutual)
+ force.setCutoffDistance(cutoff)
+ force.setEwaldErrorTolerance(0.0001)
+ force.setMutualInducedTargetEpsilon(0.00001)
+ force.setMutualInducedMaxIterations(60)
+ force.setExtrapolationCoefficients([-0.154, 0.017, 0.658, 0.474])
+ force.setTholeDampingType(TholeDipoleForce.Amoeba)
+ force.setTholeDampingParameter(0.39)
+
+ # Map AMOEBA axis types to TholeDipole axis types
+ # AMOEBA: NoAxisType=0, ZOnly=1, ZThenX=2, Bisector=3, ZBisect=4, ThreeFold=5
+ # TholeDipole: ZThenX=0, Bisector=1, ZBisect=2, ThreeFold=3, ZOnly=4, NoAxisType=5
+ amoeba_to_thole_axis = {
+ 0: TholeDipoleForce.NoAxisType, # AMOEBA NoAxisType -> TholeDipole NoAxisType
+ 1: TholeDipoleForce.ZOnly, # AMOEBA ZOnly -> TholeDipole ZOnly
+ 2: TholeDipoleForce.ZThenX, # AMOEBA ZThenX -> TholeDipole ZThenX
+ 3: TholeDipoleForce.Bisector, # AMOEBA Bisector -> TholeDipole Bisector
+ 4: TholeDipoleForce.ZBisect, # AMOEBA ZBisect -> TholeDipole ZBisect
+ 5: TholeDipoleForce.ThreeFold, # AMOEBA ThreeFold -> TholeDipole ThreeFold
+ }
+
+ idx_offset = 0
+
+ for topology, n_copies in zip(system.topologies, system.n_copies):
+ parameter_map = topology.parameters[potential.type]
+ parameters = parameter_map.assignment_matrix @ potential.parameters
+ parameters = parameters.detach()
+
+ n_particles = topology.n_particles
+ n_params = parameters.shape[1]
+
+ for _ in range(n_copies):
+ for atom_idx in range(n_particles):
+ # Get charge from first n_particles rows
+ charge = float(parameters[atom_idx, 0])
+
+ # Get dipole, axisType, frame atoms, polarity from rows n_particles to 2*n_particles
+ if parameters.shape[0] > n_particles:
+ pol_row = parameters[n_particles + atom_idx]
+
+ if n_params == 20:
+ # AMOEBA-style 20-column layout (current PHAST force field):
+ # Col 0: charge, 1-3: dipole, 4-12: quadrupole (ignored)
+ # Col 13: axisType, 14-16: atomZ/X/Y, 17: thole, 18: dampingFactor, 19: polarity
+ dipole = [float(pol_row[1]) * 0.1, float(pol_row[2]) * 0.1, float(pol_row[3]) * 0.1] # e·Å to e·nm
+ amoeba_axis_type = int(pol_row[13])
+ atom_z = int(pol_row[14])
+ atom_x = int(pol_row[15])
+ atom_y = int(pol_row[16])
+ polarity = float(pol_row[19]) * 0.001 # ų to nm³
+
+ # Map axis type, but force NoAxisType if no valid axis atoms
+ if atom_z < 0:
+ axis_type = TholeDipoleForce.NoAxisType
+ else:
+ axis_type = amoeba_to_thole_axis.get(amoeba_axis_type, TholeDipoleForce.NoAxisType)
+ elif n_params == 9:
+ # TholeDipole 9-column layout:
+ # Col 0: charge, 1-3: dipole, 4: axisType, 5-7: atomZ/X/Y, 8: polarity
+ dipole = [float(pol_row[1]) * 0.1, float(pol_row[2]) * 0.1, float(pol_row[3]) * 0.1]
+ axis_type = int(pol_row[4])
+ atom_z = int(pol_row[5])
+ atom_x = int(pol_row[6])
+ atom_y = int(pol_row[7])
+ polarity = float(pol_row[8]) * 0.001
+ else:
+ # Fallback: assume dipole at 1-3, polarity at last column
+ dipole = [float(pol_row[1]) * 0.1, float(pol_row[2]) * 0.1, float(pol_row[3]) * 0.1]
+ axis_type = TholeDipoleForce.NoAxisType
+ atom_z = -1
+ atom_x = -1
+ atom_y = -1
+ polarity = float(pol_row[n_params - 1]) * 0.001 if n_params > 4 else 0.0
+ else:
+ dipole = [0.0, 0.0, 0.0]
+ axis_type = TholeDipoleForce.NoAxisType
+ atom_z = -1
+ atom_x = -1
+ atom_y = -1
+ polarity = 0.0
+
+ # The axis atom indices in the parameters are now actual 0-based topology indices
+ # (resolved from SMIRKS indices during OpenFF conversion). We just need to
+ # add the idx_offset for the current molecule copy.
+ force.addParticle(
+ charge,
+ dipole,
+ polarity,
+ axis_type,
+ int(atom_z) + idx_offset if atom_z >= 0 else -1,
+ int(atom_x) + idx_offset if atom_x >= 0 else -1,
+ int(atom_y) + idx_offset if atom_y >= 0 else -1,
+ )
+
+ # Set up covalent maps (TholeDipole uses 4 types: Covalent12-15)
+ covalent_12_maps = {}
+ covalent_13_maps = {}
+ covalent_14_maps = {}
+ covalent_15_maps = {}
+
+ for (i, j), scale_idx in zip(parameter_map.exclusions, parameter_map.exclusion_scale_idxs):
+ i = int(i) + idx_offset
+ j = int(j) + idx_offset
+
+ if scale_idx == 0: # 1-2 interactions
+ covalent_maps = covalent_12_maps
+ elif scale_idx == 1: # 1-3 interactions
+ covalent_maps = covalent_13_maps
+ elif scale_idx == 2: # 1-4 interactions
+ covalent_maps = covalent_14_maps
+ else: # 1-5+ interactions
+ covalent_maps = covalent_15_maps
+
+ if i not in covalent_maps:
+ covalent_maps[i] = []
+ if j not in covalent_maps:
+ covalent_maps[j] = []
+ covalent_maps[i].append(j)
+ covalent_maps[j].append(i)
+
+ for i, atoms in covalent_12_maps.items():
+ force.setCovalentMap(i, TholeDipoleForce.Covalent12, atoms)
+ for i, atoms in covalent_13_maps.items():
+ force.setCovalentMap(i, TholeDipoleForce.Covalent13, atoms)
+ for i, atoms in covalent_14_maps.items():
+ force.setCovalentMap(i, TholeDipoleForce.Covalent14, atoms)
+ for i, atoms in covalent_15_maps.items():
+ force.setCovalentMap(i, TholeDipoleForce.Covalent15, atoms)
+
+ idx_offset += n_particles
+
+ return force
+
+
@smee.converters.openmm.potential_converter(
smee.PotentialType.ELECTROSTATICS, smee.EnergyFn.COULOMB
)
def convert_coulomb_potential(
- potential: smee.TensorPotential, system: smee.TensorSystem
+ potential: smee.TensorPotential, system: smee.TensorSystem
) -> openmm.NonbondedForce:
"""Convert a Coulomb potential to an OpenMM force."""
force = _create_nonbonded_force(potential, system)
diff --git a/smee/potentials/multipole.py b/smee/potentials/multipole.py
new file mode 100644
index 0000000..0d27d3d
--- /dev/null
+++ b/smee/potentials/multipole.py
@@ -0,0 +1,424 @@
+"""Multipole potential energy functions."""
+
+import math
+import typing
+
+import torch
+
+import smee.potentials
+
+
+@smee.potentials.potential_energy_fn(
+ smee.PotentialType.ELECTROSTATICS, smee.EnergyFn.POLARIZATION
+)
+def compute_multipole_energy(
+ system: smee.TensorSystem,
+ potential: smee.TensorPotential,
+ conformer: torch.Tensor,
+ box_vectors: torch.Tensor | None = None,
+ pairwise=None,
+ polarization_type: str = "mutual",
+ extrapolation_coefficients: list[float] | None = None,
+) -> torch.Tensor:
+ """Compute the multipole energy including polarization effects.
+
+ This function supports the full AMOEBA multipole model with the following parameters
+ per atom (arranged in columns):
+
+ Column 0: charge (double) - the particle's charge
+ Columns 1-3: molecularDipole (vector[3]) - the particle's molecular dipole (x, y, z)
+ Columns 4-12: molecularQuadrupole (vector[9]) - the particle's molecular quadrupole
+ Column 13: axisType (int) - the particle's axis type (0=NoAxisType, 1=ZOnly, etc.)
+ Column 14: multipoleAtomZ (int) - index of first atom for lab<->molecular frames
+ Column 15: multipoleAtomX (int) - index of second atom for lab<->molecular frames
+ Column 16: multipoleAtomY (int) - index of third atom for lab<->molecular frames
+ Column 17: thole (double) - Thole parameter (default 0.39)
+ Column 18: dampingFactor (double) - dampingFactor parameter
+ Column 19: polarity (double) - polarity/polarizability parameter
+
+ For backwards compatibility, if fewer columns are provided, sensible defaults are used.
+
+ Args:
+ system: The system.
+ potential: The potential containing multipole parameters.
+ conformer: The conformer.
+ box_vectors: The box vectors.
+ pairwise: The pairwise distances.
+ polarization_type: The polarization solver type. Options are:
+ - "mutual": Full iterative SCF solver (default, ~60 iterations)
+ - "direct": Direct polarization with no mutual coupling (0 iterations)
+ - "extrapolated": Extrapolated polarization using OPT method
+ extrapolation_coefficients: Custom extrapolation coefficients for "extrapolated" type.
+ If None, uses OPT3 coefficients [-0.154, 0.017, 0.657, 0.475].
+ Must sum to approximately 1.0 for energy conservation.
+
+ Returns:
+ The energy.
+ """
+ from smee.potentials.nonbonded import (
+ _COULOMB_PRE_FACTOR,
+ _compute_pme_exclusions,
+ _compute_pme_grid,
+ _broadcast_exclusions,
+ compute_pairwise,
+ compute_pairwise_scales,
+ )
+
+ # Validate polarization_type
+ valid_types = ["mutual", "direct", "extrapolated"]
+ if polarization_type not in valid_types:
+ raise ValueError(f"polarization_type must be one of {valid_types}, got {polarization_type}")
+
+ box_vectors = None if not system.is_periodic else box_vectors
+
+ cutoff = potential.attributes[potential.attribute_cols.index(smee.CUTOFF_ATTRIBUTE)]
+
+ pairwise = compute_pairwise(system, conformer, box_vectors, cutoff)
+
+ # Initialize parameter lists for all AMOEBA multipole parameters
+ charges = []
+ dipoles = [] # 3 components per atom
+ quadrupoles = [] # 9 components per atom
+ axis_types = []
+ multipole_atom_z = [] # Z-axis defining atom indices
+ multipole_atom_x = [] # X-axis defining atom indices
+ multipole_atom_y = [] # Y-axis defining atom indices
+ thole_params = []
+ polarizabilities = []
+
+ # Extract parameters from parameter matrix
+ for topology, n_copies in zip(system.topologies, system.n_copies):
+ parameter_map = topology.parameters[potential.type]
+ topology_parameters = parameter_map.assignment_matrix @ potential.parameters
+
+ n_particles = topology.n_particles
+ n_params = topology_parameters.shape[1]
+
+ # Expected parameter layout for full AMOEBA multipole:
+ # Column 0: charge
+ # Columns 1-3: molecular dipole (x, y, z)
+ # Columns 4-12: molecular quadrupole (9 components)
+ # Column 13: axis type (int)
+ # Column 14: multipole atom Z index
+ # Column 15: multipole atom X index
+ # Column 16: multipole atom Y index
+ # Column 17: thole parameter
+ # Column 18: damping factor
+ # Column 19: polarizability
+
+ charges.append(topology_parameters[:n_particles, 0].repeat(n_copies))
+ dipoles.append(topology_parameters[n_particles:, 1:4].repeat(n_copies, 1))
+ quadrupoles.append(topology_parameters[n_particles:, 4:13].repeat(n_copies, 1))
+ axis_types.append(topology_parameters[n_particles:, 13].repeat(n_copies).int())
+ multipole_atom_z.append(topology_parameters[n_particles:, 14].repeat(n_copies).int())
+ multipole_atom_x.append(topology_parameters[n_particles:, 15].repeat(n_copies).int())
+ multipole_atom_y.append(topology_parameters[n_particles:, 16].repeat(n_copies).int())
+ thole_params.append(topology_parameters[n_particles:, 17].repeat(n_copies))
+ polarizabilities.append(topology_parameters[n_particles:, 19].repeat(n_copies))
+
+ # Concatenate all parameter lists
+ charges = torch.cat(charges) # Shape: (n_total_particles,)
+ dipoles = torch.cat(dipoles) # Shape: (n_total_particles, 3)
+ quadrupoles = torch.cat(quadrupoles) # Shape: (n_total_particles, 9)
+ axis_types = torch.cat(axis_types) # Shape: (n_total_particles,)
+ multipole_atom_z = torch.cat(multipole_atom_z) # Shape: (n_total_particles,)
+ multipole_atom_x = torch.cat(multipole_atom_x) # Shape: (n_total_particles,)
+ multipole_atom_y = torch.cat(multipole_atom_y) # Shape: (n_total_particles,)
+ thole_params = torch.cat(thole_params) # Shape: (n_total_particles,)
+ polarizabilities = torch.cat(polarizabilities) # Shape: (n_total_particles,)
+
+ # Override scale factors to match TholeDipole convention
+ # TholeDipole uses [0, 0, 0.5, 1.0] for 1-2, 1-3, 1-4, 1-5 (vs AMOEBA's [0, 0, 0.4, 0.8/1.0])
+ if 'scale_14' in potential.attribute_cols:
+ scale_14_idx = potential.attribute_cols.index('scale_14')
+ potential.attributes[scale_14_idx] = 0.5
+
+ pair_scales = compute_pairwise_scales(system, potential)
+
+ # static partial charge - partial charge energy
+ if system.is_periodic == False:
+ coul_energy = (
+ _COULOMB_PRE_FACTOR
+ * pair_scales
+ * charges[pairwise.idxs[:, 0]]
+ * charges[pairwise.idxs[:, 1]]
+ / pairwise.distances
+ ).sum(-1)
+ else:
+ import NNPOps.pme
+
+ cutoff = potential.attributes[potential.attribute_cols.index(smee.CUTOFF_ATTRIBUTE)]
+ error_tol = torch.tensor(0.0001)
+
+ exceptions = _compute_pme_exclusions(system, potential).to(charges.device)
+
+ grid_x, grid_y, grid_z, alpha = _compute_pme_grid(box_vectors, cutoff, error_tol)
+
+ pme = NNPOps.pme.PME(
+ grid_x, grid_y, grid_z, 5, alpha, _COULOMB_PRE_FACTOR, exceptions
+ )
+
+ energy_direct = torch.ops.pme.pme_direct(
+ conformer.float(),
+ charges.float(),
+ pairwise.idxs.T,
+ pairwise.deltas,
+ pairwise.distances,
+ pme.exclusions,
+ pme.alpha,
+ pme.coulomb,
+ )
+ energy_self = -torch.sum(charges ** 2) * pme.coulomb * pme.alpha / math.sqrt(torch.pi)
+ energy_recip = energy_self + torch.ops.pme.pme_reciprocal(
+ conformer.float(),
+ charges.float(),
+ box_vectors.float(),
+ pme.gridx,
+ pme.gridy,
+ pme.gridz,
+ pme.order,
+ pme.alpha,
+ pme.coulomb,
+ pme.moduli[0].to(charges.device),
+ pme.moduli[1].to(charges.device),
+ pme.moduli[2].to(charges.device),
+ )
+
+ exclusion_idxs, exclusion_scales = _broadcast_exclusions(system, potential)
+
+ exclusion_distances = (
+ conformer[exclusion_idxs[:, 0], :] - conformer[exclusion_idxs[:, 1], :]
+ ).norm(dim=-1)
+
+ energy_exclusion = (
+ _COULOMB_PRE_FACTOR
+ * exclusion_scales
+ * charges[exclusion_idxs[:, 0]]
+ * charges[exclusion_idxs[:, 1]]
+ / exclusion_distances
+ ).sum(-1)
+
+ coul_energy = energy_direct + energy_recip + energy_exclusion
+
+ # If all polarizabilities are zero, just return the Coulomb energy
+ if torch.allclose(polarizabilities, torch.tensor(0.0, dtype=torch.float64)):
+ return coul_energy
+
+ # Handle batch vs single conformer - process each conformer individually
+ is_batch = conformer.ndim == 3
+
+ if is_batch:
+ # Process each conformer individually and return results for each
+ n_conformers = conformer.shape[0]
+ batch_energies = []
+
+ for conf_idx in range(n_conformers):
+ # Extract single conformer
+ single_conformer = conformer[conf_idx]
+
+ # Compute pairwise for this conformer
+ single_pairwise = compute_pairwise(system, single_conformer, box_vectors, cutoff)
+
+ # Recursively call this function for single conformer
+ single_energy = compute_multipole_energy(
+ system, potential, single_conformer, box_vectors, single_pairwise,
+ polarization_type, extrapolation_coefficients
+ )
+ batch_energies.append(single_energy)
+
+ return torch.stack(batch_energies)
+
+ # Continue with single conformer processing
+ efield_static = torch.zeros((system.n_particles, 3), dtype=torch.float64, device=conformer.device)
+ efield_static_polar = torch.zeros((system.n_particles, 3), dtype=torch.float64, device=conformer.device)
+
+ # TholeDipole scale factors (different from AMOEBA):
+ # mScale: permanent-permanent interactions [1-2, 1-3, 1-4, 1-5, 1-6+]
+ # iScale: induced-induced interactions
+ # TholeDipole uses [0, 0, 0.5, 1.0, 1.0] for mScale (vs AMOEBA's [0, 0, 0, 0.4, 0.8])
+ mScale = torch.tensor([0.0, 0.0, 0.5, 1.0, 1.0], dtype=torch.float64)
+ dScale = torch.tensor([0.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float64)
+ pScale = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0], dtype=torch.float64)
+ uScale = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float64)
+
+ # calculate electric field due to partial charges by hand
+ # TODO ewald or wolf summation for periodic
+ _SQRT_COULOMB_PRE_FACTOR = _COULOMB_PRE_FACTOR ** (1 / 2)
+ # Calculate fixed dipole field from permanent charges
+ # TholeDipole applies mScale to the field calculation
+ for distance, delta, idx, scale in zip(
+ pairwise.distances, pairwise.deltas, pairwise.idxs, pair_scales
+ ):
+ # Use actual scale value (don't zero out partial scales)
+ if polarizabilities[idx[0]] * polarizabilities[idx[1]] != 0:
+ u = distance / (polarizabilities[idx[0]] * polarizabilities[idx[1]]) ** (
+ 1.0 / 6.0
+ )
+ else:
+ u = distance
+
+ # Use the Thole parameter (typically 0.39) in the damping function
+ # Using the minimum of the two atoms' Thole parameters
+ thole_a = torch.min(thole_params[idx[0]], thole_params[idx[1]])
+ au3 = thole_a * u ** 3
+ exp_au3 = torch.exp(-au3)
+
+ # Thole damping for charge-dipole interactions (thole_c in OpenMM)
+ damping_term1 = 1 - exp_au3
+
+ efield_static[idx[0]] -= (
+ _SQRT_COULOMB_PRE_FACTOR
+ * scale
+ * damping_term1
+ * charges[idx[1]]
+ * delta
+ / distance ** 3
+ )
+ efield_static[idx[1]] += (
+ _SQRT_COULOMB_PRE_FACTOR
+ * scale
+ * damping_term1
+ * charges[idx[0]]
+ * delta
+ / distance ** 3
+ )
+
+ # reshape to (3*N) vector
+ efield_static = efield_static.reshape(3 * system.n_particles)
+
+ # If there's no electric field, no polarization energy
+ if torch.allclose(efield_static, torch.tensor(0.0, dtype=torch.float64)):
+ return coul_energy
+
+ # induced dipole vector - start with direct polarization
+ ind_dipoles = torch.repeat_interleave(polarizabilities, 3) * efield_static
+
+ # Build A matrix for mutual/extrapolated methods
+ if polarization_type in ["mutual", "extrapolated"]:
+ A = torch.nan_to_num(torch.diag(torch.repeat_interleave(1.0 / polarizabilities, 3)))
+
+ # Build A matrix for induced-induced coupling
+ # NOTE: TholeDipole uses iScale=1.0 for ALL covalent types (no exclusions)
+ for distance, delta, idx in zip(
+ pairwise.distances, pairwise.deltas, pairwise.idxs
+ ):
+ if polarizabilities[idx[0]] * polarizabilities[idx[1]] != 0:
+ u = distance / (polarizabilities[idx[0]] * polarizabilities[idx[1]]) ** (
+ 1.0 / 6.0
+ )
+ else:
+ u = distance
+
+ # Use the Thole parameter (typically 0.39) in the damping function
+ thole_a = torch.min(thole_params[idx[0]], thole_params[idx[1]])
+ au3 = thole_a * u ** 3
+ exp_au3 = torch.exp(-au3)
+ thole3 = 1 - exp_au3
+ thole5 = 1 - (1 + au3) * exp_au3 # Note: (1 + au3), NOT (1 + 1.5*au3)
+
+ t = (
+ torch.eye(3, dtype=torch.float64, device=conformer.device) * thole3 * distance ** -3
+ - 3 * thole5 * torch.einsum("i,j->ij", delta, delta) * distance ** -5
+ )
+ # No scale factor - iScale is always 1.0 in TholeDipole
+ A[3 * idx[0]: 3 * idx[0] + 3, 3 * idx[1]: 3 * idx[1] + 3] = t
+ A[3 * idx[1]: 3 * idx[1] + 3, 3 * idx[0]: 3 * idx[0] + 3] = t
+
+ # Handle different polarization types
+ if polarization_type == "direct":
+ # Direct polarization: μ = α * E (no mutual coupling)
+ # ind_dipoles is already μ^(0) = α * E, so no additional work needed
+ pass
+ elif polarization_type == "extrapolated":
+ # TholeDipole extrapolation uses simple iterative updates (NOT conjugate gradient)
+ # μ_0 = α * E_fixed (direct polarization)
+ # μ_n = α * (E_fixed + E_induced(μ_{n-1})) for n = 1, 2, 3
+ # μ_final = Σ c_k * μ_k
+ if extrapolation_coefficients is None:
+ # TholeDipole default OPT4 coefficients
+ extrapolation_coefficients = [-0.154, 0.017, 0.658, 0.474]
+
+ opt_coeffs = torch.tensor(extrapolation_coefficients, dtype=torch.float64, device=conformer.device)
+ n_orders = len(opt_coeffs)
+
+ # Store perturbation theory orders
+ pt_dipoles = []
+ pt_dipoles.append(ind_dipoles.clone()) # PT0: direct polarization μ = α * E_fixed
+
+ # Build dipole-dipole interaction tensor T for computing induced fields
+ # E_induced[i] = Σ_j T[i,j] @ μ[j]
+ T = torch.zeros((3 * system.n_particles, 3 * system.n_particles), dtype=torch.float64,
+ device=conformer.device)
+ for distance, delta, idx in zip(pairwise.distances, pairwise.deltas, pairwise.idxs):
+ if polarizabilities[idx[0]] * polarizabilities[idx[1]] != 0:
+ u = distance / (polarizabilities[idx[0]] * polarizabilities[idx[1]]) ** (1.0 / 6.0)
+ else:
+ u = distance
+
+ thole_a = torch.min(thole_params[idx[0]], thole_params[idx[1]])
+ au3 = thole_a * u ** 3
+ exp_au3 = torch.exp(-au3)
+ thole3 = 1 - exp_au3
+ thole5 = 1 - (1 + au3) * exp_au3
+
+ # Dipole field tensor: E = [-thole3*μ/r³ + 3*thole5*(μ·r̂)r̂/r³]
+ t = (
+ -torch.eye(3, dtype=torch.float64, device=conformer.device) * thole3 * distance ** -3
+ + 3 * thole5 * torch.einsum("i,j->ij", delta, delta) * distance ** -5
+ )
+ T[3 * idx[0]: 3 * idx[0] + 3, 3 * idx[1]: 3 * idx[1] + 3] = t
+ T[3 * idx[1]: 3 * idx[1] + 3, 3 * idx[0]: 3 * idx[0] + 3] = t
+
+ # Generate higher order PT terms
+ current_dipoles = ind_dipoles.clone()
+ for order in range(1, n_orders):
+ # Calculate induced field from current dipoles
+ efield_induced = T @ current_dipoles
+
+ # Update dipoles: μ_n = α * (E_fixed + E_induced)
+ current_dipoles = torch.repeat_interleave(polarizabilities, 3) * (efield_static + efield_induced)
+ pt_dipoles.append(current_dipoles.clone())
+
+ # Combine PT orders with extrapolation coefficients
+ ind_dipoles = torch.zeros_like(ind_dipoles)
+ for k in range(n_orders):
+ ind_dipoles += opt_coeffs[k] * pt_dipoles[k]
+
+ else: # mutual
+ # Mutual polarization using conjugate gradient
+ precondition_m = torch.repeat_interleave(polarizabilities, 3)
+ residual = efield_static - A @ ind_dipoles
+ z = torch.einsum("i,i->i", precondition_m, residual)
+ p = torch.clone(z)
+
+ converged_iter = 60
+ for iter_num in range(60):
+ alpha = torch.dot(residual, z) / (p @ A @ p)
+ ind_dipoles = ind_dipoles + alpha * p
+
+ prev_residual = torch.clone(residual)
+ prev_z = torch.clone(z)
+
+ residual = residual - alpha * A @ p
+
+ rms_residual = torch.sqrt(torch.dot(residual, residual) / len(residual))
+ if torch.dot(residual, residual) < 1e-7:
+ converged_iter = iter_num + 1
+ break
+
+ z = torch.einsum("i,i->i", precondition_m, residual)
+ beta = torch.dot(z, residual) / torch.dot(prev_z, prev_residual)
+ p = z + beta * p
+
+ # Reshape induced dipoles back to (N, 3) for energy calculations
+ ind_dipoles_3d = ind_dipoles.reshape(system.n_particles, 3)
+
+ # Calculate polarization energy
+ # TholeDipole uses the same formula for ALL polarization types: -0.5 * μ · E_fixed
+ # This works because:
+ # - For Direct: μ = α * E_fixed, so U = -0.5 * α * |E_fixed|²
+ # - For Mutual: At SCF convergence, this gives the correct variational energy
+ # - For Extrapolated: The extrapolated dipoles approximate the SCF result
+ coul_energy += -0.5 * torch.dot(ind_dipoles, efield_static)
+
+ return coul_energy
diff --git a/smee/potentials/nonbonded.py b/smee/potentials/nonbonded.py
index a8ac45a..4cee816 100644
--- a/smee/potentials/nonbonded.py
+++ b/smee/potentials/nonbonded.py
@@ -4,6 +4,7 @@
import math
import typing
+import numpy as np
import openff.units
import torch
@@ -796,6 +797,163 @@ def compute_dexp_energy(
return energy
+def _compute_dampedexp6810_lrc(
+ system: smee.TensorSystem,
+ potential: smee.TensorPotential,
+ rs: torch.Tensor | None,
+ rc: torch.Tensor | None,
+ volume: torch.Tensor,
+) -> torch.Tensor:
+ """Computes the long range dispersion correction due to the double exponential
+ potential, possibly with a switching function."""
+
+ raise NotImplementedError
+
+
+@smee.potentials.potential_energy_fn(
+ smee.PotentialType.VDW, smee.EnergyFn.VDW_DAMPEDEXP6810
+)
+def compute_dampedexp6810_energy(
+ system: smee.TensorSystem,
+ potential: smee.TensorPotential,
+ conformer: torch.Tensor,
+ box_vectors: torch.Tensor | None = None,
+ pairwise: PairwiseDistances | None = None,
+) -> torch.Tensor:
+ """Compute the potential energy [kcal / mol] of the vdW interactions using the
+ DampedExp6810 potential.
+
+ Notes:
+ * No cutoff function will be applied if the system is not periodic.
+
+ Args:
+ system: The system to compute the energy for.
+ potential: The potential energy function to evaluate.
+ conformer: The conformer [Å] to evaluate the potential at with
+ ``shape=(n_confs, n_particles, 3)`` or ``shape=(n_particles, 3)``.
+ box_vectors: The box vectors [Å] of the system with ``shape=(n_confs, 3, 3)``
+ or ``shape=(3, 3)`` if the system is periodic, or ``None`` otherwise.
+ pairwise: Pre-computed distances between each pair of particles
+ in the system.
+
+ Returns:
+ The evaluated potential energy [kcal / mol].
+ """
+ box_vectors = None if not system.is_periodic else box_vectors
+
+ cutoff = potential.attributes[potential.attribute_cols.index(smee.CUTOFF_ATTRIBUTE)]
+
+ pairwise = (
+ pairwise
+ if pairwise is not None
+ else compute_pairwise(system, conformer, box_vectors, cutoff)
+ )
+
+ if system.is_periodic and not torch.isclose(pairwise.cutoff, cutoff):
+ raise ValueError("the pairwise cutoff does not match the potential.")
+
+ parameters = smee.potentials.broadcast_parameters(system, potential)
+ pair_scales = compute_pairwise_scales(system, potential)
+
+ pairs_1d = smee.utils.to_upper_tri_idx(
+ pairwise.idxs[:, 0], pairwise.idxs[:, 1], len(parameters)
+ )
+ pair_scales = pair_scales[pairs_1d]
+
+ rho_column = potential.parameter_cols.index("rho")
+ beta_column = potential.parameter_cols.index("beta")
+ c6_column = potential.parameter_cols.index("c6")
+ c8_column = potential.parameter_cols.index("c8")
+ c10_column = potential.parameter_cols.index("c10")
+
+ rho_a = parameters[pairwise.idxs[:, 0], rho_column]
+ rho_b = parameters[pairwise.idxs[:, 1], rho_column]
+ beta_a = parameters[pairwise.idxs[:, 0], beta_column]
+ beta_b = parameters[pairwise.idxs[:, 1], beta_column]
+ c6_a = parameters[pairwise.idxs[:, 0], c6_column]
+ c6_b = parameters[pairwise.idxs[:, 1], c6_column]
+ c8_a = parameters[pairwise.idxs[:, 0], c8_column]
+ c8_b = parameters[pairwise.idxs[:, 1], c8_column]
+ c10_a = parameters[pairwise.idxs[:, 0], c10_column]
+ c10_b = parameters[pairwise.idxs[:, 1], c10_column]
+
+ rho = 0.5 * (rho_a + rho_b)
+ beta = 2.0 * beta_a * beta_b / (beta_a + beta_b)
+ c6 = smee.utils.geometric_mean(c6_a, c6_b)
+ c8 = smee.utils.geometric_mean(c8_a, c8_b)
+ c10 = smee.utils.geometric_mean(c10_a, c10_b)
+
+ if potential.exceptions is not None:
+ exception_idxs, exceptions = smee.potentials.broadcast_exceptions(
+ system, potential, pairwise.idxs[:, 0], pairwise.idxs[:, 1]
+ )
+
+ rho = rho.clone() # prevent in-place modification
+ beta = beta.clone()
+ c6 = c6.clone()
+ c8 = c8.clone()
+ c10 = c10.clone()
+
+ rho[exception_idxs] = exceptions[:, rho_column]
+ beta[exception_idxs] = exceptions[:, beta_column]
+ c6[exception_idxs] = exceptions[:, c6_column]
+ c8[exception_idxs] = exceptions[:, c8_column]
+ c10[exception_idxs] = exceptions[:, c10_column]
+
+ force_at_zero = potential.attributes[
+ potential.attribute_cols.index("force_at_zero")
+ ]
+
+ x = pairwise.distances
+
+ invR = 1.0 / x
+ br = beta * x
+ expbr = torch.exp(-beta * x)
+
+ ttdamp6_sum = (
+ 1.0 + br + br**2 / 2 + br**3 / 6 + br**4 / 24 + br**5 / 120 + br**6 / 720
+ )
+ ttdamp8_sum = ttdamp6_sum + br**7 / 5040 + br**8 / 40320
+ ttdamp10_sum = ttdamp8_sum + br**9 / 362880 + br**10 / 3628800
+
+ ttdamp6 = 1.0 - expbr * ttdamp6_sum
+ ttdamp8 = 1.0 - expbr * ttdamp8_sum
+ ttdamp10 = 1.0 - expbr * ttdamp10_sum
+
+ repulsion = force_at_zero * 1.0 / beta * torch.exp(-beta * (x - rho))
+ energies = (
+ repulsion
+ - ttdamp6 * c6 * x**-6
+ - ttdamp8 * c8 * x**-8
+ - ttdamp10 * c10 * x**-10
+ )
+
+ # Apply exclusion scaling factors
+ energies *= pair_scales
+
+ if not system.is_periodic:
+ return energies.sum(-1)
+
+ switch_fn, switch_width = _compute_switch_fn(potential, pairwise)
+ energies *= switch_fn
+
+ energy = energies.sum(-1)
+
+ energy += _compute_dampedexp6810_lrc(
+ system,
+ potential.to(precision="double"),
+ switch_width.double(),
+ pairwise.cutoff.double(),
+ torch.det(box_vectors),
+ )
+
+ return energy
+
+
+# Import compute_multipole_energy from the new multipole module
+from smee.potentials.multipole import compute_multipole_energy
+
+
def _compute_pme_exclusions(
system: smee.TensorSystem, potential: smee.TensorPotential
) -> torch.Tensor:
diff --git a/smee/tests/convertors/openff/test_nonbonded.py b/smee/tests/convertors/openff/test_nonbonded.py
index 5dbce55..e9f4fcb 100644
--- a/smee/tests/convertors/openff/test_nonbonded.py
+++ b/smee/tests/convertors/openff/test_nonbonded.py
@@ -7,6 +7,7 @@
import smee
import smee.converters
from smee.converters.openff.nonbonded import (
+ convert_dampedexp6810,
convert_dexp,
convert_electrostatics,
convert_vdw,
@@ -312,3 +313,24 @@ def test_convert_dexp(ethanol, test_data_dir):
assert potential.type == "vdW"
assert potential.fn == smee.EnergyFn.VDW_DEXP
+
+
+def test_convert_dampedexp6810(ethanol, test_data_dir):
+ ff = openff.toolkit.ForceField(
+ str(test_data_dir / "PHAST-H2CNO-nonpolar-2.0.0.offxml"), load_plugins=True
+ )
+
+ interchange = openff.interchange.Interchange.from_smirnoff(
+ ff, ethanol.to_topology()
+ )
+ vdw_collection = interchange.collections["DampedExp6810"]
+
+ potential, parameter_maps = convert_dampedexp6810(
+ [vdw_collection], [ethanol.to_topology()], [None]
+ )
+
+ assert potential.attribute_cols[-1] == "force_at_zero"
+ assert potential.parameter_cols == ("rho", "beta", "c6", "c8", "c10")
+
+ assert potential.type == "vdW"
+ assert potential.fn == smee.EnergyFn.VDW_DAMPEDEXP6810
diff --git a/smee/tests/convertors/openmm/test_openmm.py b/smee/tests/convertors/openmm/test_openmm.py
index 7f9606c..18719cc 100644
--- a/smee/tests/convertors/openmm/test_openmm.py
+++ b/smee/tests/convertors/openmm/test_openmm.py
@@ -291,6 +291,45 @@ def test_convert_to_openmm_system_dexp_periodic(test_data_dir):
)
+def test_convert_to_openmm_system_damped6810_periodic(test_data_dir):
+ ff = openff.toolkit.ForceField(
+ str(test_data_dir / "PHAST-H2CNO-nonpolar-2.0.0.offxml"), load_plugins=True
+ )
+ top = openff.toolkit.Topology()
+
+ interchanges = []
+
+ n_copies_per_mol = [5, 5]
+
+ for smiles, n_copies in zip(["OCCO", "O"], n_copies_per_mol):
+ mol = openff.toolkit.Molecule.from_smiles(smiles)
+ mol.generate_conformers(n_conformers=1)
+
+ interchange = openff.interchange.Interchange.from_smirnoff(
+ ff, mol.to_topology()
+ )
+ interchanges.append(interchange)
+
+ for _ in range(n_copies):
+ top.add_molecule(mol)
+
+ tensor_ff, tensor_tops = smee.converters.convert_interchange(interchanges)
+ tensor_system = smee.TensorSystem(tensor_tops, n_copies_per_mol, True)
+
+ coords, _ = smee.mm.generate_system_coords(
+ tensor_system, None, smee.mm.GenerateCoordsConfig()
+ )
+ box_vectors = numpy.eye(3) * 20.0 * openmm.unit.angstrom
+
+ top.box_vectors = box_vectors
+
+ interchange_top = openff.interchange.Interchange.from_smirnoff(ff, top)
+
+ _compare_smee_and_interchange(
+ tensor_ff, tensor_system, interchange_top, coords, box_vectors
+ )
+
+
def test_convert_to_openmm_topology():
formaldehyde_interchange = openff.interchange.Interchange.from_smirnoff(
openff.toolkit.ForceField("openff-2.0.0.offxml"),
diff --git a/smee/tests/data/PHAST-H2CNO-2.0.0.offxml b/smee/tests/data/PHAST-H2CNO-2.0.0.offxml
new file mode 100644
index 0000000..9e53d83
--- /dev/null
+++ b/smee/tests/data/PHAST-H2CNO-2.0.0.offxml
@@ -0,0 +1,388 @@
+
+
+ Adam Hogan
+ 2023-04-26
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/smee/tests/data/PHAST-H2CNO-nonpolar-2.0.0.offxml b/smee/tests/data/PHAST-H2CNO-nonpolar-2.0.0.offxml
new file mode 100644
index 0000000..85be63b
--- /dev/null
+++ b/smee/tests/data/PHAST-H2CNO-nonpolar-2.0.0.offxml
@@ -0,0 +1,387 @@
+
+
+ Adam Hogan
+ 2023-04-26
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/smee/tests/potentials/test_nonbonded.py b/smee/tests/potentials/test_nonbonded.py
index 456c75f..8c59679 100644
--- a/smee/tests/potentials/test_nonbonded.py
+++ b/smee/tests/potentials/test_nonbonded.py
@@ -2,10 +2,13 @@
import math
import numpy
+import openff
import openmm.unit
import pytest
import torch
+from tholedipoleplugin import TholeDipoleForce
+
import smee
import smee.converters
import smee.converters.openmm
@@ -18,19 +21,23 @@
_compute_lj_lrc,
_compute_pme_exclusions,
compute_coulomb_energy,
+ compute_dampedexp6810_energy,
compute_dexp_energy,
compute_lj_energy,
compute_pairwise,
compute_pairwise_scales,
prepare_lrc_types,
)
+from smee.potentials.multipole import compute_multipole_energy
def _compute_openmm_energy(
- system: smee.TensorSystem,
- coords: torch.Tensor,
- box_vectors: torch.Tensor | None,
- potential: smee.TensorPotential,
+ system: smee.TensorSystem,
+ coords: torch.Tensor,
+ box_vectors: torch.Tensor | None,
+ potential: smee.TensorPotential,
+ polarization_type: str | None = None,
+ return_omm_forces: bool | None = None,
) -> torch.Tensor:
coords = coords.numpy() * openmm.unit.angstrom
@@ -40,6 +47,19 @@ def _compute_openmm_energy(
omm_forces = smee.converters.convert_to_openmm_force(potential, system)
omm_system = smee.converters.openmm.create_openmm_system(system, None)
+ # Handle polarization type for TholeDipoleForce
+ if polarization_type is not None:
+ for omm_force in omm_forces:
+ if isinstance(omm_force, TholeDipoleForce):
+ if polarization_type == "direct":
+ omm_force.setPolarizationType(TholeDipoleForce.Direct)
+ elif polarization_type == "mutual":
+ omm_force.setPolarizationType(TholeDipoleForce.Mutual)
+ elif polarization_type == "extrapolated":
+ omm_force.setPolarizationType(TholeDipoleForce.Extrapolated)
+ else:
+ raise ValueError(f"Unknown polarization_type: {polarization_type}")
+
for omm_force in omm_forces:
omm_system.addForce(omm_force)
@@ -47,7 +67,9 @@ def _compute_openmm_energy(
omm_system.setDefaultPeriodicBoxVectors(*box_vectors)
omm_integrator = openmm.VerletIntegrator(1.0 * openmm.unit.femtoseconds)
- omm_context = openmm.Context(omm_system, omm_integrator)
+ # Use Reference platform for TholeDipole (OpenCL may segfault)
+ platform = openmm.Platform.getPlatformByName('Reference')
+ omm_context = openmm.Context(omm_system, omm_integrator, platform)
if box_vectors is not None:
omm_context.setPeriodicBoxVectors(*box_vectors)
@@ -57,7 +79,31 @@ def _compute_openmm_energy(
omm_energy = omm_context.getState(getEnergy=True).getPotentialEnergy()
omm_energy = omm_energy.value_in_unit(openmm.unit.kilocalories_per_mole)
- return torch.tensor(omm_energy, dtype=torch.float64)
+ # Get induced dipoles from TholeDipoleForce
+ try:
+ thole_force = None
+ for force in omm_forces:
+ if isinstance(force, TholeDipoleForce):
+ thole_force = force
+ break
+
+ if thole_force:
+ induced_dipoles = thole_force.getInducedDipoles(omm_context)
+
+ # Convert from e·nm to e·Å (multiply by 10)
+ conversion_factor = 10.0
+ induced_dipoles_angstrom = [[d * conversion_factor for d in dipole] for dipole in induced_dipoles]
+ print(f"\nOpenMM induced dipoles (e·Å):")
+ for i, dipole in enumerate(induced_dipoles_angstrom):
+ print(f" Particle {i}: [{dipole[0]:.10f}, {dipole[1]:.10f}, {dipole[2]:.10f}]")
+
+ except Exception as e:
+ print(f"Could not get induced dipoles: {e}")
+
+ if return_omm_forces:
+ return torch.tensor(omm_energy, dtype=torch.float64), omm_forces
+ else:
+ return torch.tensor(omm_energy, dtype=torch.float64)
def _parameter_key_to_idx(potential: smee.TensorPotential, key: str):
@@ -295,7 +341,7 @@ def test_compute_xxx_lrc_with_exceptions(lrc_fn, convert_fn):
rs = torch.tensor(8.0)
rc = torch.tensor(9.0)
- volume = torch.tensor(18.0**3)
+ volume = torch.tensor(18.0 ** 3)
lrc_no_exceptions = lrc_fn(system, vdw_potential_no_exceptions, rs, rc, volume)
@@ -320,7 +366,7 @@ def test_compute_xxx_lrc_with_exceptions(lrc_fn, convert_fn):
],
)
def test_compute_xxx_energy_periodic(
- energy_fn, convert_fn, etoh_water_system, with_exceptions
+ energy_fn, convert_fn, etoh_water_system, with_exceptions
):
tensor_sys, tensor_ff, coords, box_vectors = etoh_water_system
@@ -378,17 +424,17 @@ def _expected_energy_lj_exceptions(params: dict[str, smee.tests.utils.LJParam]):
sqrt_2 = math.sqrt(2)
expected_energy = 4.0 * (
- params["oh"].eps * (params["oh"].sig ** 12 - params["oh"].sig ** 6)
- + params["oh"].eps * (params["oh"].sig ** 12 - params["oh"].sig ** 6)
- #
- + params["ah"].eps * (params["ah"].sig ** 12 - params["ah"].sig ** 6)
- + params["ah"].eps * (params["ah"].sig ** 12 - params["ah"].sig ** 6)
- #
- + params["hh"].eps
- * ((params["hh"].sig / sqrt_2) ** 12 - (params["hh"].sig / sqrt_2) ** 6)
- #
- + params["oa"].eps
- * ((params["oa"].sig / sqrt_2) ** 12 - (params["oa"].sig / sqrt_2) ** 6)
+ params["oh"].eps * (params["oh"].sig ** 12 - params["oh"].sig ** 6)
+ + params["oh"].eps * (params["oh"].sig ** 12 - params["oh"].sig ** 6)
+ #
+ + params["ah"].eps * (params["ah"].sig ** 12 - params["ah"].sig ** 6)
+ + params["ah"].eps * (params["ah"].sig ** 12 - params["ah"].sig ** 6)
+ #
+ + params["hh"].eps
+ * ((params["hh"].sig / sqrt_2) ** 12 - (params["hh"].sig / sqrt_2) ** 6)
+ #
+ + params["oa"].eps
+ * ((params["oa"].sig / sqrt_2) ** 12 - (params["oa"].sig / sqrt_2) ** 6)
)
return expected_energy
@@ -410,15 +456,15 @@ def _dexp(eps, sig, dist):
sqrt_2 = math.sqrt(2)
expected_energy = (
- _dexp(params["oh"].eps, params["oh"].sig, 1.0)
- + _dexp(params["oh"].eps, params["oh"].sig, 1.0)
- #
- + _dexp(params["ah"].eps, params["ah"].sig, 1.0)
- + _dexp(params["ah"].eps, params["ah"].sig, 1.0)
- #
- + _dexp(params["hh"].eps, params["hh"].sig, sqrt_2)
- #
- + _dexp(params["oa"].eps, params["oa"].sig, sqrt_2)
+ _dexp(params["oh"].eps, params["oh"].sig, 1.0)
+ + _dexp(params["oh"].eps, params["oh"].sig, 1.0)
+ #
+ + _dexp(params["ah"].eps, params["ah"].sig, 1.0)
+ + _dexp(params["ah"].eps, params["ah"].sig, 1.0)
+ #
+ + _dexp(params["hh"].eps, params["hh"].sig, sqrt_2)
+ #
+ + _dexp(params["oa"].eps, params["oa"].sig, sqrt_2)
)
return expected_energy
@@ -428,9 +474,9 @@ def _dexp(eps, sig, dist):
[
(compute_lj_energy, lambda p: p, _expected_energy_lj_exceptions),
(
- compute_dexp_energy,
- smee.tests.utils.convert_lj_to_dexp,
- _expected_energy_dexp_exceptions,
+ compute_dexp_energy,
+ smee.tests.utils.convert_lj_to_dexp,
+ _expected_energy_dexp_exceptions,
),
],
)
@@ -530,3 +576,1039 @@ def test_compute_coulomb_energy_non_periodic():
)
assert torch.isclose(energy, expected_energy, atol=1.0e-4)
+
+
+def test_compute_dampedexp6810_energy_ne_scan_non_periodic(test_data_dir):
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ ["[Ne]", "[Ne]"],
+ [1, 1],
+ openff.toolkit.ForceField(
+ str(test_data_dir / "PHAST-H2CNO-nonpolar-2.0.0.offxml"), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = False
+
+ coords = torch.stack(
+ [
+ torch.vstack(
+ [
+ torch.tensor([0, 0, 0]),
+ torch.tensor([0, 0, 0]) + torch.tensor([0, 0, 1.5 + i * 0.5]),
+ ]
+ )
+ for i in range(20)
+ ]
+ )
+
+ energies = smee.compute_energy(tensor_sys, tensor_ff, coords)
+ expected_energies = []
+ for coord in coords:
+ expected_energies.append(
+ _compute_openmm_energy(
+ tensor_sys, coord, None, tensor_ff.potentials_by_type["vdW"]
+ )
+ )
+ expected_energies = torch.tensor(expected_energies)
+ assert torch.allclose(energies, expected_energies.float(), atol=1.0e-4)
+
+
+@pytest.mark.parametrize(
+ "forcefield_name,polarization_type",
+ [
+ ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "mutual"),
+ ("PHAST-H2CNO-2.0.0.offxml", "extrapolated")
+ ]
+)
+def test_compute_multipole_energy_CC_O_non_periodic(test_data_dir, forcefield_name, polarization_type):
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ ["CC", "O"],
+ [3, 2],
+ openff.toolkit.ForceField(
+ str(test_data_dir / forcefield_name), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = False
+
+ # Use fixed coordinates to ensure reproducibility
+ coords = torch.tensor([
+ [5.9731, 4.8234, 5.1358],
+ [5.6308, 3.4725, 5.7007],
+ [5.0358, 5.2467, 4.7020],
+ [6.2522, 5.4850, 5.9780],
+ [6.7136, 4.7967, 4.3256],
+ [4.9936, 3.6648, 6.6100],
+ [6.5061, 2.9131, 6.0617],
+ [5.0991, 2.8173, 4.9849],
+ [0.9326, 2.8105, 5.2711],
+ [0.9434, 1.3349, 5.5607],
+ [1.1295, 2.9339, 4.1794],
+ [-0.0939, 3.1853, 5.4460],
+ [1.7103, 3.3774, 5.7996],
+ [0.0123, 0.9149, 5.0849],
+ [0.8655, 1.0972, 6.6316],
+ [1.8432, 0.8172, 5.1776],
+ [3.2035, 0.7561, 3.0346],
+ [3.4468, 1.0277, 1.5757],
+ [4.1522, 0.3467, 3.4566],
+ [3.0323, 1.7263, 3.5387],
+ [2.4222, 0.0093, 3.2280],
+ [4.1461, 1.9103, 1.5332],
+ [2.5430, 1.3356, 1.0299],
+ [3.8762, 0.1647, 1.0324],
+ [6.3764, 1.9600, 2.9162],
+ [6.1056, 1.3456, 3.6328],
+ [6.6023, 2.8357, 3.3122],
+ [3.0792, 6.2544, 4.6979],
+ [3.5093, 6.6131, 5.5045],
+ [3.5237, 6.6324, 3.9016]], dtype=torch.float64)
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+ es_potential.parameters.requires_grad = True
+
+ energy = compute_multipole_energy(tensor_sys, es_potential, coords.float(), None,
+ polarization_type=polarization_type)
+ energy.backward()
+ expected_energy, omm_forces = _compute_openmm_energy(tensor_sys, coords, None, es_potential,
+ polarization_type=polarization_type, return_omm_forces=True)
+ print_debug_info_multipole(energy, expected_energy, tensor_sys, es_potential, omm_forces)
+ assert torch.allclose(energy, expected_energy, atol=1e-3)
+
+
+@pytest.mark.parametrize(
+ "forcefield_name,polarization_type",
+ [
+ #("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "direct"),
+ #("PHAST-H2CNO-2.0.0.offxml", "mutual"),
+ #("PHAST-H2CNO-2.0.0.offxml", "extrapolated")
+ ]
+)
+def test_compute_multipole_energy_charged_ne_non_periodic(test_data_dir, forcefield_name, polarization_type):
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ ["[Ne]", "[Ne]"],
+ [1, 1],
+ openff.toolkit.ForceField(
+ str(test_data_dir / forcefield_name), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = False
+
+ coords = torch.vstack([torch.tensor([0, 0, 0]), torch.tensor([0, 0, 3.0])])
+
+ # give each atom a charge otherwise the system is neutral
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+ es_potential.parameters[0, 0] = 1
+
+ energy = compute_multipole_energy(
+ tensor_sys, es_potential, coords, None, polarization_type=polarization_type
+ )
+
+ expected_energy, omm_forces = _compute_openmm_energy(tensor_sys, coords, None, es_potential,
+ polarization_type=polarization_type, return_omm_forces=True)
+ print_debug_info_multipole(energy, expected_energy, tensor_sys, es_potential, omm_forces)
+ assert torch.allclose(energy, expected_energy, atol=1e-3)
+
+
+@pytest.mark.parametrize(
+ "forcefield_name,polarization_type",
+ [
+ ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "mutual"),
+ ("PHAST-H2CNO-2.0.0.offxml", "extrapolated")
+ ]
+)
+def test_compute_multipole_energy_c_xe_non_periodic(test_data_dir, forcefield_name, polarization_type):
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ ["C", "[Xe]"],
+ [1, 1],
+ openff.toolkit.ForceField(
+ str(test_data_dir / forcefield_name), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = False
+
+ coords = torch.tensor(
+ [
+ [+0.00000, +0.00000, +0.00000],
+ [+0.00000, +0.00000, +1.08900],
+ [+1.02672, +0.00000, -0.36300],
+ [-0.51336, -0.88916, -0.36300],
+ [-0.51336, +0.88916, -0.36300],
+ [+3.00000, +0.00000, +0.00000],
+ ]
+ )
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ energy = compute_multipole_energy(
+ tensor_sys, es_potential, coords, None, polarization_type=polarization_type
+ )
+
+ expected_energy, omm_forces = _compute_openmm_energy(tensor_sys, coords, None, es_potential,
+ polarization_type=polarization_type, return_omm_forces=True)
+ print_debug_info_multipole(energy, expected_energy, tensor_sys, es_potential, omm_forces)
+ assert torch.allclose(energy, expected_energy, atol=1e-3)
+
+
+@pytest.mark.parametrize(
+ "forcefield_name,polarization_type",
+ [
+ ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "mutual"),
+ ("PHAST-H2CNO-2.0.0.offxml", "extrapolated")
+ ]
+)
+def test_compute_phast2_energy_water_conformers_non_periodic(test_data_dir, forcefield_name, polarization_type):
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ ["O"],
+ [2],
+ openff.toolkit.ForceField(
+ str(test_data_dir / forcefield_name), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = False
+
+ coords = torch.tensor([[[-5.5964e-02, 8.1693e-01, -5.3445e-01],
+ [+2.5174e-01, -5.8659e-01, -8.1979e-01],
+ [+0.0000e+00, 0.0000e+00, 0.0000e+00],
+ [+7.6271e+00, -6.6103e-01, -5.7262e-01],
+ [+7.7119e+00, -4.1601e-01, 9.3098e-01],
+ [+7.9377e+00, 0.0000e+00, 0.0000e+00]],
+ [[+7.1041e-01, 4.7487e-01, -1.5602e-01],
+ [-4.8097e-01, 7.2769e-01, -2.2119e-01],
+ [+0.0000e+00, 0.0000e+00, 0.0000e+00],
+ [+8.1144e+00, -8.7009e-01, -3.9085e-01],
+ [+8.1329e+00, 9.2279e-01, -4.4597e-01],
+ [+7.9377e+00, 0.0000e+00, 0.0000e+00]],
+ [[+2.1348e-01, 3.6725e-01, 8.3273e-01],
+ [-5.7851e-01, -6.4377e-01, 5.4664e-01],
+ [+0.0000e+00, 0.0000e+00, 0.0000e+00],
+ [+7.2758e+00, 3.1414e-01, -5.9182e-01],
+ [+7.7279e+00, -5.7537e-01, 7.6088e-01],
+ [+7.9377e+00, 0.0000e+00, 0.0000e+00]]])
+
+ multipole_potential = tensor_ff.potentials_by_type["Electrostatics"]
+ vdw_potential = tensor_ff.potentials_by_type["vdW"]
+
+ multipole_energy = compute_multipole_energy(
+ tensor_sys, multipole_potential, coords, None, polarization_type=polarization_type
+ )
+
+ multipole_expected_energy = torch.tensor(
+ [_compute_openmm_energy(tensor_sys, coord, None, multipole_potential, polarization_type=polarization_type) for coord in coords]
+ )
+
+ vdw_energy = compute_dampedexp6810_energy(
+ tensor_sys, vdw_potential, coords, None
+ )
+
+ vdw_expected_energy = torch.tensor(
+ [_compute_openmm_energy(tensor_sys, coord, None, vdw_potential) for coord in coords]
+ )
+
+ assert torch.allclose(multipole_energy, multipole_expected_energy, atol=1.0e-3)
+ assert torch.allclose(vdw_energy, vdw_expected_energy, atol=1.0e-4)
+
+
+@pytest.mark.parametrize(
+ "forcefield_name,polarization_type",
+ [
+ ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "mutual"),
+ ("PHAST-H2CNO-2.0.0.offxml", "extrapolated")
+ ]
+)
+@pytest.mark.parametrize("smiles", ["CC", "CCC", "CCCC", "CCCCC"])
+def test_compute_multipole_energy_isolated_non_periodic(test_data_dir, forcefield_name, polarization_type, smiles):
+
+ print(f"\n{forcefield_name} - {polarization_type} - {smiles}")
+
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ [smiles],
+ [1],
+ openff.toolkit.ForceField(
+ str(test_data_dir / forcefield_name), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = False
+
+ coords, _ = smee.mm.generate_system_coords(tensor_sys, None)
+ coords = torch.tensor(coords.value_in_unit(openmm.unit.angstrom))
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+ es_potential.parameters.requires_grad = True
+
+ energy = compute_multipole_energy(tensor_sys, es_potential, coords.float(), None,
+ polarization_type=polarization_type)
+ energy.backward()
+ expected_energy, omm_forces = _compute_openmm_energy(tensor_sys, coords, None, es_potential,
+ polarization_type=polarization_type, return_omm_forces=True)
+ print_debug_info_multipole(energy, expected_energy, tensor_sys, es_potential, omm_forces)
+ assert torch.allclose(energy, expected_energy, atol=1e-3)
+
+
+@pytest.mark.parametrize(
+ "forcefield_name,polarization_type",
+ [
+ ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "mutual"),
+ ("PHAST-H2CNO-2.0.0.offxml", "extrapolated")
+ ]
+)
+def test_compute_multipole_energy_CC_O_periodic(test_data_dir, forcefield_name, polarization_type):
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ ["CC", "O"],
+ [10, 15],
+ openff.toolkit.ForceField(
+ str(test_data_dir / forcefield_name), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = True
+
+ # Use lower density to ensure box size >= 2*cutoff (18.0 Å)
+ config = smee.mm.GenerateCoordsConfig(
+ target_density=0.4 * openmm.unit.gram / openmm.unit.milliliter,
+ scale_factor=1.3,
+ padding=3.0 * openmm.unit.angstrom,
+ )
+ coords, box_vectors = smee.mm.generate_system_coords(tensor_sys, None, config)
+ coords = torch.tensor(coords.value_in_unit(openmm.unit.angstrom))
+ box_vectors = torch.tensor(box_vectors.value_in_unit(openmm.unit.angstrom))
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+ es_potential.parameters.requires_grad = True
+
+ energy = compute_multipole_energy(tensor_sys, es_potential, coords.float(), box_vectors.float(),
+ polarization_type=polarization_type)
+ energy.backward()
+ expected_energy, omm_forces = _compute_openmm_energy(tensor_sys, coords, box_vectors, es_potential,
+ polarization_type=polarization_type, return_omm_forces=True)
+ print_debug_info_multipole(energy, expected_energy, tensor_sys, es_potential, omm_forces)
+ assert torch.allclose(energy, expected_energy, atol=1e-3)
+
+
+def print_debug_info_multipole(energy: torch.Tensor,
+ expected_energy: torch.Tensor,
+ tensor_sys: smee.TensorSystem,
+ es_potential: smee.TensorPotential,
+ omm_forces: list[openmm.Force]):
+ print(f"Energy\nSMEE {energy} OpenMM {expected_energy}")
+
+ print(f"SMEE Parameters {es_potential.parameters}")
+
+ for idx, topology in enumerate(tensor_sys.topologies):
+ print(f"SMEE Topology {idx}")
+ print(f"Assignment Matrix {topology.parameters[es_potential.type].assignment_matrix.to_dense()}")
+
+ thole_force = None
+ for force in omm_forces:
+ if isinstance(force, TholeDipoleForce):
+ thole_force = force
+ break
+
+ print(thole_force)
+
+
+# =============================================================================
+# Water Model Tests - Testing SMEE with water force field from water_fitting
+# =============================================================================
+
+def _get_water_model_forcefield():
+ """Load the water model force field."""
+ import pathlib
+ water_ff_path = pathlib.Path(__file__).parent.parent.parent.parent / "water_fitting" / "water_model.offxml"
+ if not water_ff_path.exists():
+ pytest.skip(f"Water model not found at {water_ff_path}")
+ return openff.toolkit.ForceField(str(water_ff_path), load_plugins=True)
+
+
+def _get_water_dimer_coords():
+ """Return coordinates for a water dimer in Angstroms.
+
+ First water at origin, second water displaced along z-axis.
+ Standard water geometry: O-H bond length ~0.9572 Å, H-O-H angle ~104.5°
+ """
+ # Water 1 (near origin)
+ o1 = [0.0, 0.0, 0.0]
+ h1a = [0.9572, 0.0, 0.0]
+ h1b = [-0.2399, 0.9270, 0.0]
+
+ # Water 2 (displaced by ~2.8 Å along z-axis - typical H-bond distance)
+ z_offset = 2.8
+ o2 = [0.0, 0.0, z_offset]
+ h2a = [0.9572, 0.0, z_offset]
+ h2b = [-0.2399, 0.9270, z_offset]
+
+ coords = torch.tensor([o1, h1a, h1b, o2, h2a, h2b], dtype=torch.float64)
+ return coords
+
+
+def _get_water_trimer_coords():
+ """Return coordinates for a water trimer in Angstroms.
+
+ Three water molecules in a triangular arrangement.
+ """
+ import math
+
+ # Water 1 at origin
+ o1 = [0.0, 0.0, 0.0]
+ h1a = [0.9572, 0.0, 0.0]
+ h1b = [-0.2399, 0.9270, 0.0]
+
+ # Water 2 displaced along x
+ x_offset = 3.0
+ o2 = [x_offset, 0.0, 0.0]
+ h2a = [x_offset + 0.9572, 0.0, 0.0]
+ h2b = [x_offset - 0.2399, 0.9270, 0.0]
+
+ # Water 3 displaced to form triangle
+ angle = math.pi / 3 # 60 degrees
+ r = 3.0
+ o3 = [r * math.cos(angle), r * math.sin(angle), 0.0]
+ h3a = [o3[0] + 0.9572, o3[1], 0.0]
+ h3b = [o3[0] - 0.2399, o3[1] + 0.9270, 0.0]
+
+ coords = torch.tensor([o1, h1a, h1b, o2, h2a, h2b, o3, h3a, h3b], dtype=torch.float64)
+ return coords
+
+
+@pytest.mark.parametrize("polarization_type", ["direct", "mutual", "extrapolated"])
+def test_water_dimer_energy(polarization_type):
+ """Test water dimer energy matches OpenMM/TholeDipoleForce."""
+ force_field = _get_water_model_forcefield()
+
+ water1 = openff.toolkit.Molecule.from_smiles("O")
+ water2 = openff.toolkit.Molecule.from_smiles("O")
+ topology = openff.toolkit.Topology.from_molecules([water1, water2])
+
+ interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology)
+ tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange)
+
+ system = smee.TensorSystem([tensor_top], [1], is_periodic=False)
+ coords = _get_water_dimer_coords()
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ # Compute SMEE energy
+ energy = compute_multipole_energy(
+ system, es_potential, coords.float(), None, polarization_type=polarization_type
+ )
+
+ # Compute OpenMM reference energy
+ expected_energy = _compute_openmm_energy(
+ system, coords, None, es_potential, polarization_type=polarization_type
+ )
+
+ print(f"\nWater dimer ({polarization_type}):")
+ print(f" SMEE Energy: {energy.item():.6f} kcal/mol")
+ print(f" OpenMM Energy: {expected_energy.item():.6f} kcal/mol")
+ print(f" Difference: {abs(energy.item() - expected_energy.item()):.2e} kcal/mol")
+
+ assert torch.allclose(energy, expected_energy, atol=1e-4)
+
+
+@pytest.mark.parametrize("polarization_type", ["direct", "mutual", "extrapolated"])
+def test_water_dimer_forces(polarization_type):
+ """Test water dimer forces via autograd match finite difference."""
+ force_field = _get_water_model_forcefield()
+
+ water1 = openff.toolkit.Molecule.from_smiles("O")
+ water2 = openff.toolkit.Molecule.from_smiles("O")
+ topology = openff.toolkit.Topology.from_molecules([water1, water2])
+
+ interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology)
+ tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange)
+
+ system = smee.TensorSystem([tensor_top], [1], is_periodic=False)
+ coords = _get_water_dimer_coords().float()
+ coords.requires_grad = True
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ # Compute SMEE energy and forces
+ energy = compute_multipole_energy(
+ system, es_potential, coords, None, polarization_type=polarization_type
+ )
+ energy.backward()
+ smee_forces = -coords.grad.detach()
+
+ # Compute numerical forces via finite difference
+ # Use h=1e-2 which is optimal for float32 precision (smaller h causes precision loss)
+ h = 1e-2
+ numerical_forces = torch.zeros_like(coords)
+
+ for i in range(coords.shape[0]):
+ for j in range(3):
+ coords_plus = coords.detach().clone()
+ coords_plus[i, j] += h
+ e_plus = compute_multipole_energy(
+ system, es_potential, coords_plus, None, polarization_type=polarization_type
+ )
+
+ coords_minus = coords.detach().clone()
+ coords_minus[i, j] -= h
+ e_minus = compute_multipole_energy(
+ system, es_potential, coords_minus, None, polarization_type=polarization_type
+ )
+
+ numerical_forces[i, j] = -(e_plus - e_minus) / (2 * h)
+
+ print(f"\nWater dimer forces ({polarization_type}):")
+ print(f" Max force difference (autograd vs numerical): {(smee_forces - numerical_forces).abs().max():.2e}")
+
+ assert torch.allclose(smee_forces, numerical_forces, atol=1e-3)
+
+
+@pytest.mark.parametrize("polarization_type", ["direct", "mutual", "extrapolated"])
+def test_water_trimer_energy(polarization_type):
+ """Test water trimer energy matches OpenMM/TholeDipoleForce."""
+ force_field = _get_water_model_forcefield()
+
+ waters = [openff.toolkit.Molecule.from_smiles("O") for _ in range(3)]
+ topology = openff.toolkit.Topology.from_molecules(waters)
+
+ interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology)
+ tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange)
+
+ system = smee.TensorSystem([tensor_top], [1], is_periodic=False)
+ coords = _get_water_trimer_coords()
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ # Compute SMEE energy
+ energy = compute_multipole_energy(
+ system, es_potential, coords.float(), None, polarization_type=polarization_type
+ )
+
+ # Compute OpenMM reference energy
+ expected_energy = _compute_openmm_energy(
+ system, coords, None, es_potential, polarization_type=polarization_type
+ )
+
+ print(f"\nWater trimer ({polarization_type}):")
+ print(f" SMEE Energy: {energy.item():.6f} kcal/mol")
+ print(f" OpenMM Energy: {expected_energy.item():.6f} kcal/mol")
+ print(f" Difference: {abs(energy.item() - expected_energy.item()):.2e} kcal/mol")
+
+ assert torch.allclose(energy, expected_energy, atol=1e-4)
+
+
+def test_single_water_zero_energy():
+ """Test that a single isolated water molecule has zero intermolecular energy."""
+ force_field = _get_water_model_forcefield()
+
+ water = openff.toolkit.Molecule.from_smiles("O")
+ topology = openff.toolkit.Topology.from_molecules([water])
+
+ interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology)
+ tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange)
+
+ system = smee.TensorSystem([tensor_top], [1], is_periodic=False)
+
+ # Standard water geometry
+ coords = torch.tensor([
+ [0.0, 0.0, 0.0],
+ [0.9572, 0.0, 0.0],
+ [-0.2399, 0.9270, 0.0],
+ ], dtype=torch.float32)
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ # Compute SMEE energy
+ energy = compute_multipole_energy(
+ system, es_potential, coords, None, polarization_type="direct"
+ )
+
+ # Single molecule should have zero intermolecular energy
+ # (intramolecular terms are excluded via covalent maps)
+ print(f"\nSingle water molecule energy: {energy.item():.6e} kcal/mol")
+
+ assert torch.allclose(energy, torch.tensor(0.0, dtype=energy.dtype), atol=1e-6)
+
+
+@pytest.mark.parametrize("polarization_type", ["direct", "mutual", "extrapolated"])
+def test_water_dimer_multiple_conformers(polarization_type):
+ """Test water dimer with multiple conformers (batched computation)."""
+ force_field = _get_water_model_forcefield()
+
+ water1 = openff.toolkit.Molecule.from_smiles("O")
+ water2 = openff.toolkit.Molecule.from_smiles("O")
+ topology = openff.toolkit.Topology.from_molecules([water1, water2])
+
+ interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology)
+ tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange)
+
+ system = smee.TensorSystem([tensor_top], [1], is_periodic=False)
+
+ # Create multiple conformers with different separations
+ base_coords = _get_water_dimer_coords()
+ conformers = []
+ for z_offset in [2.5, 3.0, 3.5, 4.0]:
+ coords = base_coords.clone()
+ # Shift second water
+ coords[3:, 2] = z_offset
+ conformers.append(coords)
+
+ coords_batch = torch.stack(conformers).float()
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ # Compute SMEE energies for batch
+ energies = compute_multipole_energy(
+ system, es_potential, coords_batch, None, polarization_type=polarization_type
+ )
+
+ # Compute OpenMM reference energies individually
+ expected_energies = torch.tensor([
+ _compute_openmm_energy(system, coords, None, es_potential, polarization_type=polarization_type)
+ for coords in conformers
+ ])
+
+ print(f"\nWater dimer multiple conformers ({polarization_type}):")
+ for i, (e, exp_e) in enumerate(zip(energies, expected_energies)):
+ print(f" Conformer {i}: SMEE={e.item():.6f}, OpenMM={exp_e.item():.6f}, diff={abs(e.item()-exp_e.item()):.2e}")
+
+ assert torch.allclose(energies, expected_energies, atol=1e-4)
+
+
+@pytest.mark.parametrize("polarization_type", ["direct", "mutual", "extrapolated"])
+def test_water_dimer_axis_resolution(polarization_type):
+ """Test that axis atoms are correctly resolved from SMIRKS indices to topology indices."""
+ force_field = _get_water_model_forcefield()
+
+ water1 = openff.toolkit.Molecule.from_smiles("O")
+ water2 = openff.toolkit.Molecule.from_smiles("O")
+ topology = openff.toolkit.Topology.from_molecules([water1, water2])
+
+ interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology)
+ tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange)
+
+ system = smee.TensorSystem([tensor_top], [1], is_periodic=False)
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+ param_map = tensor_top.parameters[es_potential.type]
+ assigned_params = param_map.assignment_matrix @ es_potential.parameters
+
+ print(f"\nAxis resolution test ({polarization_type}):")
+ print("Checking that axis atoms are actual topology indices (not SMIRKS indices):")
+
+ # For 6 atoms (2 waters), axis atoms should be 0-5 (topology indices)
+ # NOT 1-3 (SMIRKS indices)
+ for i in range(6):
+ axis_type = int(assigned_params[i, 13])
+ atom_z = int(assigned_params[i, 14])
+ atom_x = int(assigned_params[i, 15])
+ print(f" Atom {i}: axisType={axis_type}, atomZ={atom_z}, atomX={atom_x}")
+
+ # Verify axis atoms are valid topology indices
+ if atom_z >= 0:
+ assert atom_z < 6, f"atomZ={atom_z} exceeds topology size (6 atoms)"
+ if atom_x >= 0:
+ assert atom_x < 6, f"atomX={atom_x} exceeds topology size (6 atoms)"
+
+ # Verify we can compute energy without errors
+ coords = _get_water_dimer_coords()
+ energy = compute_multipole_energy(
+ system, es_potential, coords.float(), None, polarization_type=polarization_type
+ )
+ assert torch.isfinite(energy)
+
+
+@pytest.mark.parametrize("n_waters", [2, 4, 8])
+def test_water_cluster_energy(n_waters):
+ """Test water clusters of various sizes."""
+ force_field = _get_water_model_forcefield()
+
+ waters = [openff.toolkit.Molecule.from_smiles("O") for _ in range(n_waters)]
+ topology = openff.toolkit.Topology.from_molecules(waters)
+
+ interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology)
+ tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange)
+
+ system = smee.TensorSystem([tensor_top], [1], is_periodic=False)
+
+ # Generate coordinates for water cluster
+ import numpy as np
+ np.random.seed(42)
+
+ coords_list = []
+ spacing = 3.0 # Å between water molecules
+
+ for i in range(n_waters):
+ # Place water molecules in a grid
+ x = (i % 2) * spacing
+ y = ((i // 2) % 2) * spacing
+ z = (i // 4) * spacing
+
+ # Standard water geometry
+ coords_list.append([x, y, z]) # O
+ coords_list.append([x + 0.9572, y, z]) # H
+ coords_list.append([x - 0.2399, y + 0.9270, z]) # H
+
+ coords = torch.tensor(coords_list, dtype=torch.float64)
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ # Compute SMEE energy
+ energy = compute_multipole_energy(
+ system, es_potential, coords.float(), None, polarization_type="mutual"
+ )
+
+ # Compute OpenMM reference energy
+ expected_energy = _compute_openmm_energy(
+ system, coords, None, es_potential, polarization_type="mutual"
+ )
+
+ print(f"\nWater cluster (n={n_waters}):")
+ print(f" SMEE Energy: {energy.item():.6f} kcal/mol")
+ print(f" OpenMM Energy: {expected_energy.item():.6f} kcal/mol")
+ print(f" Difference: {abs(energy.item() - expected_energy.item()):.2e} kcal/mol")
+
+ assert torch.allclose(energy, expected_energy, atol=1e-3)
+
+
+# =============================================================================
+# Ammonia Tests - Testing ThreeFold axis type and different molecular geometry
+# =============================================================================
+
+def _get_ammonia_dimer_coords():
+ """Return coordinates for an ammonia dimer in Angstroms.
+
+ Ammonia geometry: N-H bond length ~1.012 Å, H-N-H angle ~106.7°
+ Tetrahedral-like geometry with lone pair.
+ """
+ import math
+
+ # Ammonia 1 at origin
+ # N at origin, 3 H atoms in tetrahedral-like arrangement
+ n1 = [0.0, 0.0, 0.0]
+ # H atoms roughly 1.012 Å from N, ~106.7° H-N-H angle
+ h1a = [0.9377, 0.0, -0.3816]
+ h1b = [-0.4689, 0.8121, -0.3816]
+ h1c = [-0.4689, -0.8121, -0.3816]
+
+ # Ammonia 2 displaced along z-axis
+ z_offset = 3.5
+ n2 = [0.0, 0.0, z_offset]
+ h2a = [0.9377, 0.0, z_offset - 0.3816]
+ h2b = [-0.4689, 0.8121, z_offset - 0.3816]
+ h2c = [-0.4689, -0.8121, z_offset - 0.3816]
+
+ coords = torch.tensor([n1, h1a, h1b, h1c, n2, h2a, h2b, h2c], dtype=torch.float64)
+ return coords
+
+
+def _get_ammonia_trimer_coords():
+ """Return coordinates for an ammonia trimer in Angstroms."""
+ import math
+
+ # Ammonia 1 at origin
+ n1 = [0.0, 0.0, 0.0]
+ h1a = [0.9377, 0.0, -0.3816]
+ h1b = [-0.4689, 0.8121, -0.3816]
+ h1c = [-0.4689, -0.8121, -0.3816]
+
+ # Ammonia 2 displaced along x
+ x_offset = 3.5
+ n2 = [x_offset, 0.0, 0.0]
+ h2a = [x_offset + 0.9377, 0.0, -0.3816]
+ h2b = [x_offset - 0.4689, 0.8121, -0.3816]
+ h2c = [x_offset - 0.4689, -0.8121, -0.3816]
+
+ # Ammonia 3 displaced to form triangle
+ angle = math.pi / 3
+ r = 3.5
+ n3 = [r * math.cos(angle), r * math.sin(angle), 0.0]
+ h3a = [n3[0] + 0.9377, n3[1], -0.3816]
+ h3b = [n3[0] - 0.4689, n3[1] + 0.8121, -0.3816]
+ h3c = [n3[0] - 0.4689, n3[1] - 0.8121, -0.3816]
+
+ coords = torch.tensor([
+ n1, h1a, h1b, h1c,
+ n2, h2a, h2b, h2c,
+ n3, h3a, h3b, h3c
+ ], dtype=torch.float64)
+ return coords
+
+
+@pytest.mark.parametrize(
+ "forcefield_name,polarization_type",
+ [
+ ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "mutual"),
+ ("PHAST-H2CNO-2.0.0.offxml", "extrapolated"),
+ ]
+)
+def test_ammonia_dimer_energy(test_data_dir, forcefield_name, polarization_type):
+ """Test ammonia dimer energy matches OpenMM/TholeDipoleForce.
+
+ Ammonia uses ThreeFold axis type for N and ZOnly for H atoms.
+ """
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ ["N", "N"],
+ [1, 1],
+ openff.toolkit.ForceField(
+ str(test_data_dir / forcefield_name), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = False
+
+ coords = _get_ammonia_dimer_coords()
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ energy = compute_multipole_energy(
+ tensor_sys, es_potential, coords.float(), None, polarization_type=polarization_type
+ )
+
+ expected_energy = _compute_openmm_energy(
+ tensor_sys, coords, None, es_potential, polarization_type=polarization_type
+ )
+
+ print(f"\nAmmonia dimer ({forcefield_name}, {polarization_type}):")
+ print(f" SMEE Energy: {energy.item():.6f} kcal/mol")
+ print(f" OpenMM Energy: {expected_energy.item():.6f} kcal/mol")
+ print(f" Difference: {abs(energy.item() - expected_energy.item()):.2e} kcal/mol")
+
+ assert torch.allclose(energy, expected_energy, atol=1e-3)
+
+
+@pytest.mark.parametrize(
+ "forcefield_name,polarization_type",
+ [
+ ("PHAST-H2CNO-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "mutual"),
+ ("PHAST-H2CNO-2.0.0.offxml", "extrapolated"),
+ ]
+)
+def test_ammonia_dimer_forces(test_data_dir, forcefield_name, polarization_type):
+ """Test ammonia dimer forces via autograd match finite difference."""
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ ["N", "N"],
+ [1, 1],
+ openff.toolkit.ForceField(
+ str(test_data_dir / forcefield_name), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = False
+
+ coords = _get_ammonia_dimer_coords().float()
+ coords.requires_grad = True
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ energy = compute_multipole_energy(
+ tensor_sys, es_potential, coords, None, polarization_type=polarization_type
+ )
+ energy.backward()
+ smee_forces = -coords.grad.detach()
+
+ # Use h=1e-2 for optimal float32 finite difference precision
+ h = 1e-2
+ numerical_forces = torch.zeros_like(coords)
+
+ for i in range(coords.shape[0]):
+ for j in range(3):
+ coords_plus = coords.detach().clone()
+ coords_plus[i, j] += h
+ e_plus = compute_multipole_energy(
+ tensor_sys, es_potential, coords_plus, None, polarization_type=polarization_type
+ )
+
+ coords_minus = coords.detach().clone()
+ coords_minus[i, j] -= h
+ e_minus = compute_multipole_energy(
+ tensor_sys, es_potential, coords_minus, None, polarization_type=polarization_type
+ )
+
+ numerical_forces[i, j] = -(e_plus - e_minus) / (2 * h)
+
+ print(f"\nAmmonia dimer forces ({forcefield_name}, {polarization_type}):")
+ print(f" Max force difference (autograd vs numerical): {(smee_forces - numerical_forces).abs().max():.2e}")
+
+ assert torch.allclose(smee_forces, numerical_forces, atol=1e-3)
+
+
+@pytest.mark.parametrize(
+ "forcefield_name,polarization_type",
+ [
+ ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "mutual"),
+ ("PHAST-H2CNO-2.0.0.offxml", "extrapolated"),
+ ]
+)
+def test_ammonia_trimer_energy(test_data_dir, forcefield_name, polarization_type):
+ """Test ammonia trimer energy matches OpenMM/TholeDipoleForce."""
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ ["N"],
+ [3],
+ openff.toolkit.ForceField(
+ str(test_data_dir / forcefield_name), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = False
+
+ coords = _get_ammonia_trimer_coords()
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ energy = compute_multipole_energy(
+ tensor_sys, es_potential, coords.float(), None, polarization_type=polarization_type
+ )
+
+ expected_energy = _compute_openmm_energy(
+ tensor_sys, coords, None, es_potential, polarization_type=polarization_type
+ )
+
+ print(f"\nAmmonia trimer ({forcefield_name}, {polarization_type}):")
+ print(f" SMEE Energy: {energy.item():.6f} kcal/mol")
+ print(f" OpenMM Energy: {expected_energy.item():.6f} kcal/mol")
+ print(f" Difference: {abs(energy.item() - expected_energy.item()):.2e} kcal/mol")
+
+ assert torch.allclose(energy, expected_energy, atol=1e-3)
+
+
+@pytest.mark.parametrize(
+ "forcefield_name",
+ ["PHAST-H2CNO-nonpolar-2.0.0.offxml", "PHAST-H2CNO-2.0.0.offxml"]
+)
+def test_single_ammonia_zero_energy(test_data_dir, forcefield_name):
+ """Test that a single isolated ammonia molecule has zero intermolecular energy."""
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ ["N"],
+ [1],
+ openff.toolkit.ForceField(
+ str(test_data_dir / forcefield_name), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = False
+
+ # Standard ammonia geometry
+ coords = torch.tensor([
+ [0.0, 0.0, 0.0], # N
+ [0.9377, 0.0, -0.3816], # H
+ [-0.4689, 0.8121, -0.3816], # H
+ [-0.4689, -0.8121, -0.3816], # H
+ ], dtype=torch.float32)
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ energy = compute_multipole_energy(
+ tensor_sys, es_potential, coords, None, polarization_type="direct"
+ )
+
+ print(f"\nSingle ammonia molecule energy ({forcefield_name}): {energy.item():.6e} kcal/mol")
+
+ assert torch.allclose(energy, torch.tensor(0.0, dtype=energy.dtype), atol=1e-6)
+
+
+@pytest.mark.parametrize(
+ "forcefield_name,polarization_type",
+ [
+ ("PHAST-H2CNO-2.0.0.offxml", "direct"),
+ ("PHAST-H2CNO-2.0.0.offxml", "mutual"),
+ ("PHAST-H2CNO-2.0.0.offxml", "extrapolated"),
+ ]
+)
+def test_ammonia_axis_types(test_data_dir, forcefield_name, polarization_type):
+ """Test that ammonia axis types are correctly assigned."""
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ ["N"],
+ [2],
+ openff.toolkit.ForceField(
+ str(test_data_dir / forcefield_name), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = False
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ # Get assigned parameters for the single topology (4 atoms)
+ topology = tensor_sys.topologies[0]
+ param_map = topology.parameters[es_potential.type]
+ assigned_params = param_map.assignment_matrix @ es_potential.parameters
+
+ n_atoms_per_mol = topology.n_particles
+ print(f"\nAmmonia axis types ({forcefield_name}, {polarization_type}):")
+ print(f" Atoms per molecule: {n_atoms_per_mol}")
+
+ # Check axis types for each atom in the topology template
+ for i in range(n_atoms_per_mol):
+ axis_type = int(assigned_params[i, 13])
+ atom_z = int(assigned_params[i, 14])
+ atom_x = int(assigned_params[i, 15])
+ atom_y = int(assigned_params[i, 16])
+
+ atom_type = "N" if i == 0 else "H"
+ print(f" Atom {i} ({atom_type}): axisType={axis_type}, atomZ={atom_z}, atomX={atom_x}, atomY={atom_y}")
+
+ # Verify axis atoms are valid topology indices (or -1)
+ if atom_z >= 0:
+ assert atom_z < n_atoms_per_mol, f"atomZ={atom_z} exceeds topology size"
+ if atom_x >= 0:
+ assert atom_x < n_atoms_per_mol, f"atomX={atom_x} exceeds topology size"
+ if atom_y >= 0:
+ assert atom_y < n_atoms_per_mol, f"atomY={atom_y} exceeds topology size"
+
+ # Verify we can compute energy without errors
+ coords = _get_ammonia_dimer_coords()
+ energy = compute_multipole_energy(
+ tensor_sys, es_potential, coords.float(), None, polarization_type=polarization_type
+ )
+ assert torch.isfinite(energy)
+
+
+@pytest.mark.parametrize("n_ammonia", [2, 4, 6])
+def test_ammonia_cluster_energy(test_data_dir, n_ammonia):
+ """Test ammonia clusters of various sizes."""
+ tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(
+ ["N"],
+ [n_ammonia],
+ openff.toolkit.ForceField(
+ str(test_data_dir / "PHAST-H2CNO-2.0.0.offxml"), load_plugins=True
+ ),
+ )
+ tensor_sys.is_periodic = False
+
+ # Generate coordinates for ammonia cluster
+ coords_list = []
+ spacing = 4.0 # Å between molecules
+
+ for i in range(n_ammonia):
+ # Place ammonia molecules in a grid
+ x = (i % 2) * spacing
+ y = ((i // 2) % 2) * spacing
+ z = (i // 4) * spacing
+
+ # Standard ammonia geometry
+ coords_list.append([x, y, z]) # N
+ coords_list.append([x + 0.9377, y, z - 0.3816]) # H
+ coords_list.append([x - 0.4689, y + 0.8121, z - 0.3816]) # H
+ coords_list.append([x - 0.4689, y - 0.8121, z - 0.3816]) # H
+
+ coords = torch.tensor(coords_list, dtype=torch.float64)
+
+ es_potential = tensor_ff.potentials_by_type["Electrostatics"]
+
+ energy = compute_multipole_energy(
+ tensor_sys, es_potential, coords.float(), None, polarization_type="mutual"
+ )
+
+ expected_energy = _compute_openmm_energy(
+ tensor_sys, coords, None, es_potential, polarization_type="mutual"
+ )
+
+ print(f"\nAmmonia cluster (n={n_ammonia}):")
+ print(f" SMEE Energy: {energy.item():.6f} kcal/mol")
+ print(f" OpenMM Energy: {expected_energy.item():.6f} kcal/mol")
+ print(f" Difference: {abs(energy.item() - expected_energy.item()):.2e} kcal/mol")
+
+ assert torch.allclose(energy, expected_energy, atol=1e-3)