diff --git a/.gitignore b/.gitignore index f2e103a..04a99aa 100644 --- a/.gitignore +++ b/.gitignore @@ -113,3 +113,4 @@ ENV/ # Local development scratch +CLAUDE.md diff --git a/smee/_constants.py b/smee/_constants.py index d04e395..ca96e79 100644 --- a/smee/_constants.py +++ b/smee/_constants.py @@ -40,6 +40,7 @@ class EnergyFn(_StrEnum): """An enumeration of the energy functions supported by ``smee`` out of the box.""" COULOMB = "coul" + POLARIZATION = "coul+pol" VDW_LJ = "4*epsilon*((sigma/r)**12-(sigma/r)**6)" VDW_DEXP = ( @@ -48,6 +49,7 @@ class EnergyFn(_StrEnum): "alpha/(alpha-beta)*exp(beta*(1-r/r_min)))" ) # VDW_BUCKINGHAM = "a*exp(-b*r)-c*r^-6" + VDW_DAMPEDEXP6810 = "force_at_zero*beta**-1*exp(-beta*(r-rho))-f_6(beta*r)*c6**6-f_8(beta*r)*c8**8-f_10(beta*r)*c10**10" BOND_HARMONIC = "k/2*(r-length)**2" diff --git a/smee/converters/openff/_openff.py b/smee/converters/openff/_openff.py index 73ea3b4..1297f38 100644 --- a/smee/converters/openff/_openff.py +++ b/smee/converters/openff/_openff.py @@ -71,6 +71,26 @@ def _get_value( ) -> float: """Returns the value of a parameter in its default units""" default_units = default_units[parameter] + + # Handle missing parameters by using default values + if parameter not in potential.parameters: + if default_value is None: + # Set specific defaults for known multipole parameters + if parameter == "thole": + default_value = 0.39 * openff.units.unit.dimensionless + elif parameter == "dampingFactor": + # Will be computed from polarizability if not provided + return 0.0 # Placeholder, will be computed later + elif parameter.startswith("dipole"): + default_value = 0.0 * openff.units.unit.elementary_charge * openff.units.unit.angstrom + elif parameter.startswith("quadrupole"): + default_value = 0.0 * openff.units.unit.elementary_charge * openff.units.unit.angstrom**2 + elif parameter in ["axisType", "multipoleAtomZ", "multipoleAtomX", "multipoleAtomY"]: + default_value = -1 * openff.units.unit.dimensionless # -1 indicates undefined + else: + raise KeyError(f"Parameter '{parameter}' not found and no default provided") + return default_value.m_as(default_units) + value = potential.parameters[parameter] return (value if value is not None else default_value).m_as(default_units) diff --git a/smee/converters/openff/nonbonded.py b/smee/converters/openff/nonbonded.py index c695a46..267f3fa 100644 --- a/smee/converters/openff/nonbonded.py +++ b/smee/converters/openff/nonbonded.py @@ -217,6 +217,282 @@ def convert_dexp( return potential, parameter_maps +@smee.converters.smirnoff_parameter_converter( + "DampedExp6810", + { + "rho": _ANGSTROM, + "beta": _ANGSTROM**-1, + "c6": _KCAL_PER_MOL * _ANGSTROM**6, + "c8": _KCAL_PER_MOL * _ANGSTROM**8, + "c10": _KCAL_PER_MOL * _ANGSTROM**10, + "force_at_zero": _KCAL_PER_MOL * _ANGSTROM**-1, + "scale_12": _UNITLESS, + "scale_13": _UNITLESS, + "scale_14": _UNITLESS, + "scale_15": _UNITLESS, + "cutoff": _ANGSTROM, + "switch_width": _ANGSTROM, + }, +) +def convert_dampedexp6810( + handlers: list[ + "smirnoff_plugins.collections.nonbonded.SMIRNOFFDampedExp6810Collection" + ], + topologies: list[openff.toolkit.Topology], + v_site_maps: list[smee.VSiteMap | None], +) -> tuple[smee.TensorPotential, list[smee.NonbondedParameterMap]]: + import smee.potentials.nonbonded + + ( + potential, + parameter_maps, + ) = smee.converters.openff.nonbonded.convert_nonbonded_handlers( + handlers, + "DampedExp6810", + topologies, + v_site_maps, + ("rho", "beta", "c6", "c8", "c10"), + ("cutoff", "switch_width", "force_at_zero"), + ) + potential.type = smee.PotentialType.VDW + potential.fn = smee.EnergyFn.VDW_DAMPEDEXP6810 + + return potential, parameter_maps + + +@smee.converters.smirnoff_parameter_converter( + "Multipole", + { + # Molecular multipole moments + "dipoleX": _ELEMENTARY_CHARGE * _ANGSTROM, + "dipoleY": _ELEMENTARY_CHARGE * _ANGSTROM, + "dipoleZ": _ELEMENTARY_CHARGE * _ANGSTROM, + "quadrupoleXX": _ELEMENTARY_CHARGE * _ANGSTROM**2, + "quadrupoleXY": _ELEMENTARY_CHARGE * _ANGSTROM**2, + "quadrupoleXZ": _ELEMENTARY_CHARGE * _ANGSTROM**2, + "quadrupoleYX": _ELEMENTARY_CHARGE * _ANGSTROM**2, + "quadrupoleYY": _ELEMENTARY_CHARGE * _ANGSTROM**2, + "quadrupoleYZ": _ELEMENTARY_CHARGE * _ANGSTROM**2, + "quadrupoleZX": _ELEMENTARY_CHARGE * _ANGSTROM**2, + "quadrupoleZY": _ELEMENTARY_CHARGE * _ANGSTROM**2, + "quadrupoleZZ": _ELEMENTARY_CHARGE * _ANGSTROM**2, + # Local frame definition + "axisType": _UNITLESS, + "multipoleAtomZ": _UNITLESS, + "multipoleAtomX": _UNITLESS, + "multipoleAtomY": _UNITLESS, + # Damping and polarizability (these may not be present in current force fields) + "thole": _UNITLESS, + "dampingFactor": _ANGSTROM, + "polarity": _ANGSTROM**3, + # Cutoff and scaling + "cutoff": _ANGSTROM, + "scale_12": _UNITLESS, + "scale_13": _UNITLESS, + "scale_14": _UNITLESS, + "scale_15": _UNITLESS, + }, + depends_on=["Electrostatics"], +) +def convert_multipole( + handlers: list[ + "smirnoff_plugins.collections.nonbonded.SMIRNOFFMultipoleCollection" + ], + topologies: list[openff.toolkit.Topology], + v_site_maps: list[smee.VSiteMap | None], + dependencies: dict[ + str, tuple[smee.TensorPotential, list[smee.NonbondedParameterMap]] + ], +) -> tuple[smee.TensorPotential, list[smee.NonbondedParameterMap]]: + + potential_chg, parameter_maps_chg = dependencies["Electrostatics"] + + ( + potential_pol, + parameter_maps_pol, + ) = smee.converters.openff.nonbonded.convert_nonbonded_handlers( + handlers, + "Multipole", + topologies, + v_site_maps, + ( + "dipoleX", "dipoleY", "dipoleZ", + "quadrupoleXX", "quadrupoleXY", "quadrupoleXZ", + "quadrupoleYX", "quadrupoleYY", "quadrupoleYZ", + "quadrupoleZX", "quadrupoleZY", "quadrupoleZZ", + "axisType", "multipoleAtomZ", "multipoleAtomX", "multipoleAtomY", + "thole", "dampingFactor", "polarity" + ), + ("cutoff",), + has_exclusions=False, + ) + + cutoff_idx_pol = potential_pol.attribute_cols.index("cutoff") + cutoff_idx_chg = potential_chg.attribute_cols.index("cutoff") + + assert torch.isclose( + potential_pol.attributes[cutoff_idx_pol], + potential_chg.attributes[cutoff_idx_chg], + ) + + potential_chg.fn = smee.EnergyFn.POLARIZATION + + potential_chg.parameter_cols = ( + *potential_chg.parameter_cols, + *potential_pol.parameter_cols, + ) + potential_chg.parameter_units = ( + *potential_chg.parameter_units, + *potential_pol.parameter_units, + ) + potential_chg.parameter_keys = [ + *potential_chg.parameter_keys, + *potential_pol.parameter_keys, + ] + + # Handle different numbers of columns between charge and polarizability potentials + n_chg_cols = potential_chg.parameters.shape[1] + n_pol_cols = potential_pol.parameters.shape[1] + + # Pad charge parameters with zeros for the new polarizability columns + parameters_chg = torch.cat( + (potential_chg.parameters, torch.zeros(potential_chg.parameters.shape[0], n_pol_cols, dtype=potential_chg.parameters.dtype)), dim=1 + ) + + # Resolve multipole axis atoms from SMIRKS indices to actual topology indices. + # The OFFXML stores multipoleAtomZ/X/Y as 1-based SMIRKS atom map numbers. + # We need to resolve these to actual 0-based topology indices for each atom. + # Columns in potential_pol.parameters: 13=multipoleAtomZ, 14=multipoleAtomX, 15=multipoleAtomY + + parameter_key_to_idx = { + key: i for i, key in enumerate(potential_pol.parameter_keys) + } + + all_resolved_pol_params = [] + resolved_parameter_maps_pol = [] + + for handler, topology, v_site_map, param_map_pol in zip( + handlers, topologies, v_site_maps, parameter_maps_pol, strict=True + ): + n_atoms = topology.n_atoms + + # Create per-atom resolved parameters for multipoles + # Each atom gets its own parameter row with resolved axis indices + resolved_params = [] + + for atom_idx in range(n_atoms): + # Find the topology_key for this atom + matched_key = None + matched_param_idx = None + + for topology_key, parameter_key in handler.key_map.items(): + if isinstance(topology_key, openff.interchange.models.VirtualSiteKey): + continue + if topology_key.atom_indices[0] == atom_idx: + matched_key = topology_key + matched_param_idx = parameter_key_to_idx[parameter_key] + break + + if matched_key is None: + # No multipole parameters for this atom, create zero row + resolved_params.append(torch.zeros(n_pol_cols, dtype=torch.float64)) + continue + + # Get the base parameters for this atom + base_params = potential_pol.parameters[matched_param_idx].clone() + + # Resolve SMIRKS indices to actual topology indices + # SMIRKS indices are 1-based, so multipoleAtomZ=2 means atom_indices[1] + smirks_atom_z = int(base_params[13]) # multipoleAtomZ + smirks_atom_x = int(base_params[14]) # multipoleAtomX + smirks_atom_y = int(base_params[15]) # multipoleAtomY + + # Resolve using atom_indices from the match + # SMIRKS :N corresponds to atom_indices[N-1] + atom_indices = matched_key.atom_indices + + if smirks_atom_z > 0 and smirks_atom_z <= len(atom_indices): + base_params[13] = atom_indices[smirks_atom_z - 1] + else: + base_params[13] = -1 + + if smirks_atom_x > 0 and smirks_atom_x <= len(atom_indices): + base_params[14] = atom_indices[smirks_atom_x - 1] + else: + base_params[14] = -1 + + if smirks_atom_y > 0 and smirks_atom_y <= len(atom_indices): + base_params[15] = atom_indices[smirks_atom_y - 1] + else: + base_params[15] = -1 + + # Compute damping factor from polarity + base_params[17] = base_params[18] ** (1/6) + + resolved_params.append(base_params) + + # Stack into a tensor for this topology + resolved_params_tensor = torch.stack(resolved_params) + all_resolved_pol_params.append(resolved_params_tensor) + + # Create identity-like assignment matrix (each atom maps to its own row) + assignment_matrix = torch.eye(n_atoms, dtype=torch.float64).to_sparse() + resolved_parameter_maps_pol.append( + smee.NonbondedParameterMap( + assignment_matrix=assignment_matrix, + exclusions=param_map_pol.exclusions, + exclusion_scale_idxs=param_map_pol.exclusion_scale_idxs, + ) + ) + + # Combine all resolved parameters across topologies into a single tensor + # Each topology's parameters are separate, so we need to track offsets + total_pol_params = sum(p.shape[0] for p in all_resolved_pol_params) + + # Create combined parameters tensor + if all_resolved_pol_params: + combined_pol_params = torch.cat(all_resolved_pol_params, dim=0) + else: + combined_pol_params = torch.zeros((0, n_pol_cols), dtype=torch.float64) + + # Pad with charge columns + parameters_pol = torch.cat( + (torch.zeros(combined_pol_params.shape[0], n_chg_cols, dtype=torch.float64), combined_pol_params), dim=1 + ) + + potential_chg.parameters = torch.cat((parameters_chg, parameters_pol), dim=0) + + # Update assignment matrices with proper offsets + param_offset = 0 + for i, (parameter_map_chg, resolved_map_pol) in enumerate(zip( + parameter_maps_chg, resolved_parameter_maps_pol, strict=True + )): + n_chg_params = parameter_map_chg.assignment_matrix.shape[1] + n_atoms = resolved_map_pol.assignment_matrix.shape[0] + + # The resolved_map_pol assignment matrix is identity, but we need to offset + # the column indices by (n_charge_params + param_offset) + pol_assignment_dense = resolved_map_pol.assignment_matrix.to_dense() + + # Create the full assignment matrix + n_total_params = parameters_chg.shape[0] + combined_pol_params.shape[0] + full_assignment = torch.zeros((n_atoms, n_total_params), dtype=torch.float64) + + # Copy charge assignments + chg_assignment = parameter_map_chg.assignment_matrix.to_dense() + full_assignment[:, :n_chg_params] = chg_assignment + + # Add multipole assignments with proper offset + pol_start_col = parameters_chg.shape[0] + param_offset + for atom_idx in range(n_atoms): + full_assignment[atom_idx, pol_start_col + atom_idx] = 1.0 + + parameter_map_chg.assignment_matrix = full_assignment.to_sparse() + param_offset += n_atoms + + return potential_chg, parameter_maps_chg + + def _make_v_site_electrostatics_compatible( handlers: list[openff.interchange.smirnoff.SMIRNOFFElectrostaticsCollection], ): diff --git a/smee/converters/openmm/nonbonded.py b/smee/converters/openmm/nonbonded.py index 2188d87..56d2d27 100644 --- a/smee/converters/openmm/nonbonded.py +++ b/smee/converters/openmm/nonbonded.py @@ -7,6 +7,8 @@ import openmm import torch +from tholedipoleplugin import TholeDipoleForce + import smee import smee.converters.openmm import smee.potentials.nonbonded @@ -22,9 +24,9 @@ def _create_nonbonded_force( - potential: smee.TensorPotential, - system: smee.TensorSystem, - cls: typing.Type[_T] = openmm.NonbondedForce, + potential: smee.TensorPotential, + system: smee.TensorSystem, + cls: typing.Type[_T] = openmm.NonbondedForce, ) -> _T: """Create a non-bonded force for a given potential and system, making sure to set the appropriate method and cutoffs.""" @@ -71,10 +73,10 @@ def _create_nonbonded_force( def _eval_mixing_fn( - potential: smee.TensorPotential, - mixing_fn: dict[str, str], - param_1: torch.Tensor, - param_2: torch.Tensor, + potential: smee.TensorPotential, + mixing_fn: dict[str, str], + param_1: torch.Tensor, + param_2: torch.Tensor, ) -> dict[str, float]: import symengine @@ -96,8 +98,8 @@ def _eval_mixing_fn( def _build_vdw_lookup( - potential: smee.TensorPotential, - mixing_fn: dict[str, str], + potential: smee.TensorPotential, + mixing_fn: dict[str, str], ) -> dict[str, list[float]]: """Build the ``n_param x n_param`` vdW parameter lookup table containing parameters for all interactions. @@ -155,7 +157,7 @@ def _prepend_scale_to_energy_fn(fn: str, scale_var: str = _INTRA_SCALE_VAR) -> s def _detect_parameters( - potential: smee.TensorPotential, energy_fn: str, mixing_fn: dict[str, str] + potential: smee.TensorPotential, energy_fn: str, mixing_fn: dict[str, str] ) -> tuple[list[str], list[str]]: """Detect the required parameters and attributes for a given energy function and associated mixing rules.""" @@ -209,7 +211,7 @@ def _detect_parameters( def _extract_parameters( - potential: smee.TensorPotential, parameter: torch.Tensor, cols: list[str] + potential: smee.TensorPotential, parameter: torch.Tensor, cols: list[str] ) -> list[float]: """Extract the values of a subset of parameters from a parameter tensor.""" @@ -230,13 +232,13 @@ def _extract_parameters( def _add_parameters_to_vdw_without_lookup( - potential: smee.TensorPotential, - system: smee.TensorSystem, - energy_fn: str, - mixing_fn: dict[str, str], - inter_force: openmm.CustomNonbondedForce, - intra_force: openmm.CustomBondForce, - used_parameters: list[str], + potential: smee.TensorPotential, + system: smee.TensorSystem, + energy_fn: str, + mixing_fn: dict[str, str], + inter_force: openmm.CustomNonbondedForce, + intra_force: openmm.CustomBondForce, + used_parameters: list[str], ): """Add parameters to a vdW force directly, i.e. without using a lookup table.""" @@ -292,12 +294,12 @@ def _add_parameters_to_vdw_without_lookup( def _add_parameters_to_vdw_with_lookup( - potential: smee.TensorPotential, - system: smee.TensorSystem, - energy_fn: str, - mixing_fn: dict[str, str], - inter_force: openmm.CustomNonbondedForce, - intra_force: openmm.CustomBondForce, + potential: smee.TensorPotential, + system: smee.TensorSystem, + energy_fn: str, + mixing_fn: dict[str, str], + inter_force: openmm.CustomNonbondedForce, + intra_force: openmm.CustomBondForce, ): """Add parameters to a vdW force, explicitly defining all pairwise parameters using a lookup table.""" @@ -356,10 +358,10 @@ def _add_parameters_to_vdw_with_lookup( def convert_custom_vdw_potential( - potential: smee.TensorPotential, - system: smee.TensorSystem, - energy_fn: str, - mixing_fn: dict[str, str], + potential: smee.TensorPotential, + system: smee.TensorSystem, + energy_fn: str, + mixing_fn: dict[str, str], ) -> tuple[openmm.CustomNonbondedForce, openmm.CustomBondForce]: """Converts an arbitrary vdW potential to OpenMM forces. @@ -444,7 +446,7 @@ def convert_custom_vdw_potential( smee.PotentialType.VDW, smee.EnergyFn.VDW_LJ ) def convert_lj_potential( - potential: smee.TensorPotential, system: smee.TensorSystem + potential: smee.TensorPotential, system: smee.TensorSystem ) -> openmm.NonbondedForce | list[openmm.CustomNonbondedForce | openmm.CustomBondForce]: """Convert a Lennard-Jones potential to an OpenMM force. @@ -501,7 +503,7 @@ def convert_lj_potential( smee.PotentialType.VDW, smee.EnergyFn.VDW_DEXP ) def convert_dexp_potential( - potential: smee.TensorPotential, system: smee.TensorSystem + potential: smee.TensorPotential, system: smee.TensorSystem ) -> tuple[openmm.CustomNonbondedForce, openmm.CustomBondForce]: """Convert a DEXP potential to OpenMM forces. @@ -526,11 +528,212 @@ def convert_dexp_potential( return convert_custom_vdw_potential(potential, system, energy_fn, mixing_fn) +@smee.converters.openmm.potential_converter( + smee.PotentialType.VDW, smee.EnergyFn.VDW_DAMPEDEXP6810 +) +def convert_dampedexp6810_potential( + potential: smee.TensorPotential, system: smee.TensorSystem +) -> tuple[openmm.CustomNonbondedForce, openmm.CustomBondForce]: + """Convert a DampedExp6810 potential to OpenMM forces. + + The intermolcular interactions are described by a custom nonbonded force, while the + intramolecular interactions are described by a custom bond force. + + If the potential has custom mixing rules (i.e. exceptions), a lookup table will be + used to store the parameters. Otherwise, the mixing rules will be applied directly + in the energy function. + """ + energy_fn = ( + "repulsion - ttdamp6*c6*invR^6 - ttdamp8*c8*invR^8 - ttdamp10*c10*invR^10;" + "repulsion = force_at_zero*invbeta*exp(-beta*(r-rho));" + "ttdamp10 = select(expbr, 1.0 - expbr * ttdamp10Sum, 1);" + "ttdamp8 = select(expbr, 1.0 - expbr * ttdamp8Sum, 1);" + "ttdamp6 = select(expbr, 1.0 - expbr * ttdamp6Sum, 1);" + "ttdamp10Sum = ttdamp8Sum + br^9/362880 + br^10/3628800;" + "ttdamp8Sum = ttdamp6Sum + br^7/5040 + br^8/40320;" + "ttdamp6Sum = 1.0 + br + br^2/2 + br^3/6 + br^4/24 + br^5/120 + br^6/720;" + "expbr = exp(-br);" + "br = beta*r;" + "invR = 1.0/r;" + "invbeta = 1.0/beta;" + ) + mixing_fn = { + "beta": "2.0 * beta1 * beta2 / (beta1 + beta2)", + "rho": "0.5 * (rho1 + rho2)", + "c6": "sqrt(c61*c62)", + "c8": "sqrt(c81*c82)", + "c10": "sqrt(c101*c102)", + } + + return convert_custom_vdw_potential(potential, system, energy_fn, mixing_fn) + + +@smee.converters.openmm.potential_converter( + smee.PotentialType.ELECTROSTATICS, smee.EnergyFn.POLARIZATION +) +def convert_multipole_potential( + potential: smee.TensorPotential, system: smee.TensorSystem +) -> TholeDipoleForce: + """Convert a Multipole potential to OpenMM TholeDipoleForce. + + TholeDipole parameter layout (9 columns): + Column 0: charge (e) + Columns 1-3: molecularDipole (e·Å, x, y, z) + Column 4: axisType (int, 0-5) + Column 5: multipoleAtomZ (int) + Column 6: multipoleAtomX (int) + Column 7: multipoleAtomY (int) + Column 8: polarity (ų) + """ + cutoff_idx = potential.attribute_cols.index(smee.CUTOFF_ATTRIBUTE) + cutoff = float(potential.attributes[cutoff_idx]) * 0.1 # Å to nm + + force = TholeDipoleForce() + + if system.is_periodic: + force.setNonbondedMethod(TholeDipoleForce.PME) + else: + force.setNonbondedMethod(TholeDipoleForce.NoCutoff) + + force.setPolarizationType(TholeDipoleForce.Mutual) + force.setCutoffDistance(cutoff) + force.setEwaldErrorTolerance(0.0001) + force.setMutualInducedTargetEpsilon(0.00001) + force.setMutualInducedMaxIterations(60) + force.setExtrapolationCoefficients([-0.154, 0.017, 0.658, 0.474]) + force.setTholeDampingType(TholeDipoleForce.Amoeba) + force.setTholeDampingParameter(0.39) + + # Map AMOEBA axis types to TholeDipole axis types + # AMOEBA: NoAxisType=0, ZOnly=1, ZThenX=2, Bisector=3, ZBisect=4, ThreeFold=5 + # TholeDipole: ZThenX=0, Bisector=1, ZBisect=2, ThreeFold=3, ZOnly=4, NoAxisType=5 + amoeba_to_thole_axis = { + 0: TholeDipoleForce.NoAxisType, # AMOEBA NoAxisType -> TholeDipole NoAxisType + 1: TholeDipoleForce.ZOnly, # AMOEBA ZOnly -> TholeDipole ZOnly + 2: TholeDipoleForce.ZThenX, # AMOEBA ZThenX -> TholeDipole ZThenX + 3: TholeDipoleForce.Bisector, # AMOEBA Bisector -> TholeDipole Bisector + 4: TholeDipoleForce.ZBisect, # AMOEBA ZBisect -> TholeDipole ZBisect + 5: TholeDipoleForce.ThreeFold, # AMOEBA ThreeFold -> TholeDipole ThreeFold + } + + idx_offset = 0 + + for topology, n_copies in zip(system.topologies, system.n_copies): + parameter_map = topology.parameters[potential.type] + parameters = parameter_map.assignment_matrix @ potential.parameters + parameters = parameters.detach() + + n_particles = topology.n_particles + n_params = parameters.shape[1] + + for _ in range(n_copies): + for atom_idx in range(n_particles): + # Get charge from first n_particles rows + charge = float(parameters[atom_idx, 0]) + + # Get dipole, axisType, frame atoms, polarity from rows n_particles to 2*n_particles + if parameters.shape[0] > n_particles: + pol_row = parameters[n_particles + atom_idx] + + if n_params == 20: + # AMOEBA-style 20-column layout (current PHAST force field): + # Col 0: charge, 1-3: dipole, 4-12: quadrupole (ignored) + # Col 13: axisType, 14-16: atomZ/X/Y, 17: thole, 18: dampingFactor, 19: polarity + dipole = [float(pol_row[1]) * 0.1, float(pol_row[2]) * 0.1, float(pol_row[3]) * 0.1] # e·Å to e·nm + amoeba_axis_type = int(pol_row[13]) + atom_z = int(pol_row[14]) + atom_x = int(pol_row[15]) + atom_y = int(pol_row[16]) + polarity = float(pol_row[19]) * 0.001 # ų to nm³ + + # Map axis type, but force NoAxisType if no valid axis atoms + if atom_z < 0: + axis_type = TholeDipoleForce.NoAxisType + else: + axis_type = amoeba_to_thole_axis.get(amoeba_axis_type, TholeDipoleForce.NoAxisType) + elif n_params == 9: + # TholeDipole 9-column layout: + # Col 0: charge, 1-3: dipole, 4: axisType, 5-7: atomZ/X/Y, 8: polarity + dipole = [float(pol_row[1]) * 0.1, float(pol_row[2]) * 0.1, float(pol_row[3]) * 0.1] + axis_type = int(pol_row[4]) + atom_z = int(pol_row[5]) + atom_x = int(pol_row[6]) + atom_y = int(pol_row[7]) + polarity = float(pol_row[8]) * 0.001 + else: + # Fallback: assume dipole at 1-3, polarity at last column + dipole = [float(pol_row[1]) * 0.1, float(pol_row[2]) * 0.1, float(pol_row[3]) * 0.1] + axis_type = TholeDipoleForce.NoAxisType + atom_z = -1 + atom_x = -1 + atom_y = -1 + polarity = float(pol_row[n_params - 1]) * 0.001 if n_params > 4 else 0.0 + else: + dipole = [0.0, 0.0, 0.0] + axis_type = TholeDipoleForce.NoAxisType + atom_z = -1 + atom_x = -1 + atom_y = -1 + polarity = 0.0 + + # The axis atom indices in the parameters are now actual 0-based topology indices + # (resolved from SMIRKS indices during OpenFF conversion). We just need to + # add the idx_offset for the current molecule copy. + force.addParticle( + charge, + dipole, + polarity, + axis_type, + int(atom_z) + idx_offset if atom_z >= 0 else -1, + int(atom_x) + idx_offset if atom_x >= 0 else -1, + int(atom_y) + idx_offset if atom_y >= 0 else -1, + ) + + # Set up covalent maps (TholeDipole uses 4 types: Covalent12-15) + covalent_12_maps = {} + covalent_13_maps = {} + covalent_14_maps = {} + covalent_15_maps = {} + + for (i, j), scale_idx in zip(parameter_map.exclusions, parameter_map.exclusion_scale_idxs): + i = int(i) + idx_offset + j = int(j) + idx_offset + + if scale_idx == 0: # 1-2 interactions + covalent_maps = covalent_12_maps + elif scale_idx == 1: # 1-3 interactions + covalent_maps = covalent_13_maps + elif scale_idx == 2: # 1-4 interactions + covalent_maps = covalent_14_maps + else: # 1-5+ interactions + covalent_maps = covalent_15_maps + + if i not in covalent_maps: + covalent_maps[i] = [] + if j not in covalent_maps: + covalent_maps[j] = [] + covalent_maps[i].append(j) + covalent_maps[j].append(i) + + for i, atoms in covalent_12_maps.items(): + force.setCovalentMap(i, TholeDipoleForce.Covalent12, atoms) + for i, atoms in covalent_13_maps.items(): + force.setCovalentMap(i, TholeDipoleForce.Covalent13, atoms) + for i, atoms in covalent_14_maps.items(): + force.setCovalentMap(i, TholeDipoleForce.Covalent14, atoms) + for i, atoms in covalent_15_maps.items(): + force.setCovalentMap(i, TholeDipoleForce.Covalent15, atoms) + + idx_offset += n_particles + + return force + + @smee.converters.openmm.potential_converter( smee.PotentialType.ELECTROSTATICS, smee.EnergyFn.COULOMB ) def convert_coulomb_potential( - potential: smee.TensorPotential, system: smee.TensorSystem + potential: smee.TensorPotential, system: smee.TensorSystem ) -> openmm.NonbondedForce: """Convert a Coulomb potential to an OpenMM force.""" force = _create_nonbonded_force(potential, system) diff --git a/smee/potentials/multipole.py b/smee/potentials/multipole.py new file mode 100644 index 0000000..0d27d3d --- /dev/null +++ b/smee/potentials/multipole.py @@ -0,0 +1,424 @@ +"""Multipole potential energy functions.""" + +import math +import typing + +import torch + +import smee.potentials + + +@smee.potentials.potential_energy_fn( + smee.PotentialType.ELECTROSTATICS, smee.EnergyFn.POLARIZATION +) +def compute_multipole_energy( + system: smee.TensorSystem, + potential: smee.TensorPotential, + conformer: torch.Tensor, + box_vectors: torch.Tensor | None = None, + pairwise=None, + polarization_type: str = "mutual", + extrapolation_coefficients: list[float] | None = None, +) -> torch.Tensor: + """Compute the multipole energy including polarization effects. + + This function supports the full AMOEBA multipole model with the following parameters + per atom (arranged in columns): + + Column 0: charge (double) - the particle's charge + Columns 1-3: molecularDipole (vector[3]) - the particle's molecular dipole (x, y, z) + Columns 4-12: molecularQuadrupole (vector[9]) - the particle's molecular quadrupole + Column 13: axisType (int) - the particle's axis type (0=NoAxisType, 1=ZOnly, etc.) + Column 14: multipoleAtomZ (int) - index of first atom for lab<->molecular frames + Column 15: multipoleAtomX (int) - index of second atom for lab<->molecular frames + Column 16: multipoleAtomY (int) - index of third atom for lab<->molecular frames + Column 17: thole (double) - Thole parameter (default 0.39) + Column 18: dampingFactor (double) - dampingFactor parameter + Column 19: polarity (double) - polarity/polarizability parameter + + For backwards compatibility, if fewer columns are provided, sensible defaults are used. + + Args: + system: The system. + potential: The potential containing multipole parameters. + conformer: The conformer. + box_vectors: The box vectors. + pairwise: The pairwise distances. + polarization_type: The polarization solver type. Options are: + - "mutual": Full iterative SCF solver (default, ~60 iterations) + - "direct": Direct polarization with no mutual coupling (0 iterations) + - "extrapolated": Extrapolated polarization using OPT method + extrapolation_coefficients: Custom extrapolation coefficients for "extrapolated" type. + If None, uses OPT3 coefficients [-0.154, 0.017, 0.657, 0.475]. + Must sum to approximately 1.0 for energy conservation. + + Returns: + The energy. + """ + from smee.potentials.nonbonded import ( + _COULOMB_PRE_FACTOR, + _compute_pme_exclusions, + _compute_pme_grid, + _broadcast_exclusions, + compute_pairwise, + compute_pairwise_scales, + ) + + # Validate polarization_type + valid_types = ["mutual", "direct", "extrapolated"] + if polarization_type not in valid_types: + raise ValueError(f"polarization_type must be one of {valid_types}, got {polarization_type}") + + box_vectors = None if not system.is_periodic else box_vectors + + cutoff = potential.attributes[potential.attribute_cols.index(smee.CUTOFF_ATTRIBUTE)] + + pairwise = compute_pairwise(system, conformer, box_vectors, cutoff) + + # Initialize parameter lists for all AMOEBA multipole parameters + charges = [] + dipoles = [] # 3 components per atom + quadrupoles = [] # 9 components per atom + axis_types = [] + multipole_atom_z = [] # Z-axis defining atom indices + multipole_atom_x = [] # X-axis defining atom indices + multipole_atom_y = [] # Y-axis defining atom indices + thole_params = [] + polarizabilities = [] + + # Extract parameters from parameter matrix + for topology, n_copies in zip(system.topologies, system.n_copies): + parameter_map = topology.parameters[potential.type] + topology_parameters = parameter_map.assignment_matrix @ potential.parameters + + n_particles = topology.n_particles + n_params = topology_parameters.shape[1] + + # Expected parameter layout for full AMOEBA multipole: + # Column 0: charge + # Columns 1-3: molecular dipole (x, y, z) + # Columns 4-12: molecular quadrupole (9 components) + # Column 13: axis type (int) + # Column 14: multipole atom Z index + # Column 15: multipole atom X index + # Column 16: multipole atom Y index + # Column 17: thole parameter + # Column 18: damping factor + # Column 19: polarizability + + charges.append(topology_parameters[:n_particles, 0].repeat(n_copies)) + dipoles.append(topology_parameters[n_particles:, 1:4].repeat(n_copies, 1)) + quadrupoles.append(topology_parameters[n_particles:, 4:13].repeat(n_copies, 1)) + axis_types.append(topology_parameters[n_particles:, 13].repeat(n_copies).int()) + multipole_atom_z.append(topology_parameters[n_particles:, 14].repeat(n_copies).int()) + multipole_atom_x.append(topology_parameters[n_particles:, 15].repeat(n_copies).int()) + multipole_atom_y.append(topology_parameters[n_particles:, 16].repeat(n_copies).int()) + thole_params.append(topology_parameters[n_particles:, 17].repeat(n_copies)) + polarizabilities.append(topology_parameters[n_particles:, 19].repeat(n_copies)) + + # Concatenate all parameter lists + charges = torch.cat(charges) # Shape: (n_total_particles,) + dipoles = torch.cat(dipoles) # Shape: (n_total_particles, 3) + quadrupoles = torch.cat(quadrupoles) # Shape: (n_total_particles, 9) + axis_types = torch.cat(axis_types) # Shape: (n_total_particles,) + multipole_atom_z = torch.cat(multipole_atom_z) # Shape: (n_total_particles,) + multipole_atom_x = torch.cat(multipole_atom_x) # Shape: (n_total_particles,) + multipole_atom_y = torch.cat(multipole_atom_y) # Shape: (n_total_particles,) + thole_params = torch.cat(thole_params) # Shape: (n_total_particles,) + polarizabilities = torch.cat(polarizabilities) # Shape: (n_total_particles,) + + # Override scale factors to match TholeDipole convention + # TholeDipole uses [0, 0, 0.5, 1.0] for 1-2, 1-3, 1-4, 1-5 (vs AMOEBA's [0, 0, 0.4, 0.8/1.0]) + if 'scale_14' in potential.attribute_cols: + scale_14_idx = potential.attribute_cols.index('scale_14') + potential.attributes[scale_14_idx] = 0.5 + + pair_scales = compute_pairwise_scales(system, potential) + + # static partial charge - partial charge energy + if system.is_periodic == False: + coul_energy = ( + _COULOMB_PRE_FACTOR + * pair_scales + * charges[pairwise.idxs[:, 0]] + * charges[pairwise.idxs[:, 1]] + / pairwise.distances + ).sum(-1) + else: + import NNPOps.pme + + cutoff = potential.attributes[potential.attribute_cols.index(smee.CUTOFF_ATTRIBUTE)] + error_tol = torch.tensor(0.0001) + + exceptions = _compute_pme_exclusions(system, potential).to(charges.device) + + grid_x, grid_y, grid_z, alpha = _compute_pme_grid(box_vectors, cutoff, error_tol) + + pme = NNPOps.pme.PME( + grid_x, grid_y, grid_z, 5, alpha, _COULOMB_PRE_FACTOR, exceptions + ) + + energy_direct = torch.ops.pme.pme_direct( + conformer.float(), + charges.float(), + pairwise.idxs.T, + pairwise.deltas, + pairwise.distances, + pme.exclusions, + pme.alpha, + pme.coulomb, + ) + energy_self = -torch.sum(charges ** 2) * pme.coulomb * pme.alpha / math.sqrt(torch.pi) + energy_recip = energy_self + torch.ops.pme.pme_reciprocal( + conformer.float(), + charges.float(), + box_vectors.float(), + pme.gridx, + pme.gridy, + pme.gridz, + pme.order, + pme.alpha, + pme.coulomb, + pme.moduli[0].to(charges.device), + pme.moduli[1].to(charges.device), + pme.moduli[2].to(charges.device), + ) + + exclusion_idxs, exclusion_scales = _broadcast_exclusions(system, potential) + + exclusion_distances = ( + conformer[exclusion_idxs[:, 0], :] - conformer[exclusion_idxs[:, 1], :] + ).norm(dim=-1) + + energy_exclusion = ( + _COULOMB_PRE_FACTOR + * exclusion_scales + * charges[exclusion_idxs[:, 0]] + * charges[exclusion_idxs[:, 1]] + / exclusion_distances + ).sum(-1) + + coul_energy = energy_direct + energy_recip + energy_exclusion + + # If all polarizabilities are zero, just return the Coulomb energy + if torch.allclose(polarizabilities, torch.tensor(0.0, dtype=torch.float64)): + return coul_energy + + # Handle batch vs single conformer - process each conformer individually + is_batch = conformer.ndim == 3 + + if is_batch: + # Process each conformer individually and return results for each + n_conformers = conformer.shape[0] + batch_energies = [] + + for conf_idx in range(n_conformers): + # Extract single conformer + single_conformer = conformer[conf_idx] + + # Compute pairwise for this conformer + single_pairwise = compute_pairwise(system, single_conformer, box_vectors, cutoff) + + # Recursively call this function for single conformer + single_energy = compute_multipole_energy( + system, potential, single_conformer, box_vectors, single_pairwise, + polarization_type, extrapolation_coefficients + ) + batch_energies.append(single_energy) + + return torch.stack(batch_energies) + + # Continue with single conformer processing + efield_static = torch.zeros((system.n_particles, 3), dtype=torch.float64, device=conformer.device) + efield_static_polar = torch.zeros((system.n_particles, 3), dtype=torch.float64, device=conformer.device) + + # TholeDipole scale factors (different from AMOEBA): + # mScale: permanent-permanent interactions [1-2, 1-3, 1-4, 1-5, 1-6+] + # iScale: induced-induced interactions + # TholeDipole uses [0, 0, 0.5, 1.0, 1.0] for mScale (vs AMOEBA's [0, 0, 0, 0.4, 0.8]) + mScale = torch.tensor([0.0, 0.0, 0.5, 1.0, 1.0], dtype=torch.float64) + dScale = torch.tensor([0.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float64) + pScale = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0], dtype=torch.float64) + uScale = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float64) + + # calculate electric field due to partial charges by hand + # TODO ewald or wolf summation for periodic + _SQRT_COULOMB_PRE_FACTOR = _COULOMB_PRE_FACTOR ** (1 / 2) + # Calculate fixed dipole field from permanent charges + # TholeDipole applies mScale to the field calculation + for distance, delta, idx, scale in zip( + pairwise.distances, pairwise.deltas, pairwise.idxs, pair_scales + ): + # Use actual scale value (don't zero out partial scales) + if polarizabilities[idx[0]] * polarizabilities[idx[1]] != 0: + u = distance / (polarizabilities[idx[0]] * polarizabilities[idx[1]]) ** ( + 1.0 / 6.0 + ) + else: + u = distance + + # Use the Thole parameter (typically 0.39) in the damping function + # Using the minimum of the two atoms' Thole parameters + thole_a = torch.min(thole_params[idx[0]], thole_params[idx[1]]) + au3 = thole_a * u ** 3 + exp_au3 = torch.exp(-au3) + + # Thole damping for charge-dipole interactions (thole_c in OpenMM) + damping_term1 = 1 - exp_au3 + + efield_static[idx[0]] -= ( + _SQRT_COULOMB_PRE_FACTOR + * scale + * damping_term1 + * charges[idx[1]] + * delta + / distance ** 3 + ) + efield_static[idx[1]] += ( + _SQRT_COULOMB_PRE_FACTOR + * scale + * damping_term1 + * charges[idx[0]] + * delta + / distance ** 3 + ) + + # reshape to (3*N) vector + efield_static = efield_static.reshape(3 * system.n_particles) + + # If there's no electric field, no polarization energy + if torch.allclose(efield_static, torch.tensor(0.0, dtype=torch.float64)): + return coul_energy + + # induced dipole vector - start with direct polarization + ind_dipoles = torch.repeat_interleave(polarizabilities, 3) * efield_static + + # Build A matrix for mutual/extrapolated methods + if polarization_type in ["mutual", "extrapolated"]: + A = torch.nan_to_num(torch.diag(torch.repeat_interleave(1.0 / polarizabilities, 3))) + + # Build A matrix for induced-induced coupling + # NOTE: TholeDipole uses iScale=1.0 for ALL covalent types (no exclusions) + for distance, delta, idx in zip( + pairwise.distances, pairwise.deltas, pairwise.idxs + ): + if polarizabilities[idx[0]] * polarizabilities[idx[1]] != 0: + u = distance / (polarizabilities[idx[0]] * polarizabilities[idx[1]]) ** ( + 1.0 / 6.0 + ) + else: + u = distance + + # Use the Thole parameter (typically 0.39) in the damping function + thole_a = torch.min(thole_params[idx[0]], thole_params[idx[1]]) + au3 = thole_a * u ** 3 + exp_au3 = torch.exp(-au3) + thole3 = 1 - exp_au3 + thole5 = 1 - (1 + au3) * exp_au3 # Note: (1 + au3), NOT (1 + 1.5*au3) + + t = ( + torch.eye(3, dtype=torch.float64, device=conformer.device) * thole3 * distance ** -3 + - 3 * thole5 * torch.einsum("i,j->ij", delta, delta) * distance ** -5 + ) + # No scale factor - iScale is always 1.0 in TholeDipole + A[3 * idx[0]: 3 * idx[0] + 3, 3 * idx[1]: 3 * idx[1] + 3] = t + A[3 * idx[1]: 3 * idx[1] + 3, 3 * idx[0]: 3 * idx[0] + 3] = t + + # Handle different polarization types + if polarization_type == "direct": + # Direct polarization: μ = α * E (no mutual coupling) + # ind_dipoles is already μ^(0) = α * E, so no additional work needed + pass + elif polarization_type == "extrapolated": + # TholeDipole extrapolation uses simple iterative updates (NOT conjugate gradient) + # μ_0 = α * E_fixed (direct polarization) + # μ_n = α * (E_fixed + E_induced(μ_{n-1})) for n = 1, 2, 3 + # μ_final = Σ c_k * μ_k + if extrapolation_coefficients is None: + # TholeDipole default OPT4 coefficients + extrapolation_coefficients = [-0.154, 0.017, 0.658, 0.474] + + opt_coeffs = torch.tensor(extrapolation_coefficients, dtype=torch.float64, device=conformer.device) + n_orders = len(opt_coeffs) + + # Store perturbation theory orders + pt_dipoles = [] + pt_dipoles.append(ind_dipoles.clone()) # PT0: direct polarization μ = α * E_fixed + + # Build dipole-dipole interaction tensor T for computing induced fields + # E_induced[i] = Σ_j T[i,j] @ μ[j] + T = torch.zeros((3 * system.n_particles, 3 * system.n_particles), dtype=torch.float64, + device=conformer.device) + for distance, delta, idx in zip(pairwise.distances, pairwise.deltas, pairwise.idxs): + if polarizabilities[idx[0]] * polarizabilities[idx[1]] != 0: + u = distance / (polarizabilities[idx[0]] * polarizabilities[idx[1]]) ** (1.0 / 6.0) + else: + u = distance + + thole_a = torch.min(thole_params[idx[0]], thole_params[idx[1]]) + au3 = thole_a * u ** 3 + exp_au3 = torch.exp(-au3) + thole3 = 1 - exp_au3 + thole5 = 1 - (1 + au3) * exp_au3 + + # Dipole field tensor: E = [-thole3*μ/r³ + 3*thole5*(μ·r̂)r̂/r³] + t = ( + -torch.eye(3, dtype=torch.float64, device=conformer.device) * thole3 * distance ** -3 + + 3 * thole5 * torch.einsum("i,j->ij", delta, delta) * distance ** -5 + ) + T[3 * idx[0]: 3 * idx[0] + 3, 3 * idx[1]: 3 * idx[1] + 3] = t + T[3 * idx[1]: 3 * idx[1] + 3, 3 * idx[0]: 3 * idx[0] + 3] = t + + # Generate higher order PT terms + current_dipoles = ind_dipoles.clone() + for order in range(1, n_orders): + # Calculate induced field from current dipoles + efield_induced = T @ current_dipoles + + # Update dipoles: μ_n = α * (E_fixed + E_induced) + current_dipoles = torch.repeat_interleave(polarizabilities, 3) * (efield_static + efield_induced) + pt_dipoles.append(current_dipoles.clone()) + + # Combine PT orders with extrapolation coefficients + ind_dipoles = torch.zeros_like(ind_dipoles) + for k in range(n_orders): + ind_dipoles += opt_coeffs[k] * pt_dipoles[k] + + else: # mutual + # Mutual polarization using conjugate gradient + precondition_m = torch.repeat_interleave(polarizabilities, 3) + residual = efield_static - A @ ind_dipoles + z = torch.einsum("i,i->i", precondition_m, residual) + p = torch.clone(z) + + converged_iter = 60 + for iter_num in range(60): + alpha = torch.dot(residual, z) / (p @ A @ p) + ind_dipoles = ind_dipoles + alpha * p + + prev_residual = torch.clone(residual) + prev_z = torch.clone(z) + + residual = residual - alpha * A @ p + + rms_residual = torch.sqrt(torch.dot(residual, residual) / len(residual)) + if torch.dot(residual, residual) < 1e-7: + converged_iter = iter_num + 1 + break + + z = torch.einsum("i,i->i", precondition_m, residual) + beta = torch.dot(z, residual) / torch.dot(prev_z, prev_residual) + p = z + beta * p + + # Reshape induced dipoles back to (N, 3) for energy calculations + ind_dipoles_3d = ind_dipoles.reshape(system.n_particles, 3) + + # Calculate polarization energy + # TholeDipole uses the same formula for ALL polarization types: -0.5 * μ · E_fixed + # This works because: + # - For Direct: μ = α * E_fixed, so U = -0.5 * α * |E_fixed|² + # - For Mutual: At SCF convergence, this gives the correct variational energy + # - For Extrapolated: The extrapolated dipoles approximate the SCF result + coul_energy += -0.5 * torch.dot(ind_dipoles, efield_static) + + return coul_energy diff --git a/smee/potentials/nonbonded.py b/smee/potentials/nonbonded.py index a8ac45a..4cee816 100644 --- a/smee/potentials/nonbonded.py +++ b/smee/potentials/nonbonded.py @@ -4,6 +4,7 @@ import math import typing +import numpy as np import openff.units import torch @@ -796,6 +797,163 @@ def compute_dexp_energy( return energy +def _compute_dampedexp6810_lrc( + system: smee.TensorSystem, + potential: smee.TensorPotential, + rs: torch.Tensor | None, + rc: torch.Tensor | None, + volume: torch.Tensor, +) -> torch.Tensor: + """Computes the long range dispersion correction due to the double exponential + potential, possibly with a switching function.""" + + raise NotImplementedError + + +@smee.potentials.potential_energy_fn( + smee.PotentialType.VDW, smee.EnergyFn.VDW_DAMPEDEXP6810 +) +def compute_dampedexp6810_energy( + system: smee.TensorSystem, + potential: smee.TensorPotential, + conformer: torch.Tensor, + box_vectors: torch.Tensor | None = None, + pairwise: PairwiseDistances | None = None, +) -> torch.Tensor: + """Compute the potential energy [kcal / mol] of the vdW interactions using the + DampedExp6810 potential. + + Notes: + * No cutoff function will be applied if the system is not periodic. + + Args: + system: The system to compute the energy for. + potential: The potential energy function to evaluate. + conformer: The conformer [Å] to evaluate the potential at with + ``shape=(n_confs, n_particles, 3)`` or ``shape=(n_particles, 3)``. + box_vectors: The box vectors [Å] of the system with ``shape=(n_confs, 3, 3)`` + or ``shape=(3, 3)`` if the system is periodic, or ``None`` otherwise. + pairwise: Pre-computed distances between each pair of particles + in the system. + + Returns: + The evaluated potential energy [kcal / mol]. + """ + box_vectors = None if not system.is_periodic else box_vectors + + cutoff = potential.attributes[potential.attribute_cols.index(smee.CUTOFF_ATTRIBUTE)] + + pairwise = ( + pairwise + if pairwise is not None + else compute_pairwise(system, conformer, box_vectors, cutoff) + ) + + if system.is_periodic and not torch.isclose(pairwise.cutoff, cutoff): + raise ValueError("the pairwise cutoff does not match the potential.") + + parameters = smee.potentials.broadcast_parameters(system, potential) + pair_scales = compute_pairwise_scales(system, potential) + + pairs_1d = smee.utils.to_upper_tri_idx( + pairwise.idxs[:, 0], pairwise.idxs[:, 1], len(parameters) + ) + pair_scales = pair_scales[pairs_1d] + + rho_column = potential.parameter_cols.index("rho") + beta_column = potential.parameter_cols.index("beta") + c6_column = potential.parameter_cols.index("c6") + c8_column = potential.parameter_cols.index("c8") + c10_column = potential.parameter_cols.index("c10") + + rho_a = parameters[pairwise.idxs[:, 0], rho_column] + rho_b = parameters[pairwise.idxs[:, 1], rho_column] + beta_a = parameters[pairwise.idxs[:, 0], beta_column] + beta_b = parameters[pairwise.idxs[:, 1], beta_column] + c6_a = parameters[pairwise.idxs[:, 0], c6_column] + c6_b = parameters[pairwise.idxs[:, 1], c6_column] + c8_a = parameters[pairwise.idxs[:, 0], c8_column] + c8_b = parameters[pairwise.idxs[:, 1], c8_column] + c10_a = parameters[pairwise.idxs[:, 0], c10_column] + c10_b = parameters[pairwise.idxs[:, 1], c10_column] + + rho = 0.5 * (rho_a + rho_b) + beta = 2.0 * beta_a * beta_b / (beta_a + beta_b) + c6 = smee.utils.geometric_mean(c6_a, c6_b) + c8 = smee.utils.geometric_mean(c8_a, c8_b) + c10 = smee.utils.geometric_mean(c10_a, c10_b) + + if potential.exceptions is not None: + exception_idxs, exceptions = smee.potentials.broadcast_exceptions( + system, potential, pairwise.idxs[:, 0], pairwise.idxs[:, 1] + ) + + rho = rho.clone() # prevent in-place modification + beta = beta.clone() + c6 = c6.clone() + c8 = c8.clone() + c10 = c10.clone() + + rho[exception_idxs] = exceptions[:, rho_column] + beta[exception_idxs] = exceptions[:, beta_column] + c6[exception_idxs] = exceptions[:, c6_column] + c8[exception_idxs] = exceptions[:, c8_column] + c10[exception_idxs] = exceptions[:, c10_column] + + force_at_zero = potential.attributes[ + potential.attribute_cols.index("force_at_zero") + ] + + x = pairwise.distances + + invR = 1.0 / x + br = beta * x + expbr = torch.exp(-beta * x) + + ttdamp6_sum = ( + 1.0 + br + br**2 / 2 + br**3 / 6 + br**4 / 24 + br**5 / 120 + br**6 / 720 + ) + ttdamp8_sum = ttdamp6_sum + br**7 / 5040 + br**8 / 40320 + ttdamp10_sum = ttdamp8_sum + br**9 / 362880 + br**10 / 3628800 + + ttdamp6 = 1.0 - expbr * ttdamp6_sum + ttdamp8 = 1.0 - expbr * ttdamp8_sum + ttdamp10 = 1.0 - expbr * ttdamp10_sum + + repulsion = force_at_zero * 1.0 / beta * torch.exp(-beta * (x - rho)) + energies = ( + repulsion + - ttdamp6 * c6 * x**-6 + - ttdamp8 * c8 * x**-8 + - ttdamp10 * c10 * x**-10 + ) + + # Apply exclusion scaling factors + energies *= pair_scales + + if not system.is_periodic: + return energies.sum(-1) + + switch_fn, switch_width = _compute_switch_fn(potential, pairwise) + energies *= switch_fn + + energy = energies.sum(-1) + + energy += _compute_dampedexp6810_lrc( + system, + potential.to(precision="double"), + switch_width.double(), + pairwise.cutoff.double(), + torch.det(box_vectors), + ) + + return energy + + +# Import compute_multipole_energy from the new multipole module +from smee.potentials.multipole import compute_multipole_energy + + def _compute_pme_exclusions( system: smee.TensorSystem, potential: smee.TensorPotential ) -> torch.Tensor: diff --git a/smee/tests/convertors/openff/test_nonbonded.py b/smee/tests/convertors/openff/test_nonbonded.py index 5dbce55..e9f4fcb 100644 --- a/smee/tests/convertors/openff/test_nonbonded.py +++ b/smee/tests/convertors/openff/test_nonbonded.py @@ -7,6 +7,7 @@ import smee import smee.converters from smee.converters.openff.nonbonded import ( + convert_dampedexp6810, convert_dexp, convert_electrostatics, convert_vdw, @@ -312,3 +313,24 @@ def test_convert_dexp(ethanol, test_data_dir): assert potential.type == "vdW" assert potential.fn == smee.EnergyFn.VDW_DEXP + + +def test_convert_dampedexp6810(ethanol, test_data_dir): + ff = openff.toolkit.ForceField( + str(test_data_dir / "PHAST-H2CNO-nonpolar-2.0.0.offxml"), load_plugins=True + ) + + interchange = openff.interchange.Interchange.from_smirnoff( + ff, ethanol.to_topology() + ) + vdw_collection = interchange.collections["DampedExp6810"] + + potential, parameter_maps = convert_dampedexp6810( + [vdw_collection], [ethanol.to_topology()], [None] + ) + + assert potential.attribute_cols[-1] == "force_at_zero" + assert potential.parameter_cols == ("rho", "beta", "c6", "c8", "c10") + + assert potential.type == "vdW" + assert potential.fn == smee.EnergyFn.VDW_DAMPEDEXP6810 diff --git a/smee/tests/convertors/openmm/test_openmm.py b/smee/tests/convertors/openmm/test_openmm.py index 7f9606c..18719cc 100644 --- a/smee/tests/convertors/openmm/test_openmm.py +++ b/smee/tests/convertors/openmm/test_openmm.py @@ -291,6 +291,45 @@ def test_convert_to_openmm_system_dexp_periodic(test_data_dir): ) +def test_convert_to_openmm_system_damped6810_periodic(test_data_dir): + ff = openff.toolkit.ForceField( + str(test_data_dir / "PHAST-H2CNO-nonpolar-2.0.0.offxml"), load_plugins=True + ) + top = openff.toolkit.Topology() + + interchanges = [] + + n_copies_per_mol = [5, 5] + + for smiles, n_copies in zip(["OCCO", "O"], n_copies_per_mol): + mol = openff.toolkit.Molecule.from_smiles(smiles) + mol.generate_conformers(n_conformers=1) + + interchange = openff.interchange.Interchange.from_smirnoff( + ff, mol.to_topology() + ) + interchanges.append(interchange) + + for _ in range(n_copies): + top.add_molecule(mol) + + tensor_ff, tensor_tops = smee.converters.convert_interchange(interchanges) + tensor_system = smee.TensorSystem(tensor_tops, n_copies_per_mol, True) + + coords, _ = smee.mm.generate_system_coords( + tensor_system, None, smee.mm.GenerateCoordsConfig() + ) + box_vectors = numpy.eye(3) * 20.0 * openmm.unit.angstrom + + top.box_vectors = box_vectors + + interchange_top = openff.interchange.Interchange.from_smirnoff(ff, top) + + _compare_smee_and_interchange( + tensor_ff, tensor_system, interchange_top, coords, box_vectors + ) + + def test_convert_to_openmm_topology(): formaldehyde_interchange = openff.interchange.Interchange.from_smirnoff( openff.toolkit.ForceField("openff-2.0.0.offxml"), diff --git a/smee/tests/data/PHAST-H2CNO-2.0.0.offxml b/smee/tests/data/PHAST-H2CNO-2.0.0.offxml new file mode 100644 index 0000000..9e53d83 --- /dev/null +++ b/smee/tests/data/PHAST-H2CNO-2.0.0.offxml @@ -0,0 +1,388 @@ + + + Adam Hogan + 2023-04-26 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/smee/tests/data/PHAST-H2CNO-nonpolar-2.0.0.offxml b/smee/tests/data/PHAST-H2CNO-nonpolar-2.0.0.offxml new file mode 100644 index 0000000..85be63b --- /dev/null +++ b/smee/tests/data/PHAST-H2CNO-nonpolar-2.0.0.offxml @@ -0,0 +1,387 @@ + + + Adam Hogan + 2023-04-26 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/smee/tests/potentials/test_nonbonded.py b/smee/tests/potentials/test_nonbonded.py index 456c75f..8c59679 100644 --- a/smee/tests/potentials/test_nonbonded.py +++ b/smee/tests/potentials/test_nonbonded.py @@ -2,10 +2,13 @@ import math import numpy +import openff import openmm.unit import pytest import torch +from tholedipoleplugin import TholeDipoleForce + import smee import smee.converters import smee.converters.openmm @@ -18,19 +21,23 @@ _compute_lj_lrc, _compute_pme_exclusions, compute_coulomb_energy, + compute_dampedexp6810_energy, compute_dexp_energy, compute_lj_energy, compute_pairwise, compute_pairwise_scales, prepare_lrc_types, ) +from smee.potentials.multipole import compute_multipole_energy def _compute_openmm_energy( - system: smee.TensorSystem, - coords: torch.Tensor, - box_vectors: torch.Tensor | None, - potential: smee.TensorPotential, + system: smee.TensorSystem, + coords: torch.Tensor, + box_vectors: torch.Tensor | None, + potential: smee.TensorPotential, + polarization_type: str | None = None, + return_omm_forces: bool | None = None, ) -> torch.Tensor: coords = coords.numpy() * openmm.unit.angstrom @@ -40,6 +47,19 @@ def _compute_openmm_energy( omm_forces = smee.converters.convert_to_openmm_force(potential, system) omm_system = smee.converters.openmm.create_openmm_system(system, None) + # Handle polarization type for TholeDipoleForce + if polarization_type is not None: + for omm_force in omm_forces: + if isinstance(omm_force, TholeDipoleForce): + if polarization_type == "direct": + omm_force.setPolarizationType(TholeDipoleForce.Direct) + elif polarization_type == "mutual": + omm_force.setPolarizationType(TholeDipoleForce.Mutual) + elif polarization_type == "extrapolated": + omm_force.setPolarizationType(TholeDipoleForce.Extrapolated) + else: + raise ValueError(f"Unknown polarization_type: {polarization_type}") + for omm_force in omm_forces: omm_system.addForce(omm_force) @@ -47,7 +67,9 @@ def _compute_openmm_energy( omm_system.setDefaultPeriodicBoxVectors(*box_vectors) omm_integrator = openmm.VerletIntegrator(1.0 * openmm.unit.femtoseconds) - omm_context = openmm.Context(omm_system, omm_integrator) + # Use Reference platform for TholeDipole (OpenCL may segfault) + platform = openmm.Platform.getPlatformByName('Reference') + omm_context = openmm.Context(omm_system, omm_integrator, platform) if box_vectors is not None: omm_context.setPeriodicBoxVectors(*box_vectors) @@ -57,7 +79,31 @@ def _compute_openmm_energy( omm_energy = omm_context.getState(getEnergy=True).getPotentialEnergy() omm_energy = omm_energy.value_in_unit(openmm.unit.kilocalories_per_mole) - return torch.tensor(omm_energy, dtype=torch.float64) + # Get induced dipoles from TholeDipoleForce + try: + thole_force = None + for force in omm_forces: + if isinstance(force, TholeDipoleForce): + thole_force = force + break + + if thole_force: + induced_dipoles = thole_force.getInducedDipoles(omm_context) + + # Convert from e·nm to e·Å (multiply by 10) + conversion_factor = 10.0 + induced_dipoles_angstrom = [[d * conversion_factor for d in dipole] for dipole in induced_dipoles] + print(f"\nOpenMM induced dipoles (e·Å):") + for i, dipole in enumerate(induced_dipoles_angstrom): + print(f" Particle {i}: [{dipole[0]:.10f}, {dipole[1]:.10f}, {dipole[2]:.10f}]") + + except Exception as e: + print(f"Could not get induced dipoles: {e}") + + if return_omm_forces: + return torch.tensor(omm_energy, dtype=torch.float64), omm_forces + else: + return torch.tensor(omm_energy, dtype=torch.float64) def _parameter_key_to_idx(potential: smee.TensorPotential, key: str): @@ -295,7 +341,7 @@ def test_compute_xxx_lrc_with_exceptions(lrc_fn, convert_fn): rs = torch.tensor(8.0) rc = torch.tensor(9.0) - volume = torch.tensor(18.0**3) + volume = torch.tensor(18.0 ** 3) lrc_no_exceptions = lrc_fn(system, vdw_potential_no_exceptions, rs, rc, volume) @@ -320,7 +366,7 @@ def test_compute_xxx_lrc_with_exceptions(lrc_fn, convert_fn): ], ) def test_compute_xxx_energy_periodic( - energy_fn, convert_fn, etoh_water_system, with_exceptions + energy_fn, convert_fn, etoh_water_system, with_exceptions ): tensor_sys, tensor_ff, coords, box_vectors = etoh_water_system @@ -378,17 +424,17 @@ def _expected_energy_lj_exceptions(params: dict[str, smee.tests.utils.LJParam]): sqrt_2 = math.sqrt(2) expected_energy = 4.0 * ( - params["oh"].eps * (params["oh"].sig ** 12 - params["oh"].sig ** 6) - + params["oh"].eps * (params["oh"].sig ** 12 - params["oh"].sig ** 6) - # - + params["ah"].eps * (params["ah"].sig ** 12 - params["ah"].sig ** 6) - + params["ah"].eps * (params["ah"].sig ** 12 - params["ah"].sig ** 6) - # - + params["hh"].eps - * ((params["hh"].sig / sqrt_2) ** 12 - (params["hh"].sig / sqrt_2) ** 6) - # - + params["oa"].eps - * ((params["oa"].sig / sqrt_2) ** 12 - (params["oa"].sig / sqrt_2) ** 6) + params["oh"].eps * (params["oh"].sig ** 12 - params["oh"].sig ** 6) + + params["oh"].eps * (params["oh"].sig ** 12 - params["oh"].sig ** 6) + # + + params["ah"].eps * (params["ah"].sig ** 12 - params["ah"].sig ** 6) + + params["ah"].eps * (params["ah"].sig ** 12 - params["ah"].sig ** 6) + # + + params["hh"].eps + * ((params["hh"].sig / sqrt_2) ** 12 - (params["hh"].sig / sqrt_2) ** 6) + # + + params["oa"].eps + * ((params["oa"].sig / sqrt_2) ** 12 - (params["oa"].sig / sqrt_2) ** 6) ) return expected_energy @@ -410,15 +456,15 @@ def _dexp(eps, sig, dist): sqrt_2 = math.sqrt(2) expected_energy = ( - _dexp(params["oh"].eps, params["oh"].sig, 1.0) - + _dexp(params["oh"].eps, params["oh"].sig, 1.0) - # - + _dexp(params["ah"].eps, params["ah"].sig, 1.0) - + _dexp(params["ah"].eps, params["ah"].sig, 1.0) - # - + _dexp(params["hh"].eps, params["hh"].sig, sqrt_2) - # - + _dexp(params["oa"].eps, params["oa"].sig, sqrt_2) + _dexp(params["oh"].eps, params["oh"].sig, 1.0) + + _dexp(params["oh"].eps, params["oh"].sig, 1.0) + # + + _dexp(params["ah"].eps, params["ah"].sig, 1.0) + + _dexp(params["ah"].eps, params["ah"].sig, 1.0) + # + + _dexp(params["hh"].eps, params["hh"].sig, sqrt_2) + # + + _dexp(params["oa"].eps, params["oa"].sig, sqrt_2) ) return expected_energy @@ -428,9 +474,9 @@ def _dexp(eps, sig, dist): [ (compute_lj_energy, lambda p: p, _expected_energy_lj_exceptions), ( - compute_dexp_energy, - smee.tests.utils.convert_lj_to_dexp, - _expected_energy_dexp_exceptions, + compute_dexp_energy, + smee.tests.utils.convert_lj_to_dexp, + _expected_energy_dexp_exceptions, ), ], ) @@ -530,3 +576,1039 @@ def test_compute_coulomb_energy_non_periodic(): ) assert torch.isclose(energy, expected_energy, atol=1.0e-4) + + +def test_compute_dampedexp6810_energy_ne_scan_non_periodic(test_data_dir): + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + ["[Ne]", "[Ne]"], + [1, 1], + openff.toolkit.ForceField( + str(test_data_dir / "PHAST-H2CNO-nonpolar-2.0.0.offxml"), load_plugins=True + ), + ) + tensor_sys.is_periodic = False + + coords = torch.stack( + [ + torch.vstack( + [ + torch.tensor([0, 0, 0]), + torch.tensor([0, 0, 0]) + torch.tensor([0, 0, 1.5 + i * 0.5]), + ] + ) + for i in range(20) + ] + ) + + energies = smee.compute_energy(tensor_sys, tensor_ff, coords) + expected_energies = [] + for coord in coords: + expected_energies.append( + _compute_openmm_energy( + tensor_sys, coord, None, tensor_ff.potentials_by_type["vdW"] + ) + ) + expected_energies = torch.tensor(expected_energies) + assert torch.allclose(energies, expected_energies.float(), atol=1.0e-4) + + +@pytest.mark.parametrize( + "forcefield_name,polarization_type", + [ + ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "mutual"), + ("PHAST-H2CNO-2.0.0.offxml", "extrapolated") + ] +) +def test_compute_multipole_energy_CC_O_non_periodic(test_data_dir, forcefield_name, polarization_type): + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + ["CC", "O"], + [3, 2], + openff.toolkit.ForceField( + str(test_data_dir / forcefield_name), load_plugins=True + ), + ) + tensor_sys.is_periodic = False + + # Use fixed coordinates to ensure reproducibility + coords = torch.tensor([ + [5.9731, 4.8234, 5.1358], + [5.6308, 3.4725, 5.7007], + [5.0358, 5.2467, 4.7020], + [6.2522, 5.4850, 5.9780], + [6.7136, 4.7967, 4.3256], + [4.9936, 3.6648, 6.6100], + [6.5061, 2.9131, 6.0617], + [5.0991, 2.8173, 4.9849], + [0.9326, 2.8105, 5.2711], + [0.9434, 1.3349, 5.5607], + [1.1295, 2.9339, 4.1794], + [-0.0939, 3.1853, 5.4460], + [1.7103, 3.3774, 5.7996], + [0.0123, 0.9149, 5.0849], + [0.8655, 1.0972, 6.6316], + [1.8432, 0.8172, 5.1776], + [3.2035, 0.7561, 3.0346], + [3.4468, 1.0277, 1.5757], + [4.1522, 0.3467, 3.4566], + [3.0323, 1.7263, 3.5387], + [2.4222, 0.0093, 3.2280], + [4.1461, 1.9103, 1.5332], + [2.5430, 1.3356, 1.0299], + [3.8762, 0.1647, 1.0324], + [6.3764, 1.9600, 2.9162], + [6.1056, 1.3456, 3.6328], + [6.6023, 2.8357, 3.3122], + [3.0792, 6.2544, 4.6979], + [3.5093, 6.6131, 5.5045], + [3.5237, 6.6324, 3.9016]], dtype=torch.float64) + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + es_potential.parameters.requires_grad = True + + energy = compute_multipole_energy(tensor_sys, es_potential, coords.float(), None, + polarization_type=polarization_type) + energy.backward() + expected_energy, omm_forces = _compute_openmm_energy(tensor_sys, coords, None, es_potential, + polarization_type=polarization_type, return_omm_forces=True) + print_debug_info_multipole(energy, expected_energy, tensor_sys, es_potential, omm_forces) + assert torch.allclose(energy, expected_energy, atol=1e-3) + + +@pytest.mark.parametrize( + "forcefield_name,polarization_type", + [ + #("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "direct"), + #("PHAST-H2CNO-2.0.0.offxml", "mutual"), + #("PHAST-H2CNO-2.0.0.offxml", "extrapolated") + ] +) +def test_compute_multipole_energy_charged_ne_non_periodic(test_data_dir, forcefield_name, polarization_type): + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + ["[Ne]", "[Ne]"], + [1, 1], + openff.toolkit.ForceField( + str(test_data_dir / forcefield_name), load_plugins=True + ), + ) + tensor_sys.is_periodic = False + + coords = torch.vstack([torch.tensor([0, 0, 0]), torch.tensor([0, 0, 3.0])]) + + # give each atom a charge otherwise the system is neutral + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + es_potential.parameters[0, 0] = 1 + + energy = compute_multipole_energy( + tensor_sys, es_potential, coords, None, polarization_type=polarization_type + ) + + expected_energy, omm_forces = _compute_openmm_energy(tensor_sys, coords, None, es_potential, + polarization_type=polarization_type, return_omm_forces=True) + print_debug_info_multipole(energy, expected_energy, tensor_sys, es_potential, omm_forces) + assert torch.allclose(energy, expected_energy, atol=1e-3) + + +@pytest.mark.parametrize( + "forcefield_name,polarization_type", + [ + ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "mutual"), + ("PHAST-H2CNO-2.0.0.offxml", "extrapolated") + ] +) +def test_compute_multipole_energy_c_xe_non_periodic(test_data_dir, forcefield_name, polarization_type): + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + ["C", "[Xe]"], + [1, 1], + openff.toolkit.ForceField( + str(test_data_dir / forcefield_name), load_plugins=True + ), + ) + tensor_sys.is_periodic = False + + coords = torch.tensor( + [ + [+0.00000, +0.00000, +0.00000], + [+0.00000, +0.00000, +1.08900], + [+1.02672, +0.00000, -0.36300], + [-0.51336, -0.88916, -0.36300], + [-0.51336, +0.88916, -0.36300], + [+3.00000, +0.00000, +0.00000], + ] + ) + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + energy = compute_multipole_energy( + tensor_sys, es_potential, coords, None, polarization_type=polarization_type + ) + + expected_energy, omm_forces = _compute_openmm_energy(tensor_sys, coords, None, es_potential, + polarization_type=polarization_type, return_omm_forces=True) + print_debug_info_multipole(energy, expected_energy, tensor_sys, es_potential, omm_forces) + assert torch.allclose(energy, expected_energy, atol=1e-3) + + +@pytest.mark.parametrize( + "forcefield_name,polarization_type", + [ + ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "mutual"), + ("PHAST-H2CNO-2.0.0.offxml", "extrapolated") + ] +) +def test_compute_phast2_energy_water_conformers_non_periodic(test_data_dir, forcefield_name, polarization_type): + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + ["O"], + [2], + openff.toolkit.ForceField( + str(test_data_dir / forcefield_name), load_plugins=True + ), + ) + tensor_sys.is_periodic = False + + coords = torch.tensor([[[-5.5964e-02, 8.1693e-01, -5.3445e-01], + [+2.5174e-01, -5.8659e-01, -8.1979e-01], + [+0.0000e+00, 0.0000e+00, 0.0000e+00], + [+7.6271e+00, -6.6103e-01, -5.7262e-01], + [+7.7119e+00, -4.1601e-01, 9.3098e-01], + [+7.9377e+00, 0.0000e+00, 0.0000e+00]], + [[+7.1041e-01, 4.7487e-01, -1.5602e-01], + [-4.8097e-01, 7.2769e-01, -2.2119e-01], + [+0.0000e+00, 0.0000e+00, 0.0000e+00], + [+8.1144e+00, -8.7009e-01, -3.9085e-01], + [+8.1329e+00, 9.2279e-01, -4.4597e-01], + [+7.9377e+00, 0.0000e+00, 0.0000e+00]], + [[+2.1348e-01, 3.6725e-01, 8.3273e-01], + [-5.7851e-01, -6.4377e-01, 5.4664e-01], + [+0.0000e+00, 0.0000e+00, 0.0000e+00], + [+7.2758e+00, 3.1414e-01, -5.9182e-01], + [+7.7279e+00, -5.7537e-01, 7.6088e-01], + [+7.9377e+00, 0.0000e+00, 0.0000e+00]]]) + + multipole_potential = tensor_ff.potentials_by_type["Electrostatics"] + vdw_potential = tensor_ff.potentials_by_type["vdW"] + + multipole_energy = compute_multipole_energy( + tensor_sys, multipole_potential, coords, None, polarization_type=polarization_type + ) + + multipole_expected_energy = torch.tensor( + [_compute_openmm_energy(tensor_sys, coord, None, multipole_potential, polarization_type=polarization_type) for coord in coords] + ) + + vdw_energy = compute_dampedexp6810_energy( + tensor_sys, vdw_potential, coords, None + ) + + vdw_expected_energy = torch.tensor( + [_compute_openmm_energy(tensor_sys, coord, None, vdw_potential) for coord in coords] + ) + + assert torch.allclose(multipole_energy, multipole_expected_energy, atol=1.0e-3) + assert torch.allclose(vdw_energy, vdw_expected_energy, atol=1.0e-4) + + +@pytest.mark.parametrize( + "forcefield_name,polarization_type", + [ + ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "mutual"), + ("PHAST-H2CNO-2.0.0.offxml", "extrapolated") + ] +) +@pytest.mark.parametrize("smiles", ["CC", "CCC", "CCCC", "CCCCC"]) +def test_compute_multipole_energy_isolated_non_periodic(test_data_dir, forcefield_name, polarization_type, smiles): + + print(f"\n{forcefield_name} - {polarization_type} - {smiles}") + + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + [smiles], + [1], + openff.toolkit.ForceField( + str(test_data_dir / forcefield_name), load_plugins=True + ), + ) + tensor_sys.is_periodic = False + + coords, _ = smee.mm.generate_system_coords(tensor_sys, None) + coords = torch.tensor(coords.value_in_unit(openmm.unit.angstrom)) + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + es_potential.parameters.requires_grad = True + + energy = compute_multipole_energy(tensor_sys, es_potential, coords.float(), None, + polarization_type=polarization_type) + energy.backward() + expected_energy, omm_forces = _compute_openmm_energy(tensor_sys, coords, None, es_potential, + polarization_type=polarization_type, return_omm_forces=True) + print_debug_info_multipole(energy, expected_energy, tensor_sys, es_potential, omm_forces) + assert torch.allclose(energy, expected_energy, atol=1e-3) + + +@pytest.mark.parametrize( + "forcefield_name,polarization_type", + [ + ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "mutual"), + ("PHAST-H2CNO-2.0.0.offxml", "extrapolated") + ] +) +def test_compute_multipole_energy_CC_O_periodic(test_data_dir, forcefield_name, polarization_type): + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + ["CC", "O"], + [10, 15], + openff.toolkit.ForceField( + str(test_data_dir / forcefield_name), load_plugins=True + ), + ) + tensor_sys.is_periodic = True + + # Use lower density to ensure box size >= 2*cutoff (18.0 Å) + config = smee.mm.GenerateCoordsConfig( + target_density=0.4 * openmm.unit.gram / openmm.unit.milliliter, + scale_factor=1.3, + padding=3.0 * openmm.unit.angstrom, + ) + coords, box_vectors = smee.mm.generate_system_coords(tensor_sys, None, config) + coords = torch.tensor(coords.value_in_unit(openmm.unit.angstrom)) + box_vectors = torch.tensor(box_vectors.value_in_unit(openmm.unit.angstrom)) + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + es_potential.parameters.requires_grad = True + + energy = compute_multipole_energy(tensor_sys, es_potential, coords.float(), box_vectors.float(), + polarization_type=polarization_type) + energy.backward() + expected_energy, omm_forces = _compute_openmm_energy(tensor_sys, coords, box_vectors, es_potential, + polarization_type=polarization_type, return_omm_forces=True) + print_debug_info_multipole(energy, expected_energy, tensor_sys, es_potential, omm_forces) + assert torch.allclose(energy, expected_energy, atol=1e-3) + + +def print_debug_info_multipole(energy: torch.Tensor, + expected_energy: torch.Tensor, + tensor_sys: smee.TensorSystem, + es_potential: smee.TensorPotential, + omm_forces: list[openmm.Force]): + print(f"Energy\nSMEE {energy} OpenMM {expected_energy}") + + print(f"SMEE Parameters {es_potential.parameters}") + + for idx, topology in enumerate(tensor_sys.topologies): + print(f"SMEE Topology {idx}") + print(f"Assignment Matrix {topology.parameters[es_potential.type].assignment_matrix.to_dense()}") + + thole_force = None + for force in omm_forces: + if isinstance(force, TholeDipoleForce): + thole_force = force + break + + print(thole_force) + + +# ============================================================================= +# Water Model Tests - Testing SMEE with water force field from water_fitting +# ============================================================================= + +def _get_water_model_forcefield(): + """Load the water model force field.""" + import pathlib + water_ff_path = pathlib.Path(__file__).parent.parent.parent.parent / "water_fitting" / "water_model.offxml" + if not water_ff_path.exists(): + pytest.skip(f"Water model not found at {water_ff_path}") + return openff.toolkit.ForceField(str(water_ff_path), load_plugins=True) + + +def _get_water_dimer_coords(): + """Return coordinates for a water dimer in Angstroms. + + First water at origin, second water displaced along z-axis. + Standard water geometry: O-H bond length ~0.9572 Å, H-O-H angle ~104.5° + """ + # Water 1 (near origin) + o1 = [0.0, 0.0, 0.0] + h1a = [0.9572, 0.0, 0.0] + h1b = [-0.2399, 0.9270, 0.0] + + # Water 2 (displaced by ~2.8 Å along z-axis - typical H-bond distance) + z_offset = 2.8 + o2 = [0.0, 0.0, z_offset] + h2a = [0.9572, 0.0, z_offset] + h2b = [-0.2399, 0.9270, z_offset] + + coords = torch.tensor([o1, h1a, h1b, o2, h2a, h2b], dtype=torch.float64) + return coords + + +def _get_water_trimer_coords(): + """Return coordinates for a water trimer in Angstroms. + + Three water molecules in a triangular arrangement. + """ + import math + + # Water 1 at origin + o1 = [0.0, 0.0, 0.0] + h1a = [0.9572, 0.0, 0.0] + h1b = [-0.2399, 0.9270, 0.0] + + # Water 2 displaced along x + x_offset = 3.0 + o2 = [x_offset, 0.0, 0.0] + h2a = [x_offset + 0.9572, 0.0, 0.0] + h2b = [x_offset - 0.2399, 0.9270, 0.0] + + # Water 3 displaced to form triangle + angle = math.pi / 3 # 60 degrees + r = 3.0 + o3 = [r * math.cos(angle), r * math.sin(angle), 0.0] + h3a = [o3[0] + 0.9572, o3[1], 0.0] + h3b = [o3[0] - 0.2399, o3[1] + 0.9270, 0.0] + + coords = torch.tensor([o1, h1a, h1b, o2, h2a, h2b, o3, h3a, h3b], dtype=torch.float64) + return coords + + +@pytest.mark.parametrize("polarization_type", ["direct", "mutual", "extrapolated"]) +def test_water_dimer_energy(polarization_type): + """Test water dimer energy matches OpenMM/TholeDipoleForce.""" + force_field = _get_water_model_forcefield() + + water1 = openff.toolkit.Molecule.from_smiles("O") + water2 = openff.toolkit.Molecule.from_smiles("O") + topology = openff.toolkit.Topology.from_molecules([water1, water2]) + + interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology) + tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange) + + system = smee.TensorSystem([tensor_top], [1], is_periodic=False) + coords = _get_water_dimer_coords() + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + # Compute SMEE energy + energy = compute_multipole_energy( + system, es_potential, coords.float(), None, polarization_type=polarization_type + ) + + # Compute OpenMM reference energy + expected_energy = _compute_openmm_energy( + system, coords, None, es_potential, polarization_type=polarization_type + ) + + print(f"\nWater dimer ({polarization_type}):") + print(f" SMEE Energy: {energy.item():.6f} kcal/mol") + print(f" OpenMM Energy: {expected_energy.item():.6f} kcal/mol") + print(f" Difference: {abs(energy.item() - expected_energy.item()):.2e} kcal/mol") + + assert torch.allclose(energy, expected_energy, atol=1e-4) + + +@pytest.mark.parametrize("polarization_type", ["direct", "mutual", "extrapolated"]) +def test_water_dimer_forces(polarization_type): + """Test water dimer forces via autograd match finite difference.""" + force_field = _get_water_model_forcefield() + + water1 = openff.toolkit.Molecule.from_smiles("O") + water2 = openff.toolkit.Molecule.from_smiles("O") + topology = openff.toolkit.Topology.from_molecules([water1, water2]) + + interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology) + tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange) + + system = smee.TensorSystem([tensor_top], [1], is_periodic=False) + coords = _get_water_dimer_coords().float() + coords.requires_grad = True + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + # Compute SMEE energy and forces + energy = compute_multipole_energy( + system, es_potential, coords, None, polarization_type=polarization_type + ) + energy.backward() + smee_forces = -coords.grad.detach() + + # Compute numerical forces via finite difference + # Use h=1e-2 which is optimal for float32 precision (smaller h causes precision loss) + h = 1e-2 + numerical_forces = torch.zeros_like(coords) + + for i in range(coords.shape[0]): + for j in range(3): + coords_plus = coords.detach().clone() + coords_plus[i, j] += h + e_plus = compute_multipole_energy( + system, es_potential, coords_plus, None, polarization_type=polarization_type + ) + + coords_minus = coords.detach().clone() + coords_minus[i, j] -= h + e_minus = compute_multipole_energy( + system, es_potential, coords_minus, None, polarization_type=polarization_type + ) + + numerical_forces[i, j] = -(e_plus - e_minus) / (2 * h) + + print(f"\nWater dimer forces ({polarization_type}):") + print(f" Max force difference (autograd vs numerical): {(smee_forces - numerical_forces).abs().max():.2e}") + + assert torch.allclose(smee_forces, numerical_forces, atol=1e-3) + + +@pytest.mark.parametrize("polarization_type", ["direct", "mutual", "extrapolated"]) +def test_water_trimer_energy(polarization_type): + """Test water trimer energy matches OpenMM/TholeDipoleForce.""" + force_field = _get_water_model_forcefield() + + waters = [openff.toolkit.Molecule.from_smiles("O") for _ in range(3)] + topology = openff.toolkit.Topology.from_molecules(waters) + + interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology) + tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange) + + system = smee.TensorSystem([tensor_top], [1], is_periodic=False) + coords = _get_water_trimer_coords() + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + # Compute SMEE energy + energy = compute_multipole_energy( + system, es_potential, coords.float(), None, polarization_type=polarization_type + ) + + # Compute OpenMM reference energy + expected_energy = _compute_openmm_energy( + system, coords, None, es_potential, polarization_type=polarization_type + ) + + print(f"\nWater trimer ({polarization_type}):") + print(f" SMEE Energy: {energy.item():.6f} kcal/mol") + print(f" OpenMM Energy: {expected_energy.item():.6f} kcal/mol") + print(f" Difference: {abs(energy.item() - expected_energy.item()):.2e} kcal/mol") + + assert torch.allclose(energy, expected_energy, atol=1e-4) + + +def test_single_water_zero_energy(): + """Test that a single isolated water molecule has zero intermolecular energy.""" + force_field = _get_water_model_forcefield() + + water = openff.toolkit.Molecule.from_smiles("O") + topology = openff.toolkit.Topology.from_molecules([water]) + + interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology) + tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange) + + system = smee.TensorSystem([tensor_top], [1], is_periodic=False) + + # Standard water geometry + coords = torch.tensor([ + [0.0, 0.0, 0.0], + [0.9572, 0.0, 0.0], + [-0.2399, 0.9270, 0.0], + ], dtype=torch.float32) + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + # Compute SMEE energy + energy = compute_multipole_energy( + system, es_potential, coords, None, polarization_type="direct" + ) + + # Single molecule should have zero intermolecular energy + # (intramolecular terms are excluded via covalent maps) + print(f"\nSingle water molecule energy: {energy.item():.6e} kcal/mol") + + assert torch.allclose(energy, torch.tensor(0.0, dtype=energy.dtype), atol=1e-6) + + +@pytest.mark.parametrize("polarization_type", ["direct", "mutual", "extrapolated"]) +def test_water_dimer_multiple_conformers(polarization_type): + """Test water dimer with multiple conformers (batched computation).""" + force_field = _get_water_model_forcefield() + + water1 = openff.toolkit.Molecule.from_smiles("O") + water2 = openff.toolkit.Molecule.from_smiles("O") + topology = openff.toolkit.Topology.from_molecules([water1, water2]) + + interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology) + tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange) + + system = smee.TensorSystem([tensor_top], [1], is_periodic=False) + + # Create multiple conformers with different separations + base_coords = _get_water_dimer_coords() + conformers = [] + for z_offset in [2.5, 3.0, 3.5, 4.0]: + coords = base_coords.clone() + # Shift second water + coords[3:, 2] = z_offset + conformers.append(coords) + + coords_batch = torch.stack(conformers).float() + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + # Compute SMEE energies for batch + energies = compute_multipole_energy( + system, es_potential, coords_batch, None, polarization_type=polarization_type + ) + + # Compute OpenMM reference energies individually + expected_energies = torch.tensor([ + _compute_openmm_energy(system, coords, None, es_potential, polarization_type=polarization_type) + for coords in conformers + ]) + + print(f"\nWater dimer multiple conformers ({polarization_type}):") + for i, (e, exp_e) in enumerate(zip(energies, expected_energies)): + print(f" Conformer {i}: SMEE={e.item():.6f}, OpenMM={exp_e.item():.6f}, diff={abs(e.item()-exp_e.item()):.2e}") + + assert torch.allclose(energies, expected_energies, atol=1e-4) + + +@pytest.mark.parametrize("polarization_type", ["direct", "mutual", "extrapolated"]) +def test_water_dimer_axis_resolution(polarization_type): + """Test that axis atoms are correctly resolved from SMIRKS indices to topology indices.""" + force_field = _get_water_model_forcefield() + + water1 = openff.toolkit.Molecule.from_smiles("O") + water2 = openff.toolkit.Molecule.from_smiles("O") + topology = openff.toolkit.Topology.from_molecules([water1, water2]) + + interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology) + tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange) + + system = smee.TensorSystem([tensor_top], [1], is_periodic=False) + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + param_map = tensor_top.parameters[es_potential.type] + assigned_params = param_map.assignment_matrix @ es_potential.parameters + + print(f"\nAxis resolution test ({polarization_type}):") + print("Checking that axis atoms are actual topology indices (not SMIRKS indices):") + + # For 6 atoms (2 waters), axis atoms should be 0-5 (topology indices) + # NOT 1-3 (SMIRKS indices) + for i in range(6): + axis_type = int(assigned_params[i, 13]) + atom_z = int(assigned_params[i, 14]) + atom_x = int(assigned_params[i, 15]) + print(f" Atom {i}: axisType={axis_type}, atomZ={atom_z}, atomX={atom_x}") + + # Verify axis atoms are valid topology indices + if atom_z >= 0: + assert atom_z < 6, f"atomZ={atom_z} exceeds topology size (6 atoms)" + if atom_x >= 0: + assert atom_x < 6, f"atomX={atom_x} exceeds topology size (6 atoms)" + + # Verify we can compute energy without errors + coords = _get_water_dimer_coords() + energy = compute_multipole_energy( + system, es_potential, coords.float(), None, polarization_type=polarization_type + ) + assert torch.isfinite(energy) + + +@pytest.mark.parametrize("n_waters", [2, 4, 8]) +def test_water_cluster_energy(n_waters): + """Test water clusters of various sizes.""" + force_field = _get_water_model_forcefield() + + waters = [openff.toolkit.Molecule.from_smiles("O") for _ in range(n_waters)] + topology = openff.toolkit.Topology.from_molecules(waters) + + interchange = openff.interchange.Interchange.from_smirnoff(force_field, topology) + tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange) + + system = smee.TensorSystem([tensor_top], [1], is_periodic=False) + + # Generate coordinates for water cluster + import numpy as np + np.random.seed(42) + + coords_list = [] + spacing = 3.0 # Å between water molecules + + for i in range(n_waters): + # Place water molecules in a grid + x = (i % 2) * spacing + y = ((i // 2) % 2) * spacing + z = (i // 4) * spacing + + # Standard water geometry + coords_list.append([x, y, z]) # O + coords_list.append([x + 0.9572, y, z]) # H + coords_list.append([x - 0.2399, y + 0.9270, z]) # H + + coords = torch.tensor(coords_list, dtype=torch.float64) + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + # Compute SMEE energy + energy = compute_multipole_energy( + system, es_potential, coords.float(), None, polarization_type="mutual" + ) + + # Compute OpenMM reference energy + expected_energy = _compute_openmm_energy( + system, coords, None, es_potential, polarization_type="mutual" + ) + + print(f"\nWater cluster (n={n_waters}):") + print(f" SMEE Energy: {energy.item():.6f} kcal/mol") + print(f" OpenMM Energy: {expected_energy.item():.6f} kcal/mol") + print(f" Difference: {abs(energy.item() - expected_energy.item()):.2e} kcal/mol") + + assert torch.allclose(energy, expected_energy, atol=1e-3) + + +# ============================================================================= +# Ammonia Tests - Testing ThreeFold axis type and different molecular geometry +# ============================================================================= + +def _get_ammonia_dimer_coords(): + """Return coordinates for an ammonia dimer in Angstroms. + + Ammonia geometry: N-H bond length ~1.012 Å, H-N-H angle ~106.7° + Tetrahedral-like geometry with lone pair. + """ + import math + + # Ammonia 1 at origin + # N at origin, 3 H atoms in tetrahedral-like arrangement + n1 = [0.0, 0.0, 0.0] + # H atoms roughly 1.012 Å from N, ~106.7° H-N-H angle + h1a = [0.9377, 0.0, -0.3816] + h1b = [-0.4689, 0.8121, -0.3816] + h1c = [-0.4689, -0.8121, -0.3816] + + # Ammonia 2 displaced along z-axis + z_offset = 3.5 + n2 = [0.0, 0.0, z_offset] + h2a = [0.9377, 0.0, z_offset - 0.3816] + h2b = [-0.4689, 0.8121, z_offset - 0.3816] + h2c = [-0.4689, -0.8121, z_offset - 0.3816] + + coords = torch.tensor([n1, h1a, h1b, h1c, n2, h2a, h2b, h2c], dtype=torch.float64) + return coords + + +def _get_ammonia_trimer_coords(): + """Return coordinates for an ammonia trimer in Angstroms.""" + import math + + # Ammonia 1 at origin + n1 = [0.0, 0.0, 0.0] + h1a = [0.9377, 0.0, -0.3816] + h1b = [-0.4689, 0.8121, -0.3816] + h1c = [-0.4689, -0.8121, -0.3816] + + # Ammonia 2 displaced along x + x_offset = 3.5 + n2 = [x_offset, 0.0, 0.0] + h2a = [x_offset + 0.9377, 0.0, -0.3816] + h2b = [x_offset - 0.4689, 0.8121, -0.3816] + h2c = [x_offset - 0.4689, -0.8121, -0.3816] + + # Ammonia 3 displaced to form triangle + angle = math.pi / 3 + r = 3.5 + n3 = [r * math.cos(angle), r * math.sin(angle), 0.0] + h3a = [n3[0] + 0.9377, n3[1], -0.3816] + h3b = [n3[0] - 0.4689, n3[1] + 0.8121, -0.3816] + h3c = [n3[0] - 0.4689, n3[1] - 0.8121, -0.3816] + + coords = torch.tensor([ + n1, h1a, h1b, h1c, + n2, h2a, h2b, h2c, + n3, h3a, h3b, h3c + ], dtype=torch.float64) + return coords + + +@pytest.mark.parametrize( + "forcefield_name,polarization_type", + [ + ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "mutual"), + ("PHAST-H2CNO-2.0.0.offxml", "extrapolated"), + ] +) +def test_ammonia_dimer_energy(test_data_dir, forcefield_name, polarization_type): + """Test ammonia dimer energy matches OpenMM/TholeDipoleForce. + + Ammonia uses ThreeFold axis type for N and ZOnly for H atoms. + """ + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + ["N", "N"], + [1, 1], + openff.toolkit.ForceField( + str(test_data_dir / forcefield_name), load_plugins=True + ), + ) + tensor_sys.is_periodic = False + + coords = _get_ammonia_dimer_coords() + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + energy = compute_multipole_energy( + tensor_sys, es_potential, coords.float(), None, polarization_type=polarization_type + ) + + expected_energy = _compute_openmm_energy( + tensor_sys, coords, None, es_potential, polarization_type=polarization_type + ) + + print(f"\nAmmonia dimer ({forcefield_name}, {polarization_type}):") + print(f" SMEE Energy: {energy.item():.6f} kcal/mol") + print(f" OpenMM Energy: {expected_energy.item():.6f} kcal/mol") + print(f" Difference: {abs(energy.item() - expected_energy.item()):.2e} kcal/mol") + + assert torch.allclose(energy, expected_energy, atol=1e-3) + + +@pytest.mark.parametrize( + "forcefield_name,polarization_type", + [ + ("PHAST-H2CNO-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "mutual"), + ("PHAST-H2CNO-2.0.0.offxml", "extrapolated"), + ] +) +def test_ammonia_dimer_forces(test_data_dir, forcefield_name, polarization_type): + """Test ammonia dimer forces via autograd match finite difference.""" + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + ["N", "N"], + [1, 1], + openff.toolkit.ForceField( + str(test_data_dir / forcefield_name), load_plugins=True + ), + ) + tensor_sys.is_periodic = False + + coords = _get_ammonia_dimer_coords().float() + coords.requires_grad = True + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + energy = compute_multipole_energy( + tensor_sys, es_potential, coords, None, polarization_type=polarization_type + ) + energy.backward() + smee_forces = -coords.grad.detach() + + # Use h=1e-2 for optimal float32 finite difference precision + h = 1e-2 + numerical_forces = torch.zeros_like(coords) + + for i in range(coords.shape[0]): + for j in range(3): + coords_plus = coords.detach().clone() + coords_plus[i, j] += h + e_plus = compute_multipole_energy( + tensor_sys, es_potential, coords_plus, None, polarization_type=polarization_type + ) + + coords_minus = coords.detach().clone() + coords_minus[i, j] -= h + e_minus = compute_multipole_energy( + tensor_sys, es_potential, coords_minus, None, polarization_type=polarization_type + ) + + numerical_forces[i, j] = -(e_plus - e_minus) / (2 * h) + + print(f"\nAmmonia dimer forces ({forcefield_name}, {polarization_type}):") + print(f" Max force difference (autograd vs numerical): {(smee_forces - numerical_forces).abs().max():.2e}") + + assert torch.allclose(smee_forces, numerical_forces, atol=1e-3) + + +@pytest.mark.parametrize( + "forcefield_name,polarization_type", + [ + ("PHAST-H2CNO-nonpolar-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "mutual"), + ("PHAST-H2CNO-2.0.0.offxml", "extrapolated"), + ] +) +def test_ammonia_trimer_energy(test_data_dir, forcefield_name, polarization_type): + """Test ammonia trimer energy matches OpenMM/TholeDipoleForce.""" + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + ["N"], + [3], + openff.toolkit.ForceField( + str(test_data_dir / forcefield_name), load_plugins=True + ), + ) + tensor_sys.is_periodic = False + + coords = _get_ammonia_trimer_coords() + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + energy = compute_multipole_energy( + tensor_sys, es_potential, coords.float(), None, polarization_type=polarization_type + ) + + expected_energy = _compute_openmm_energy( + tensor_sys, coords, None, es_potential, polarization_type=polarization_type + ) + + print(f"\nAmmonia trimer ({forcefield_name}, {polarization_type}):") + print(f" SMEE Energy: {energy.item():.6f} kcal/mol") + print(f" OpenMM Energy: {expected_energy.item():.6f} kcal/mol") + print(f" Difference: {abs(energy.item() - expected_energy.item()):.2e} kcal/mol") + + assert torch.allclose(energy, expected_energy, atol=1e-3) + + +@pytest.mark.parametrize( + "forcefield_name", + ["PHAST-H2CNO-nonpolar-2.0.0.offxml", "PHAST-H2CNO-2.0.0.offxml"] +) +def test_single_ammonia_zero_energy(test_data_dir, forcefield_name): + """Test that a single isolated ammonia molecule has zero intermolecular energy.""" + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + ["N"], + [1], + openff.toolkit.ForceField( + str(test_data_dir / forcefield_name), load_plugins=True + ), + ) + tensor_sys.is_periodic = False + + # Standard ammonia geometry + coords = torch.tensor([ + [0.0, 0.0, 0.0], # N + [0.9377, 0.0, -0.3816], # H + [-0.4689, 0.8121, -0.3816], # H + [-0.4689, -0.8121, -0.3816], # H + ], dtype=torch.float32) + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + energy = compute_multipole_energy( + tensor_sys, es_potential, coords, None, polarization_type="direct" + ) + + print(f"\nSingle ammonia molecule energy ({forcefield_name}): {energy.item():.6e} kcal/mol") + + assert torch.allclose(energy, torch.tensor(0.0, dtype=energy.dtype), atol=1e-6) + + +@pytest.mark.parametrize( + "forcefield_name,polarization_type", + [ + ("PHAST-H2CNO-2.0.0.offxml", "direct"), + ("PHAST-H2CNO-2.0.0.offxml", "mutual"), + ("PHAST-H2CNO-2.0.0.offxml", "extrapolated"), + ] +) +def test_ammonia_axis_types(test_data_dir, forcefield_name, polarization_type): + """Test that ammonia axis types are correctly assigned.""" + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + ["N"], + [2], + openff.toolkit.ForceField( + str(test_data_dir / forcefield_name), load_plugins=True + ), + ) + tensor_sys.is_periodic = False + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + # Get assigned parameters for the single topology (4 atoms) + topology = tensor_sys.topologies[0] + param_map = topology.parameters[es_potential.type] + assigned_params = param_map.assignment_matrix @ es_potential.parameters + + n_atoms_per_mol = topology.n_particles + print(f"\nAmmonia axis types ({forcefield_name}, {polarization_type}):") + print(f" Atoms per molecule: {n_atoms_per_mol}") + + # Check axis types for each atom in the topology template + for i in range(n_atoms_per_mol): + axis_type = int(assigned_params[i, 13]) + atom_z = int(assigned_params[i, 14]) + atom_x = int(assigned_params[i, 15]) + atom_y = int(assigned_params[i, 16]) + + atom_type = "N" if i == 0 else "H" + print(f" Atom {i} ({atom_type}): axisType={axis_type}, atomZ={atom_z}, atomX={atom_x}, atomY={atom_y}") + + # Verify axis atoms are valid topology indices (or -1) + if atom_z >= 0: + assert atom_z < n_atoms_per_mol, f"atomZ={atom_z} exceeds topology size" + if atom_x >= 0: + assert atom_x < n_atoms_per_mol, f"atomX={atom_x} exceeds topology size" + if atom_y >= 0: + assert atom_y < n_atoms_per_mol, f"atomY={atom_y} exceeds topology size" + + # Verify we can compute energy without errors + coords = _get_ammonia_dimer_coords() + energy = compute_multipole_energy( + tensor_sys, es_potential, coords.float(), None, polarization_type=polarization_type + ) + assert torch.isfinite(energy) + + +@pytest.mark.parametrize("n_ammonia", [2, 4, 6]) +def test_ammonia_cluster_energy(test_data_dir, n_ammonia): + """Test ammonia clusters of various sizes.""" + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles( + ["N"], + [n_ammonia], + openff.toolkit.ForceField( + str(test_data_dir / "PHAST-H2CNO-2.0.0.offxml"), load_plugins=True + ), + ) + tensor_sys.is_periodic = False + + # Generate coordinates for ammonia cluster + coords_list = [] + spacing = 4.0 # Å between molecules + + for i in range(n_ammonia): + # Place ammonia molecules in a grid + x = (i % 2) * spacing + y = ((i // 2) % 2) * spacing + z = (i // 4) * spacing + + # Standard ammonia geometry + coords_list.append([x, y, z]) # N + coords_list.append([x + 0.9377, y, z - 0.3816]) # H + coords_list.append([x - 0.4689, y + 0.8121, z - 0.3816]) # H + coords_list.append([x - 0.4689, y - 0.8121, z - 0.3816]) # H + + coords = torch.tensor(coords_list, dtype=torch.float64) + + es_potential = tensor_ff.potentials_by_type["Electrostatics"] + + energy = compute_multipole_energy( + tensor_sys, es_potential, coords.float(), None, polarization_type="mutual" + ) + + expected_energy = _compute_openmm_energy( + tensor_sys, coords, None, es_potential, polarization_type="mutual" + ) + + print(f"\nAmmonia cluster (n={n_ammonia}):") + print(f" SMEE Energy: {energy.item():.6f} kcal/mol") + print(f" OpenMM Energy: {expected_energy.item():.6f} kcal/mol") + print(f" Difference: {abs(energy.item() - expected_energy.item()):.2e} kcal/mol") + + assert torch.allclose(energy, expected_energy, atol=1e-3)