diff --git a/src/mlip/data/configs.py b/src/mlip/data/configs.py index f9a692f..d746278 100644 --- a/src/mlip/data/configs.py +++ b/src/mlip/data/configs.py @@ -126,6 +126,7 @@ class GraphDatasetBuilderConfig(pydantic.BaseModel): to ``True``, the models assume ``"zero"`` atomic energies as can be set in the model hyperparameters. avg_num_neighbors: The pre-computed average number of neighbors. + avg_num_nodes: The pre-computed average number of nodes. avg_r_min_angstrom: The pre-computed average minimum distance between nodes. """ @@ -140,4 +141,5 @@ class GraphDatasetBuilderConfig(pydantic.BaseModel): use_formation_energies: bool = False avg_num_neighbors: Optional[float] = None + avg_num_nodes: Optional[float] = None avg_r_min_angstrom: Optional[float] = None diff --git a/src/mlip/data/dataset_info.py b/src/mlip/data/dataset_info.py index c261a86..653217d 100644 --- a/src/mlip/data/dataset_info.py +++ b/src/mlip/data/dataset_info.py @@ -17,6 +17,7 @@ from typing import Optional import jraph +import numpy as np import pydantic from ase import Atom @@ -40,6 +41,7 @@ class DatasetInfo(pydantic.BaseModel): cutoff_distance_angstrom: The graph cutoff distance that was used in the dataset in Angstrom. avg_num_neighbors: The mean number of neighbors an atom has in the dataset. + avg_num_nodes: The mean number of nodes a structure has in the dataset. avg_r_min_angstrom: The mean minimum edge distance for a structure in the dataset. scaling_mean: The mean used for the rescaling of the dataset values, the @@ -51,6 +53,7 @@ class DatasetInfo(pydantic.BaseModel): atomic_energies_map: dict[int, float] cutoff_distance_angstrom: float avg_num_neighbors: float = 1.0 + avg_num_nodes: float = 1.0 avg_r_min_angstrom: Optional[float] = None scaling_mean: float = 0.0 scaling_stdev: float = 1.0 @@ -62,6 +65,7 @@ def __str__(self): return ( f"Atomic Energies: {atomic_energies_map_with_symbols}, " f"Avg. num. neighbors: {self.avg_num_neighbors:.2f}, " + f"Avg. num. nodes: {self.avg_num_nodes:.2f}, " f"Avg. r_min: {self.avg_r_min_angstrom:.2f}, " f"Graph cutoff distance: {self.cutoff_distance_angstrom}" ) @@ -72,6 +76,7 @@ def compute_dataset_info_from_graphs( cutoff_distance_angstrom: float, z_table: AtomicNumberTable, avg_num_neighbors: Optional[float] = None, + avg_num_nodes: Optional[float] = None, avg_r_min_angstrom: Optional[float] = None, ) -> DatasetInfo: """Computes the dataset info from graphs, typically training set graphs. @@ -84,6 +89,8 @@ def compute_dataset_info_from_graphs( map keys. avg_num_neighbors: The optionally pre-computed average number of neighbors. If provided, we skip recomputing this. + avg_num_nodes: The optionally pre-computed average number of nodes. If + provided, we skip recomputing this. avg_r_min_angstrom: The optionally pre-computed average miminum radius. If provided, we skip recomputing this. @@ -99,6 +106,10 @@ def compute_dataset_info_from_graphs( logger.debug("Computing average number of neighbors...") avg_num_neighbors = compute_avg_num_neighbors(graphs) logger.debug("Average number of neighbors: %.1f", avg_num_neighbors) + if avg_num_nodes is None: + logger.debug("Computing average number of nodes...") + avg_num_nodes = np.mean([i.item() for g in graphs for i in g.n_node]) + logger.debug("Average number of nodes: %.1f", avg_num_nodes) if avg_r_min_angstrom is None: logger.debug("Computing average min neighbor distance...") avg_r_min_angstrom = compute_avg_min_neighbor_distance(graphs) @@ -119,6 +130,7 @@ def compute_dataset_info_from_graphs( atomic_energies_map=atomic_energies_map, cutoff_distance_angstrom=cutoff_distance_angstrom, avg_num_neighbors=avg_num_neighbors, + avg_num_nodes=avg_num_nodes, avg_r_min_angstrom=avg_r_min_angstrom, scaling_mean=0.0, scaling_stdev=1.0, diff --git a/src/mlip/data/graph_dataset_builder.py b/src/mlip/data/graph_dataset_builder.py index bdff854..2091ce2 100644 --- a/src/mlip/data/graph_dataset_builder.py +++ b/src/mlip/data/graph_dataset_builder.py @@ -158,6 +158,7 @@ def prepare_datasets(self) -> None: self._config.graph_cutoff_angstrom, z_table, self._config.avg_num_neighbors, + self._config.avg_num_nodes, self._config.avg_r_min_angstrom, ) diff --git a/src/mlip/models/__init__.py b/src/mlip/models/__init__.py index 0474a43..442e40a 100644 --- a/src/mlip/models/__init__.py +++ b/src/mlip/models/__init__.py @@ -17,3 +17,6 @@ from mlip.models.nequip.models import Nequip from mlip.models.predictor import ForceFieldPredictor from mlip.models.visnet.models import Visnet +from mlip.models.liten.models import Liten +from mlip.models.so3krates.models import So3krates +from mlip.models.equiformer_v2.models import EquiformerV2 diff --git a/src/mlip/models/cutoff.py b/src/mlip/models/cutoff.py new file mode 100644 index 0000000..0ba4b10 --- /dev/null +++ b/src/mlip/models/cutoff.py @@ -0,0 +1,101 @@ +# Copyright 2025 Zhongguancun Academy + +""" +This module contains all cutoff / radial envelope functions seen in all models. +It can be used to refactor Mace, Visnet and Nequip. +""" + +from enum import Enum + +import jax +import jax.numpy as jnp +import flax.linen as nn +import e3nn_jax as e3nn + + +class SoftCutoff(nn.Module): + """Soft envelope radial envelope function.""" + cutoff: float + arg_multiplicator: float = 2.0 + value_at_origin: float = 1.2 + + @nn.compact + def __call__(self, length): + return e3nn.soft_envelope( + length, + self.cutoff, + arg_multiplicator=self.arg_multiplicator, + value_at_origin=self.value_at_origin, + ) + + +class PolynomialCutoff(nn.Module): + """Polynomial radial envelope function from the MACE torch version.""" + cutoff: float + p: int = 5 + + @nn.compact + def __call__(self, length: jax.Array): + a = - (self.p + 1.0) * (self.p + 2.0) / 2.0 + b = self.p * (self.p + 2.0) + c = - self.p * (self.p + 1.0) / 2 + + x_norm = length / self.cutoff + envelope = 1.0 + jnp.pow(x_norm, self.p) * ( + a + x_norm * (b + x_norm * c) + ) + return envelope * (length < self.cutoff) + + +class CosineCutoff(nn.Module): + """Behler-style cosine cutoff function.""" + cutoff: float + + @nn.compact + def __call__(self, length: jax.Array) -> jax.Array: + cutoffs = 0.5 * (jnp.cos(length * jnp.pi / self.cutoff) + 1.0) + return cutoffs * (length < self.cutoff) + + +class PhysCutoff(nn.Module): + """Cutoff function used in PhysNet.""" + cutoff: float + + @nn.compact + def __call__(self, length: jax.Array) -> jax.Array: + x_norm = length / self.cutoff + cutoffs = 1 - 6 * x_norm ** 5 + 15 * x_norm ** 4 - 10 * x_norm ** 3 + return cutoffs * (length < self.cutoff) + + +class ExponentialCutoff(nn.Module): + """Exponential cutoff function used in SpookyNet.""" + cutoff: float + + @nn.compact + def __call__(self, length: jax.Array) -> jax.Array: + # TODO(bhcao): Check if this is numerically stable. + cutoffs = jnp.exp(-length ** 2 / ((self.cutoff - length) * (self.cutoff + length))) + return cutoffs * (length < self.cutoff) + +# --- Options --- + + +class CutoffFunction(Enum): + POLYNOMIAL = "polynomial" + SOFT = "soft" + COSINE = "cosine" + PHYS = "phys" + EXPONENTIAL = "exponential" + + +def parse_cutoff(cutoff: CutoffFunction | str) -> type[nn.Module]: + cutoff_function_map = { + CutoffFunction.POLYNOMIAL: PolynomialCutoff, + CutoffFunction.SOFT: SoftCutoff, + CutoffFunction.COSINE: CosineCutoff, + CutoffFunction.PHYS: PhysCutoff, + CutoffFunction.EXPONENTIAL: ExponentialCutoff, + } + assert set(CutoffFunction) == set(cutoff_function_map.keys()) + return cutoff_function_map[CutoffFunction(cutoff)] diff --git a/src/mlip/models/equiformer_v2/activations.py b/src/mlip/models/equiformer_v2/activations.py new file mode 100644 index 0000000..b26d20e --- /dev/null +++ b/src/mlip/models/equiformer_v2/activations.py @@ -0,0 +1,95 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Yi-Lun Liao (https://github.com/atomicarchitects/equiformer_v2) under MIT license. + +import jax +import jax.numpy as jnp +import flax.linen as nn + +from mlip.models.equiformer_v2.transform import get_s2grid_mats +from mlip.models.equiformer_v2.utils import get_expand_index + + +class SmoothLeakyReLU(nn.Module): + """Smooth Leaky ReLU activation.""" + negative_slope: float = 0.2 + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + x1 = ((1 + self.negative_slope) / 2) * x + x2 = ((1 - self.negative_slope) / 2) * x * (2 * nn.sigmoid(x) - 1) + return x1 + x2 + +# --- Vector activation functions --- + + +class GateActivation(nn.Module): + """Apply gate for vector and silu for scalar.""" + lmax: int + mmax: int + num_channels: int + m_prime: bool = False + + @nn.compact + def __call__(self, gating_scalars: jax.Array, input_tensors: jax.Array) -> jax.Array: + """ + `gating_scalars`: shape [N, lmax * num_channels] + `input_tensors`: shape [N, (lmax + 1) ** 2, num_channels] + """ + expand_index = get_expand_index( + self.lmax, self.mmax, vector_only=True, m_prime=self.m_prime + ) + + gating_scalars = nn.sigmoid(gating_scalars) + gating_scalars = gating_scalars.reshape( + gating_scalars.shape[0], self.lmax, self.num_channels + )[:, expand_index] + + input_tensors_scalars = nn.silu(input_tensors[:, 0:1]) + input_tensors_vectors = input_tensors[:, 1:] * gating_scalars + output_tensors = jnp.concat( + (input_tensors_scalars, input_tensors_vectors), axis=1 + ) + + return output_tensors + + +class S2Activation(nn.Module): + """Apply silu on sphere function.""" + lmax: int + mmax: int + resolution: int + m_prime: bool = False + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + so3_grid = get_s2grid_mats( + self.lmax, self.mmax, resolution=self.resolution, m_prime=self.m_prime + ) + + x_grid = so3_grid.to_grid(inputs) + x_grid = nn.silu(x_grid) + outputs = so3_grid.from_grid(x_grid) + return outputs + + +class SeparableS2Activation(nn.Module): + """Apply silu on sphere function for vector and silu directly for scalar.""" + lmax: int + mmax: int + resolution: int + m_prime: bool = False + + @nn.compact + def __call__(self, input_scalars: jax.Array, input_tensors: jax.Array) -> jax.Array: + output_scalars = nn.silu(input_scalars) + output_tensors = S2Activation( + self.lmax, self.mmax, self.resolution, self.m_prime + )(input_tensors) + outputs = jnp.concat( + ( + output_scalars[:, None], + output_tensors[:, 1 : output_tensors.shape[1]], + ), + axis=1, + ) + return outputs diff --git a/src/mlip/models/equiformer_v2/blocks.py b/src/mlip/models/equiformer_v2/blocks.py new file mode 100644 index 0000000..c55d6bd --- /dev/null +++ b/src/mlip/models/equiformer_v2/blocks.py @@ -0,0 +1,286 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Yi-Lun Liao (https://github.com/atomicarchitects/equiformer_v2) under MIT license. + +import jax +import jax.numpy as jnp +import flax.linen as nn +from flax.linen import initializers +from flax.typing import PRNGKey + +from mlip.models.options import parse_activation +from mlip.models.equiformer_v2.utils import get_expand_index, get_mapping_coeffs +from mlip.models.equiformer_v2.transform import WignerMats + + +class MLP(nn.Module): + """MLP with layer norm.""" + + features: tuple[int, ...] + activation: str = 'silu' + use_bias: bool = True + use_layer_norm: bool = True + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + x = inputs + for i, feat in enumerate(self.features): + x = nn.Dense(feat, use_bias=self.use_bias)(x) + if i != len(self.features) - 1: + if self.use_layer_norm: + x = nn.LayerNorm()(x) + x = parse_activation(self.activation)(x) + return x + + +class SO3Linear(nn.Module): + """EquiformerV2 linear layer.""" + + lmax: int + features: int + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + expand_index = get_expand_index(self.lmax) + + kernel = self.param( + 'kernel', + initializers.lecun_normal(), + ((self.lmax + 1), inputs.shape[-1], self.features), + ) + bias = self.param('bias', initializers.zeros_init(), (self.features,)) + + kernel_expanded = kernel[expand_index] # [(L_max + 1) ** 2, C_in, C_out] + out = jnp.einsum( + "bmi, mio -> bmo", inputs, kernel_expanded + ) # [N, (L_max + 1) ** 2, C_out] + out = out.at[:, 0:1, :].add(bias.reshape(1, 1, self.features)) + + return out + + +class SO2mConvolution(nn.Module): + """ + SO(2) Conv: Perform an SO(2) convolution on features corresponding to +- m + + Args: + m (int): Order of the spherical harmonic coefficients + channels (int): Number of output channels used during the SO(2) conv + lmax (int): Maximum degree of the spherical harmonics + """ + + m: int + channels: int + lmax: int + + @nn.compact + def __call__(self, feats_m: jax.Array) -> tuple[jax.Array, jax.Array]: + num_edges = len(feats_m) + + out_channels = 2 * (self.lmax - self.m + 1) * self.channels + + feats_m = nn.Dense(out_channels, use_bias=False)(feats_m) + feats_r, feats_i = jnp.split(feats_m, 2, axis=2) + feats_m_r = feats_r[:, 0] - feats_i[:, 1] + feats_m_i = feats_r[:, 1] + feats_i[:, 0] + + return ( + feats_m_r.reshape(num_edges, -1, self.channels), + feats_m_i.reshape(num_edges, -1, self.channels), + ) + + +class SO2Convolution(nn.Module): + """ + SO(2) Block: Perform SO(2) convolutions for all m (orders) + + Args: + lmax (int): Maximum degree of the spherical harmonics + mmax (int): Maximum order of the spherical harmonics + output_channels (int): Number of output channels used during the SO(2) conv + internal_weights (bool): If True, not using radial function to multiply inputs features + edge_channels_list (list:int): List of sizes of invariant edge embedding. For example, + [hidden_channels, hidden_channels]. + extra_scalar_channels (int): If not None, return `out` and `extra_scalar_features`. + """ + lmax: int + mmax: int + output_channels: int + internal_weights: bool = True + edge_channels_list: tuple[int, ...] | None = None + extra_scalar_channels: int | None = None + + @nn.compact + def __call__( + self, + edge_feats: jax.Array, # in m primary order + edge_embeds: jax.Array | None, + ) -> jax.Array | tuple[jax.Array, jax.Array]: + num_edges = len(edge_embeds) + + m_size = get_mapping_coeffs(self.lmax, self.mmax).m_size + + # radial function + if not self.internal_weights: + assert self.edge_channels_list is not None, "`edge_channels_list` must be provided." + assert edge_embeds is not None, "`edge_embeds` must be provided." + edge_embeds = MLP( + self.edge_channels_list + (edge_feats.shape[-1] * sum(m_size),) + )(edge_embeds) + + m0_out_channels = (self.lmax + 1) * self.output_channels + if self.extra_scalar_channels is not None: + m0_out_channels += self.extra_scalar_channels + + # Compute m=0 coefficients separately since they only have real values (no imaginary) + # `feats` means `egde_feats` in the following code for simplicity + feats_0 = edge_feats[:, :m_size[0]].reshape(num_edges, -1) + if not self.internal_weights: + feats_0 *= edge_embeds[:, :feats_0.shape[-1]] + offset_rad = feats_0.shape[-1] + + feats_0 = nn.Dense(m0_out_channels)(feats_0) + + if self.extra_scalar_channels is not None: + feats_extra, feats_0 = jnp.split( + feats_0, [self.extra_scalar_channels], axis=-1 + ) + + # x[:, 0 : self.mappingReduced.m_size[0]] = feats_0 + feats_out = [feats_0.reshape(num_edges, -1, self.output_channels)] + + # Compute the values for the m > 0 coefficients + offset = m_size[0] + for m in range(1, self.mmax + 1): + # Get the m order coefficients, shape: [N, 2, m_size[m] * sphere_channels] + feats_m = edge_feats[:, offset : 2*m_size[m]+offset].reshape(num_edges, 2, -1) + offset += 2 * m_size[m] + + if not self.internal_weights: + feats_m *= edge_embeds[:, None, offset_rad : feats_m.shape[-1] + offset_rad] + offset_rad += feats_m.shape[-1] + + # x[:, offset : offset + 2 * self.mappingReduced.m_size[m]] = feats_m + feats_out.extend(SO2mConvolution(m, self.output_channels, self.lmax)(feats_m)) + + edge_feats = jnp.concat(feats_out, axis=1) + + if self.extra_scalar_channels is not None: + return edge_feats, feats_extra + return edge_feats + + +class EdgeDegreeEmbedding(nn.Module): + """ + + Args: + lmax (int): Maximum degree of the spherical harmonics + mmax (int): Maximum order of the spherical harmonics + sphere_channels (int): Number of spherical channels + edge_channels_list (list:int): List of sizes of invariant edge embedding. For example, + [hidden_channels, hidden_channels]. The last one will be used as hidden size when + `use_atom_edge_embedding` is `True`. + use_atom_edge_embedding (bool): Whether to use atomic embedding along with relative distance + for edge scalar features + num_species (int): Maximum number of atomic numbers + rescale_factor (float): Rescale the sum aggregation + """ + lmax: int + mmax: int + sphere_channels: int + edge_channels_list: tuple[int, ...] + use_atom_edge_embedding: bool = False + num_species: int | None = None + rescale_factor: float = 5.0 + + @nn.compact + def __call__( + self, + node_species: jax.Array, + edge_embeds: jax.Array, + senders: jax.Array, + receivers: jax.Array, + wigner_mats: WignerMats, + ) -> jax.Array: + num_nodes = node_species.shape[0] + + mapping_coeffs = get_mapping_coeffs(self.lmax, self.mmax) + m_size_0 = mapping_coeffs.m_size[0] + m_size_pad = mapping_coeffs.num_coefficients - m_size_0 + + if self.use_atom_edge_embedding: + assert self.num_species is not None, "num_species must be provided" + # I have changed the layer order and merged the two embedding layers into one. + node_embeds = nn.Embed( + self.num_species, 2 * self.edge_channels_list[-1], + embedding_init=initializers.normal(stddev=0.001), # Why? + )(node_species) + senders_embeds, receivers_embeds = jnp.split(node_embeds, 2, axis=-1) + edge_embeds = jnp.concat( + (edge_embeds, senders_embeds[senders], receivers_embeds[receivers]), axis=-1 + ) + + feats_m_0 = MLP( + self.edge_channels_list + (m_size_0 * self.sphere_channels,) + )(edge_embeds).reshape(-1, m_size_0, self.sphere_channels) + + feats_m_pad = jnp.zeros( + (edge_embeds.shape[0], m_size_pad, self.sphere_channels), dtype=feats_m_0.dtype + ) + # edge_feats: [n_edges, (lmax + 1) ^ 2, num_channels], m primary + edge_feats = jnp.concat((feats_m_0, feats_m_pad), axis=1) + + edge_feats = wigner_mats.rotate_inv(edge_feats) + # NOTE: In eSEN, there is a edge_envelope, however, it seems EquiFormerV2 does not use a + # edge_envelope at all. + + # Compute the sum of the incoming neighboring messages for each target node + node_feats = jax.ops.segment_sum(edge_feats, receivers, num_nodes) / self.rescale_factor + + return node_feats + + +def _drop_path(inputs: jax.Array, rate: float, rng: PRNGKey) -> jax.Array: + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + keep_prob = 1 - rate + # work with diff dim tensors, not just 2D ConvNets + shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1) + random_tensor = keep_prob + jax.random.uniform(rng, shape, dtype=inputs.dtype) + random_tensor = jnp.floor(random_tensor) # binarize + output = (inputs / keep_prob) * random_tensor + return output + + +class GraphDropPath(nn.Module): + """Consider batch for graph inputs when dropping paths.""" + + rate: float + deterministic: bool = True + + @nn.compact + def __call__( + self, + inputs: jax.Array, + n_node: jax.Array, + ) -> jax.Array: + if (self.rate == 0.0) or self.deterministic: + return inputs + + # Prevent gradient NaNs in 1.0 edge-case. + if self.rate == 1.0: + return jnp.zeros_like(inputs) + + rng = self.make_rng('dropout') + + batch_size = len(n_node) + # work with diff dim tensors, not just 2D ConvNets + shape = (batch_size,) + (1,) * (inputs.ndim - 1) + ones = jnp.ones(shape, dtype=inputs.dtype) + drop = _drop_path(ones, self.rate, rng) + + # create pyg batch from n_node + output_size = n_node.shape[0] + num_elements = inputs.shape[0] + batch = jnp.repeat(jnp.arange(output_size), n_node, total_repeat_length=num_elements) + + out = inputs * drop[batch] + return out diff --git a/src/mlip/models/equiformer_v2/config.py b/src/mlip/models/equiformer_v2/config.py new file mode 100644 index 0000000..de71433 --- /dev/null +++ b/src/mlip/models/equiformer_v2/config.py @@ -0,0 +1,87 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Yi-Lun Liao (https://github.com/atomicarchitects/equiformer_v2) under MIT license. + +import pydantic + +from mlip.models.equiformer_v2.utils import AttnActType, FeedForwardType +from mlip.typing import PositiveInt, NonNegativeInt +from mlip.models.equiformer_v2.layernorm import LayerNormType + + +class EquiformerV2Config(pydantic.BaseModel): + """The configuration / hyperparameters of the EquiformerV2 model. + + Attributes: + num_layers: Number of EquiformerV2 layers. Default is 9. + lmax: Maximum degree of the spherical harmonics (1 to 10). + mmax: Maximum order of the spherical harmonics (0 to lmax). + sphere_channels: Number of spherical channels. Default is 128. + num_edge_channels: Number of channels for the edge invariant features. Default is 128. + atom_edge_embedding: Whether to use / share atomic embedding along with relative distance. + Options are "none", "isolated" (default) and "shared". + num_rbf: Number of basis functions used in the embedding block. Default is 600. + attn_hidden_channels: Number of hidden channels used during SO(2) graph attention. Use 64 + or 96 (not necessarily). + num_heads: Number of heads in the attention block. Default is 8. + attn_alpha_channels: Number of channels for alpha vector in each attention head. + attn_value_channels: Number of channels for value vector in each attention head. + ffn_hidden_channels: Number of hidden channels used during feedforward network. + norm_type: Type of normalization layer. Options are "layer_norm", "layer_norm_sh" (default) + and "rms_norm_sh". + grid_resolution: Resolution of SO3Grid used in Activation. Examples are 18, 16, 14, None + (default, decided automatically). + use_m_share_rad: Whether all m components within a type-L vector of one channel share + radial function weights. + use_attn_renorm: Whether to re-normalize attention weights. + ff_type: Type of feedforward network. Options are "gate", "grid", "grid_sep" (default), + "s2" and "s2_sep". See :class:`~mlip.models.equiformer_v2.utils.FeedForwardType` + for its corresponding options in original EquiformerV2 repo. + attn_act_type: Type of activation function used in the attention block. Options are "gate", + "s2_sep" (default) and "s2". See + :class:`~mlip.models.equiformer_v2.utils.AttnActType` for its corresponding + options in original EquiformerV2 repo. + alpha_drop: Dropout rate for attention weights. Use 0.0 or 0.1 (default). + drop_path_rate: Graph drop path rate. Use 0.0 or 0.05 (default). + avg_num_nodes: The mean number of atoms per graph. If `None`, use the value from the + dataset info. Default is value from IS2RE (100k). + avg_num_neighbors: The mean number of neighbors for atoms. If `None`, use the value + from the dataset info. Default is value from IS2RE (100k). It is + used to rescale messages by this value. + atomic_energies: How to treat the atomic energies. If set to ``None`` (default) or the + string ``"average"``, then the average atomic energies stored in the + dataset info are used. It can also be set to the string ``"zero"`` which + means not to use any atomic energies in the model. Lastly, one can also + pass an atomic energies dictionary via this parameter different from the + one in the dataset info, that is used. + num_species: The number of elements (atomic species descriptors) allowed. + If ``None`` (default), infer the value from the atomic energies + map in the dataset info. + direct_force: Whether to predict forces using a direct force head instead of using + auto-grad. Default is `False`. + """ + + num_layers: PositiveInt = 9 + lmax: PositiveInt = 6 + mmax: NonNegativeInt = 2 + sphere_channels: PositiveInt = 128 + num_edge_channels: PositiveInt = 128 + atom_edge_embedding: str = 'isolated' + num_rbf: PositiveInt = 600 + attn_hidden_channels: PositiveInt = 64 + num_heads: PositiveInt = 8 + attn_alpha_channels: PositiveInt = 64 + attn_value_channels: PositiveInt = 16 + ffn_hidden_channels: PositiveInt = 128 + norm_type: LayerNormType = LayerNormType.LAYER_NORM_SH + grid_resolution: PositiveInt | None = 18 # Original default: None + use_m_share_rad: bool = False + use_attn_renorm: bool = True + ff_type: FeedForwardType = FeedForwardType.GRID_SEP + attn_act_type: AttnActType = AttnActType.S2_SEP + alpha_drop: float = 0.1 + drop_path_rate: float = 0.05 + avg_num_neighbors: float | None = None # Original: 23.395238876342773 (OC20) + avg_num_nodes: float | None = None # Original: 77.81317 (OC20) + atomic_energies: str | dict[int, float] | None = None + num_species: PositiveInt | None = None + direct_force: bool = False diff --git a/src/mlip/models/equiformer_v2/layernorm.py b/src/mlip/models/equiformer_v2/layernorm.py new file mode 100644 index 0000000..d142f0e --- /dev/null +++ b/src/mlip/models/equiformer_v2/layernorm.py @@ -0,0 +1,254 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Yi-Lun Liao (https://github.com/atomicarchitects/equiformer_v2) under MIT license. + +from collections.abc import Callable +from enum import Enum + +import jax +import jax.numpy as jnp +import flax.linen as nn +from flax.linen import initializers +from flax.typing import Dtype + +from mlip.models.equiformer_v2.utils import get_expand_index + + +class LayerNormArray(nn.Module): + lmax: int + eps: float = 1e-5 + affine: bool = True + normalization: str = "component" + + @nn.compact + def __call__(self, node_input: jax.Array) -> jax.Array: + """ + Assume input is of shape [N, sphere_basis, C] + """ + + if self.affine: + affine_weight = self.param( + 'affine_weight', initializers.ones, (self.lmax + 1, node_input.shape[-1]) + ) + + out = [] + + for lval in range(self.lmax + 1): + start_idx = lval**2 + length = 2 * lval + 1 + + feature = node_input[:, start_idx : start_idx+length] + + # For scalars, first compute and subtract the mean + if lval == 0: + feature -= jnp.mean(feature, axis=2, keepdims=True) + + # Then compute the rescaling factor (norm of each feature vector) + # Rescaling of the norms themselves based on the option "normalization" + if self.normalization == "norm": + feature_norm = jnp.sum(jnp.pow(feature, 2), axis=1, keepdims=True) # [N, 1, C] + elif self.normalization == "component": + feature_norm = jnp.mean(jnp.pow(feature, 2), axis=1, keepdims=True) # [N, 1, C] + else: + raise ValueError(f"Unknown normalization option: {self.normalization}") + + feature_norm = jnp.mean(feature_norm, axis=2, keepdims=True) # [N, 1, 1] + feature_norm = jnp.pow(feature_norm + self.eps, -0.5) + + if self.affine: + feature_norm *= affine_weight[None, lval:lval+1] # [N, 1, C] + + feature *= feature_norm + + if self.affine and lval == 0: + feature += self.param( + 'affine_bias', initializers.zeros, (node_input.shape[-1],) + )[None, None] + + out.append(feature) + + out = jnp.concat(out, axis=1) + return out + + +def _get_balance_degree_weight(lmax: int, dtype: Dtype, skip_l0: bool = False) -> jax.Array: + start = 1 if skip_l0 else 0 + + balance_degree_weight = jnp.zeros(((lmax + 1) ** 2 - start, 1), dtype=dtype) + for lval in range(start, lmax + 1): + start_idx = lval**2 - start + length = 2 * lval + 1 + balance_degree_weight = balance_degree_weight.at[ + start_idx : (start_idx + length), : + ].set(1.0 / length) + + return balance_degree_weight / (lmax + 1 - start) + + +class LayerNormArraySphericalHarmonics(nn.Module): + """ + 1. Normalize over L = 0. + 2. Normalize across all m components from degrees L > 0. + 3. Do not normalize separately for different L (L > 0). + """ + + lmax: int + eps: float = 1e-5 + affine: bool = True + normalization: str = "component" + std_balance_degrees: bool = True + + @nn.compact + def __call__(self, node_input: jax.Array) -> jax.Array: + """ + Assume input is of shape [N, sphere_basis, C] + """ + + # for L = 0 + feature = node_input[:, :1] + feature = nn.LayerNorm(self.eps, use_bias=self.affine, use_scale=self.affine)(feature) + + if self.lmax == 0: + return feature + + if self.affine: + affine_weight = self.param( + 'affine_weight', initializers.ones, (self.lmax, node_input.shape[-1]) + ) + + out = [feature] + + # for L > 0 + feature = node_input[:, 1:] + + # Then compute the rescaling factor (norm of each feature vector) + # Rescaling of the norms themselves based on the option "normalization" + if self.normalization == "norm": + assert not self.std_balance_degrees + feature_norm = jnp.sum( + jnp.pow(feature, 2), axis=1, keepdims=True + ) # [N, 1, C] + elif self.normalization == "component": + if self.std_balance_degrees: + balance_degree_weight = _get_balance_degree_weight( + self.lmax, node_input.dtype, skip_l0=True + ) + # [N, (L_max + 1)**2 - 1, C], without L = 0 + feature_norm = jnp.einsum( + "nic, ia -> nac", jnp.pow(feature, 2), balance_degree_weight + ) # [N, 1, C] + else: + feature_norm = jnp.mean( + jnp.pow(feature, 2), axis=1, keepdims=True + ) # [N, 1, C] + else: + raise ValueError(f"Unknown normalization option: {self.normalization}") + + feature_norm = jnp.mean(feature_norm, axis=2, keepdims=True) # [N, 1, 1] + feature_norm = jnp.pow(feature_norm + self.eps, -0.5) + + for lval in range(1, self.lmax + 1): + start_idx = lval**2 + length = 2 * lval + 1 + # [N, (2L + 1), C] + feature = node_input[:, start_idx : start_idx+length] + feature_scale = feature_norm + if self.affine: + feature_scale *= affine_weight[None, lval-1:lval] # [N, 1, C] + out.append(feature * feature_scale) + + out = jnp.concat(out, axis=1) + return out + + +class RMSNormArraySphericalHarmonicsV2(nn.Module): + """ + 1. Normalize across all m components from degrees L >= 0. + 2. Expand weights and multiply with normalized feature to prevent slicing and concatenation. + """ + + lmax: int + eps: float = 1e-5 + affine: bool = True + normalization: str = "component" + centering: bool = True + std_balance_degrees: bool = True + + @nn.compact + def __call__(self, node_input: jax.Array) -> jax.Array: + """ + Assume input is of shape [N, sphere_basis, C] + """ + feature = node_input + + if self.centering: + feature_l0 = feature[:, 0:1] + feature_l0_mean = jnp.mean(feature_l0, axis=2, keepdims=True) # [N, 1, 1] + feature = jnp.concat( + (feature_l0 - feature_l0_mean, feature[:, 1:feature.shape[1]]), axis=1 + ) + + # for L >= 0 + if self.normalization == "norm": + assert not self.std_balance_degrees + feature_norm = jnp.sum( + jnp.pow(feature, 2), axis=1, keepdims=True + ) # [N, 1, C] + elif self.normalization == "component": + if self.std_balance_degrees: + balance_degree_weight = _get_balance_degree_weight( + self.lmax, node_input.dtype, skip_l0=False + ) + feature_norm = jnp.einsum( + "nic, ia -> nac", jnp.pow(feature, 2), balance_degree_weight + ) # [N, 1, C] + else: + feature_norm = jnp.mean( + jnp.pow(feature, 2), axis=1, keepdims=True + ) # [N, 1, C] + else: + raise ValueError(f"Unknown normalization option: {self.normalization}") + + feature_norm = jnp.mean(feature_norm, axis=2, keepdims=True) # [N, 1, 1] + feature_norm = jnp.pow(feature_norm + self.eps, -0.5) + + if self.affine: + feature_norm *= self.param( + 'affine_weight', initializers.ones, (self.lmax + 1, node_input.shape[-1]) + )[None, get_expand_index(self.lmax)] # [N, (L_max + 1)**2, C] + + out = feature * feature_norm + + if self.affine and self.centering: + out = out.at[:, 0:1, :].add(self.param( + 'affine_bias', initializers.zeros, (node_input.shape[-1],) + )[None, None]) + + return out + + +# --- Normalization options --- + + +class LayerNormType(Enum): + """Options for the LayerNorm of the EquiformerV2 model.""" + + LAYER_NORM = "layer_norm" + LAYER_NORM_SH = "layer_norm_sh" + RMS_NORM_SH = "rms_norm_sh" + + +def parse_layernorm( + norm_type: LayerNormType | str, + lmax: int, + eps: float = 1e-5, + affine: bool = True, + normalization: str = "component", +) -> Callable: + assert normalization in ["norm", "component"] + norm_type_map = { + LayerNormType.LAYER_NORM: LayerNormArray, + LayerNormType.LAYER_NORM_SH: LayerNormArraySphericalHarmonics, + LayerNormType.RMS_NORM_SH: RMSNormArraySphericalHarmonicsV2, + } + norm_class = norm_type_map[LayerNormType(norm_type)] + return norm_class(lmax, eps, affine, normalization) diff --git a/src/mlip/models/equiformer_v2/models.py b/src/mlip/models/equiformer_v2/models.py new file mode 100644 index 0000000..211fdf5 --- /dev/null +++ b/src/mlip/models/equiformer_v2/models.py @@ -0,0 +1,432 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Yi-Lun Liao (https://github.com/atomicarchitects/equiformer_v2) under MIT license. + +import flax.linen as nn +import jax +import jax.numpy as jnp + +from mlip.data.dataset_info import DatasetInfo +from mlip.models.mlip_network import MLIPNetwork +from mlip.models.atomic_energies import get_atomic_energies +from mlip.models.radial_basis import GaussianBasis +from mlip.models.equiformer_v2.config import EquiformerV2Config +from mlip.models.equiformer_v2.blocks import ( + SO3Linear, EdgeDegreeEmbedding, GraphDropPath +) +from mlip.models.equiformer_v2.layernorm import LayerNormType, parse_layernorm +from mlip.models.equiformer_v2.transform import ( + WignerMats, get_wigner_mats +) +from mlip.models.equiformer_v2.transformer_block import ( + SO2EquivariantGraphAttention, + FeedForwardNetwork, +) +from mlip.models.equiformer_v2.utils import ( + AttnActType, + FeedForwardType, + get_mapping_coeffs, +) +from mlip.utils.safe_norm import safe_norm + + +class EquiformerV2(MLIPNetwork): + """The EquiformerV2 model flax module. It is derived from the + :class:`~mlip.models.mlip_network.MLIPNetwork` class. + + References: + * Yi-Lun Liao, Brandon Wood, Abhishek Das and Tess Smidt. EquiformerV2: + Improved Equivariant Transformer for Scaling to Higher-Degree + Representations. International Conference on Learning Representations (ICLR), + January 2024. URL: https://openreview.net/forum?id=mCOBKZmrzD. + + Attributes: + config: Hyperparameters / configuration for the EquiformerV2 model, see + :class:`~mlip.models.equiformer_v2.config.EquiformerV2Config`. + dataset_info: Hyperparameters dictated by the dataset + (e.g., cutoff radius or average number of neighbors). + """ + + + Config = EquiformerV2Config + + config: EquiformerV2Config + dataset_info: DatasetInfo + + @nn.compact + def __call__( + self, + edge_vectors: jax.Array, + node_species: jax.Array, + senders: jax.Array, + receivers: jax.Array, + n_node: jax.Array | None = None, # For dropout, can be None for eval + training: bool = False, + ) -> jax.Array: + r_max = self.dataset_info.cutoff_distance_angstrom + + avg_num_neighbors = self.config.avg_num_neighbors + if avg_num_neighbors is None: + avg_num_neighbors = self.dataset_info.avg_num_neighbors + + avg_num_nodes = self.config.avg_num_nodes + if avg_num_nodes is None: + avg_num_nodes = self.dataset_info.avg_num_nodes + + num_species = self.config.num_species + if num_species is None: + num_species = len(self.dataset_info.atomic_energies_map) + + equiformer_kargs = dict( + avg_num_neighbors=avg_num_neighbors, + num_layers=self.config.num_layers, + lmax=self.config.lmax, + mmax=self.config.mmax, + sphere_channels=self.config.sphere_channels, + num_edge_channels=self.config.num_edge_channels, + atom_edge_embedding=self.config.atom_edge_embedding, + num_rbf=self.config.num_rbf, + attn_hidden_channels=self.config.attn_hidden_channels, + num_heads=self.config.num_heads, + attn_alpha_channels=self.config.attn_alpha_channels, + attn_value_channels=self.config.attn_value_channels, + ffn_hidden_channels=self.config.ffn_hidden_channels, + norm_type=self.config.norm_type, + grid_resolution=self.config.grid_resolution, + use_m_share_rad=self.config.use_m_share_rad, + use_attn_renorm=self.config.use_attn_renorm, + attn_act_type=self.config.attn_act_type, + ff_type=self.config.ff_type, + alpha_drop=self.config.alpha_drop, + drop_path_rate=self.config.drop_path_rate, + avg_num_nodes=avg_num_nodes, + cutoff=r_max, + num_species=num_species, + direct_force=self.config.direct_force, + deterministic=not training, + ) + + equiformer_model = EquiformerV2Block(**equiformer_kargs) + backbone_outputs = equiformer_model( + edge_vectors, node_species, senders, receivers, n_node + ) + + if self.config.direct_force: + node_energies, forces = backbone_outputs + else: + node_energies = backbone_outputs + + mean = self.dataset_info.scaling_mean + std = self.dataset_info.scaling_stdev + node_energies = mean + std * node_energies + + atomic_energies_ = get_atomic_energies( + self.dataset_info, self.config.atomic_energies, num_species + ) + atomic_energies_ = jnp.asarray(atomic_energies_) + node_energies += atomic_energies_[node_species] # [n_nodes, ] + + if self.config.direct_force: + return jnp.concat([node_energies[:, None], std * forces], axis=-1) + return node_energies + + +class EquiformerV2Block(nn.Module): + """ + Equiformer with graph attention built upon SO(2) convolution and feedforward network built upon + S2 activation. + """ + + avg_num_neighbors: float + num_layers: int + lmax: int + mmax: int + sphere_channels: int + num_species: int + num_edge_channels: int + atom_edge_embedding: str + attn_hidden_channels: int + num_heads: int + attn_alpha_channels: int + attn_value_channels: int + ffn_hidden_channels: int + norm_type: LayerNormType + grid_resolution: int + use_m_share_rad: bool + use_attn_renorm: bool + attn_act_type: AttnActType + ff_type: FeedForwardType + alpha_drop: float + drop_path_rate: float + avg_num_nodes: float + num_rbf: int = 600 + cutoff: float = 5.0 + direct_force: bool = False + deterministic: bool = True + + def setup(self): + # Weights for message initialization + self.sphere_embedding = nn.Embed(self.num_species, self.sphere_channels) + + # Function used to measure the distances between atoms + self.distance_expansion = GaussianBasis( + self.cutoff, + self.num_rbf, + trainable=False, + rbf_width=2.0, + ) + + # Sizes of radial functions (2 hidden channels, input and output are ignored) + edge_channels_list = [self.num_edge_channels] * 2 + + # Atom edge embedding + self.edge_embedding = None + if self.atom_edge_embedding == 'shared': + self.edge_embedding = nn.Embed(self.num_species, 2 * self.num_edge_channels) + + # Edge-degree embedding + self.edge_degree_embedding = EdgeDegreeEmbedding( + self.lmax, + self.mmax, + self.sphere_channels, + edge_channels_list, + self.atom_edge_embedding == 'isolated', + num_species=self.num_species, + rescale_factor=self.avg_num_neighbors, + ) + + # Initialize the blocks for each layer of EquiformerV2 + self.layers = [ + EquiformerV2Layer( + self.lmax, + self.mmax, + self.grid_resolution, + self.sphere_channels, + self.attn_hidden_channels, + self.num_heads, + self.attn_alpha_channels, + self.attn_value_channels, + self.ffn_hidden_channels, + self.sphere_channels, + self.num_species, + edge_channels_list, + self.atom_edge_embedding == 'isolated', + self.use_m_share_rad, + self.use_attn_renorm, + self.attn_act_type, + self.ff_type, + self.norm_type, + self.alpha_drop, + self.drop_path_rate, + self.deterministic, + ) + for _ in range(self.num_layers) + ] + + # Output blocks for energy and forces + self.norm = parse_layernorm(self.norm_type, self.lmax) + self.energy_block = FeedForwardNetwork( + self.lmax, + self.ffn_hidden_channels, + 1, + self.grid_resolution, + self.ff_type, + ) + if self.direct_force: + self.force_block = SO2EquivariantGraphAttention( + self.lmax, + self.mmax, + self.grid_resolution, + self.attn_hidden_channels, + self.num_heads, + self.attn_alpha_channels, + self.attn_value_channels, + 1, + self.num_species, + edge_channels_list, + self.atom_edge_embedding == 'isolated', + self.use_m_share_rad, + self.use_attn_renorm, + self.attn_act_type, + alpha_drop=0.0, + deterministic=self.deterministic, + ) + + def __call__( + self, + edge_vectors: jax.Array, # [n_edges, 3] + node_species: jax.Array, # [n_nodes] int between 0 and num_species-1 + senders: jax.Array, # [n_edges] + receivers: jax.Array, # [n_edges] + n_node: jax.Array, # [batch_size] + ) -> jax.Array | tuple[jax.Array, jax.Array]: + num_atoms = len(node_species) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + if not self.deterministic: + rng = self.make_rng('rotation') + rot_gamma = jax.random.uniform( + rng, shape=len(edge_vectors), maxval=2 * jnp.pi, + dtype=edge_vectors.dtype + ) + else: + rot_gamma = jnp.zeros(len(edge_vectors), dtype=edge_vectors.dtype) + + mapping_coeffs = get_mapping_coeffs(self.lmax, self.mmax) + wigner_mats = get_wigner_mats( + self.lmax, self.mmax, edge_vectors, rot_gamma, mapping_coeffs.perm + ) + + # Initialize the l = 0, m = 0 coefficients + node_feats_0 = self.sphere_embedding(node_species)[:, None] + node_feats_m_pad = jnp.zeros( + [num_atoms, (self.lmax + 1) ** 2 - 1, node_feats_0.shape[-1]], + dtype=edge_vectors.dtype, + ) + node_feats = jnp.concat((node_feats_0, node_feats_m_pad), axis=1) + + # Edge encoding (distance and atom edge) + edge_distances = safe_norm(edge_vectors, axis=-1) + edge_embeds = self.distance_expansion(edge_distances) + if self.edge_embedding is not None: + node_embeds = self.edge_embedding(node_species) + senders_embeds, receivers_embeds = jnp.split(node_embeds, 2, axis=-1) + edge_embeds = jnp.concat( + (edge_embeds, senders_embeds[senders], receivers_embeds[receivers]), axis=1 + ) + + # Edge-degree embedding + edge_degree = self.edge_degree_embedding( + node_species, edge_embeds, senders, receivers, wigner_mats + ) + node_feats = node_feats + edge_degree + + for layer in self.layers: + node_feats = layer( + node_feats, + node_species, + edge_embeds, + senders, + receivers, + wigner_mats, + n_node=n_node, # for GraphDropPath + ) + + # Final layer norm + node_feats = self.norm(node_feats) + + node_energies = self.energy_block(node_feats) + node_energies = node_energies[:, 0, 0] / self.avg_num_nodes + + if self.direct_force: + forces = self.force_block( + node_feats, + node_species, + edge_embeds, + senders, + receivers, + wigner_mats, + ) + return node_energies, forces[:, 1:4, 0] + + return node_energies + + +class EquiformerV2Layer(nn.Module): + lmax: int + mmax: int + resolution: int + sphere_channels: int + attn_hidden_channels: int + num_heads: int + attn_alpha_channels: int + attn_value_channels: int + ffn_hidden_channels: int + output_channels: int + num_species: int + edge_channels_list: tuple[int, ...] + use_atom_edge_embedding: bool = True + use_m_share_rad: bool = False + use_attn_renorm: bool = True + attn_act_type: AttnActType = AttnActType.S2_SEP + ff_type: FeedForwardType = FeedForwardType.GRID_SEP + norm_type: LayerNormType = LayerNormType.RMS_NORM_SH + alpha_drop: float = 0.0 + drop_path_rate: float = 0.0 + deterministic: bool = False # Randomness of rotation matrix + + def setup(self): + self.norm_1 = parse_layernorm(self.norm_type, self.lmax) + + self.graph_attn = SO2EquivariantGraphAttention( + self.lmax, + self.mmax, + self.resolution, + hidden_channels=self.attn_hidden_channels, + num_heads=self.num_heads, + attn_alpha_channels=self.attn_alpha_channels, + attn_value_channels=self.attn_value_channels, + output_channels=self.sphere_channels, + num_species=self.num_species, + edge_channels_list=self.edge_channels_list, + use_atom_edge_embedding=self.use_atom_edge_embedding, + use_m_share_rad=self.use_m_share_rad, + use_attn_renorm=self.use_attn_renorm, + attn_act_type=self.attn_act_type, + alpha_drop=self.alpha_drop, + deterministic=self.deterministic, + ) + + self.drop_path = GraphDropPath( + self.drop_path_rate, self.deterministic + ) if self.drop_path_rate > 0.0 else None + + self.norm_2 = parse_layernorm(self.norm_type, self.lmax) + + self.ffn = FeedForwardNetwork( + self.lmax, + hidden_channels=self.ffn_hidden_channels, + output_channels=self.output_channels, + resolution=self.resolution, + ff_type=self.ff_type, + ) + + self.ffn_shortcut = None + if self.sphere_channels != self.output_channels: + self.ffn_shortcut = SO3Linear(self.lmax, self.output_channels) + + def __call__( + self, + node_feats: jax.Array, + node_species: jax.Array, + edge_embeds: jax.Array, + senders: jax.Array, + receivers: jax.Array, + wigner_mats: WignerMats, + n_node: jax.Array | None = None, + ) -> jax.Array: + # Attention block + node_feats_res = node_feats + node_feats = self.norm_1(node_feats) + node_feats = self.graph_attn( + node_feats, node_species, edge_embeds, senders, receivers, wigner_mats + ) + + if self.drop_path is not None: + node_feats = self.drop_path(node_feats, n_node) + + node_feats = node_feats + node_feats_res + + # FFN block + node_feats_res = node_feats + node_feats = self.norm_2(node_feats) + node_feats = self.ffn(node_feats) + + if self.drop_path is not None: + node_feats = self.drop_path(node_feats, n_node) + + if self.ffn_shortcut is not None: + node_feats_res = self.ffn_shortcut(node_feats_res) + + node_feats = node_feats + node_feats_res + + return node_feats diff --git a/src/mlip/models/equiformer_v2/transform.py b/src/mlip/models/equiformer_v2/transform.py new file mode 100644 index 0000000..bef481e --- /dev/null +++ b/src/mlip/models/equiformer_v2/transform.py @@ -0,0 +1,315 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Yi-Lun Liao (https://github.com/atomicarchitects/equiformer_v2) under MIT license. + +""" +In e3nn @ 0.4.0, the Wigner-D matrix is computed using Jd, while in e3nn @ 0.5.0, +it is computed using generators and matrix_exp causing a significant slowdown. +However, in e3nn_jax, `_wigner_D_from_angles` uses Jd for l <= 11 and matrix_exp +for l > 11, so it is well-optimized and there is no need for reimplement. +""" + +from functools import cache + +from e3nn_jax._src.J import Jd +from e3nn_jax._src.s2grid import ( + _spherical_harmonics_s2grid, _normalization, _expand_matrix, _rollout_sh +) +from flax.typing import Dtype +from flax.struct import dataclass +import jax +import jax.numpy as jnp + +from mlip.models.equiformer_v2.utils import get_order_mask, get_rescale_mat, get_mapping_coeffs + + +def _chebyshev(cos_x: jax.Array, sin_x: jax.Array, lmax: int) -> tuple[jax.Array, jax.Array]: + """Calculate cos(nx) and sin(nx) using Chebyshev polynomials. + + Args: + cos_x (jax.Array): Cosine of the angle x. + sin_x (jax.Array): Sine of the angle x. + lmax (int): Maximum degree of the representation. + + Returns: + Tuple of arrays (cos_nx, sin_nx) with shape of (..., lmax) and (..., lmax). + """ + + if lmax == 1: + return cos_x[..., None], sin_x[..., None] + + cos_2x = 2 * cos_x * cos_x - 1 + sin_2x = 2 * cos_x * sin_x + + if lmax == 2: + return jnp.stack([cos_x, cos_2x], axis=-1), jnp.stack([sin_x, sin_2x], axis=-1) + + init_carry = (jnp.stack([cos_2x, sin_2x]), jnp.stack([cos_x, sin_x])) + + def body(carry, _): + prev, prev2 = carry + out = 2 * cos_x * prev - prev2 + carry = (out, prev) + return carry, out + + _, results = jax.lax.scan(body, init_carry, length=lmax - 2) + results = results.transpose(*range(1, len(results.shape)), 0) + cos_all = jnp.concat([cos_x[..., None], cos_2x[..., None], results[0]], axis=-1) + sin_all = jnp.concat([sin_x[..., None], sin_2x[..., None], results[1]], axis=-1) + + return cos_all, sin_all + + +def _rot_y(cos_x: jax.Array, sin_x: jax.Array, lmax: int) -> list[jax.Array]: + """Rotational matrix around y-axis by angle phi. + + Args: + cos_x (jax.Array): Cosine of the angle. + sin_x (jax.Array, optional): Sine of the angle. + lmax (int): Maximum degree of representation to return. + """ + cos_all, sin_all = _chebyshev(cos_x, sin_x, lmax) + cos_all = jnp.concat([cos_all[..., ::-1], jnp.ones_like(cos_x)[..., None], cos_all], axis=-1) + sin_all = jnp.concat([sin_all[..., ::-1], jnp.zeros_like(sin_x)[..., None], -sin_all], axis=-1) + + rot_mat_list = [] + for l in range(lmax + 1): + rot_mat = jnp.zeros(cos_x.shape + (2 * l + 1, 2 * l + 1), dtype=cos_x.dtype) + inds = jnp.arange(0, 2 * l + 1, 1) + rev_inds = jnp.arange(2 * l, -1, -1) + rot_mat = rot_mat.at[..., inds, rev_inds].set(sin_all[..., lmax-l:lmax+l+1]) + rot_mat = rot_mat.at[..., inds, inds].set(cos_all[..., lmax-l:lmax+l+1]) + rot_mat_list.append(rot_mat) + + return rot_mat_list + + +def _wigner_d_from_angles( + alpha: tuple[jax.Array, jax.Array], + beta: tuple[jax.Array, jax.Array], + gamma: tuple[jax.Array, jax.Array], + lmax: int, +) -> list[jax.Array]: + r"""The Wigner-D matrix of the real irreducible representations of :math:`SO(3)`. + + Args: + + alpha (jax.Array): Cosine and sine of the first Euler angle. + beta (jax.Array): Cosine and sine of the second Euler angle. + gamma (jax.Array): Cosine and sine of the third Euler angle. + lmax (int): The representation order of the irrep. + + Returns: + List of Wigner-D matrices from 0 to lmax. + """ + + alpha_mats = _rot_y(alpha[0], alpha[1], lmax) + beta_mats = _rot_y(beta[0], beta[1], lmax) + gamma_mats = _rot_y(gamma[0], gamma[1], lmax) + + mats = [] + for l, (a, b, c) in enumerate(zip(alpha_mats, beta_mats, gamma_mats)): + if l < len(Jd): + j = Jd[l].astype(b.dtype) + b = j @ b @ j + else: + # TODO(bhcao): implement Wigner-D for l > 11 + # x = generators(l) + # b = jax.scipy.linalg.expm(b.astype(x.dtype) * x[0]).astype(b.dtype) + raise NotImplementedError("Wigner-D not implemented for l > 11") + mats.append(a @ b @ c) + + return mats + + +def _xyz_to_angles(xyz: jax.Array): + r"""The rotation :math:`R(\alpha, \beta, 0)` such that :math:`\vec r = R \vec e_y`. + + .. math:: + \vec r = R(\alpha, \beta, 0) \vec e_y + \alpha = \arctan(x/z) + \beta = \arccos(y) + + Args: + xyz (`jax.Array`): array of shape :math:`(..., 3)` + + Returns: + (tuple): tuple of `(\cos(\alpha), \sin(\alpha))`, `(\cos(\beta), \sin(\beta))`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + xyz2 = x**2 + y**2 + z**2 + len_xyz = jnp.sqrt(jnp.where(xyz2 > 1e-16, xyz2, 1e-16)) + xz2 = x**2 + z**2 + len_xz = jnp.sqrt(jnp.where(xz2 > 1e-16, xz2, 1e-16)) + + sin_alpha = jnp.clip(x / len_xz, -1, 1) + cos_alpha = jnp.clip(z / len_xz, -1, 1) + + sin_beta = jnp.clip(len_xz / len_xyz, 0, 1) + cos_beta = jnp.clip(y / len_xyz, -1, 1) + + return (cos_alpha, sin_alpha), (cos_beta, sin_beta) + + +def _get_s2grid_mat( + lmax: int, + res_beta: int, + res_alpha: int, + *, + dtype: Dtype = jnp.float32, + normalization: str = "integral", +) -> tuple[jax.Array, jax.Array]: + r"""Modified `e3nn_jax._src.s2grid.to_s2grid` and `e3nn_jax._src.s2grid.from_s2grid` to act + like `e3nn.o3.ToS2Grid` and `e3nn.o3.FromS2Grid`. + + Args: + lmax (int): Maximum degree of the spherical harmonics + res_beta (int): Number of points on the sphere in the :math:`\theta` direction + res_alpha (int): Number of points on the sphere in the :math:`\phi` direction + normalization ({'norm', 'component', 'integral'}): Normalization of the basis + + Returns: + (to_grid_mat, from_grid_mat): + Transform matrix from irreps to spherical grid and its inverse. + """ + _, _, sh_y, sha, qw = _spherical_harmonics_s2grid( + lmax, res_beta, res_alpha, quadrature="soft", dtype=dtype + ) + # sh_y: (res_beta, l, |m|) + sh_y = _rollout_sh(sh_y, lmax) + + m = jnp.asarray(_expand_matrix(range(lmax + 1)), dtype) # [l, m, i] + + # construct to_grid_mat + n_to = _normalization(lmax, normalization, dtype, "to_s2") + sh_y_to = jnp.einsum("lmj,bj,lmi,l->mbi", m, sh_y, m, n_to) # [m, b, i] + to_grid_mat = jnp.einsum("mbi,am->bai", sh_y_to, sha) # [beta, alpha, i] + + # construct from_grid_mat + n_from = _normalization(lmax, normalization, dtype, "from_s2", lmax) + sh_y_from = jnp.einsum("lmj,bj,lmi,l,b->mbi", m, sh_y, m, n_from, qw) # [m, b, i] + from_grid_mat = jnp.einsum("mbi,am->bai", sh_y_from, sha / res_alpha) # [beta, alpha, i] + return to_grid_mat, from_grid_mat + + +# There is no need to promote the dtype because it is determined by the input. +@dataclass +class WignerMats: + """Wigner-D matrix""" + + wigner: jax.Array + wigner_inv: jax.Array + + def rotate(self, embedding): + """Rotate the embedding, l primary -> m primary.""" + return jnp.matmul(self.wigner, embedding) + + def rotate_inv(self, embedding): + """Rotate the embedding by the inverse of rotation matrix, m primary -> to l primary.""" + return jnp.matmul(self.wigner_inv, embedding) + + +def get_wigner_mats( + lmax: int, + mmax: int, + xyz: jax.Array, + gamma: jax.Array, + perm: jax.Array, + scale: bool = True, +) -> WignerMats: + """ + Init the Wigner-D matrix for given euler angles. For continuity of derivatives, `alpha` + and `beta` are implicitly calculated through given `xyz`. Mathematically, it is + equivalent to calculate `alpha, beta = xyz_to_angles(xyz)`. + """ + mask = get_order_mask(lmax, mmax) + # Compute the re-scaling for rotating back to original frame + if scale: + rotate_inv_rescale = jnp.asarray(get_rescale_mat(lmax, mmax, dim=2), dtype=xyz.dtype) + rotate_inv_rescale = rotate_inv_rescale[None, :, mask] + + alpha, beta = _xyz_to_angles(xyz) + gamma = jnp.cos(gamma), jnp.sin(gamma) + blocks = _wigner_d_from_angles(alpha, beta, gamma, lmax) + + # Cache the Wigner-D matrices + size = (lmax + 1) ** 2 + wigner_inv = jnp.zeros([len(xyz), size, size], dtype=xyz.dtype) + start = 0 + for i, block in enumerate(blocks): + end = start + block.shape[1] + wigner_inv = wigner_inv.at[:, start:end, start:end].set((-1) ** i * block) + start = end + + # Mask the output to include only modes with m < mmax + wigner_inv = wigner_inv[:, :, mask] + wigner = wigner_inv.transpose((0, 2, 1)) + + if scale: + wigner_inv *= rotate_inv_rescale + + wigner = wigner[:, perm, :] + wigner_inv = wigner_inv[:, :, perm] + + return WignerMats(wigner, wigner_inv) + + +@dataclass +class S2GridMats: + """Scaled S2 grid matrix""" + + to_grid_mat: jax.Array + from_grid_mat: jax.Array + + def to_grid(self, embedding: jax.Array) -> jax.Array: + """Compute grid from irreps representation""" + to_grid_mat = jnp.asarray(self.to_grid_mat, dtype=embedding.dtype) + grid = jnp.einsum("bai, zic -> zbac", to_grid_mat, embedding) + return grid + + def from_grid(self, grid: jax.Array) -> jax.Array: + """Compute irreps from grid representation""" + from_grid_mat = jnp.asarray(self.from_grid_mat, dtype=grid.dtype) + embedding = jnp.einsum("bai, zbac -> zic", from_grid_mat, grid) + return embedding + + +@cache +def get_s2grid_mats( + lmax: int, + mmax: int, + normalization: str = "component", + resolution: int | None = None, + m_prime: bool = False, +) -> S2GridMats: + """Create the S2Grid matrix for given lmax and mmax.""" + mask = get_order_mask(lmax, mmax) + + if resolution is not None: + lat_resolution = resolution + long_resolution = resolution + else: + lat_resolution = 2 * (lmax + 1) + long_resolution = 2 * (mmax + 1 if lmax == mmax else mmax) + 1 + + # rescale last dimension based on mmax + rescale_matrix = get_rescale_mat(lmax, mmax) + + to_grid_mat, from_grid_mat = _get_s2grid_mat( + lmax, + lat_resolution, + long_resolution, + normalization=normalization, + ) + to_grid_mat = (to_grid_mat * rescale_matrix)[:, :, mask] + from_grid_mat = (from_grid_mat * rescale_matrix)[:, :, mask] + + if m_prime: + # This will be reused by lru_cache. + perm = get_mapping_coeffs(lmax, mmax).perm + to_grid_mat = to_grid_mat[:, :, perm] + from_grid_mat = from_grid_mat[:, :, perm] + + return S2GridMats(to_grid_mat, from_grid_mat) diff --git a/src/mlip/models/equiformer_v2/transformer_block.py b/src/mlip/models/equiformer_v2/transformer_block.py new file mode 100644 index 0000000..1039b9c --- /dev/null +++ b/src/mlip/models/equiformer_v2/transformer_block.py @@ -0,0 +1,256 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Yi-Lun Liao (https://github.com/atomicarchitects/equiformer_v2) under MIT license. + +import jax +import jax.numpy as jnp +import flax.linen as nn +from flax.linen import initializers + +from mlip.models.equiformer_v2.transform import ( + WignerMats, + get_s2grid_mats, +) +from mlip.models.equiformer_v2.activations import ( + SmoothLeakyReLU, + GateActivation, + S2Activation, + SeparableS2Activation, +) +from mlip.models.equiformer_v2.blocks import SO3Linear, MLP, SO2Convolution +from mlip.models.equiformer_v2.utils import ( + AttnActType, + FeedForwardType, + get_expand_index, + pyg_softmax, +) + + +class SO2EquivariantGraphAttention(nn.Module): + """ + SO2EquivariantGraphAttention: Perform MLP attention + non-linear message passing + SO(2) Convolution with radial function -> S2 Activation -> SO(2) Convolution -> attention + weights and non-linear messages attention weights * non-linear messages -> Linear + + Args: + lmax (int): Maximum degree of spherical harmonics + mmax (int): Maximum degree of spherical harmonics + resolution (int): Resolution of the spherical grid + hidden_channels (int): Number of hidden channels used during the SO(2) conv + num_heads (int): Number of attention heads + attn_alpha_channels (int): Number of channels for alpha vector in each attention head + attn_value_channels (int): Number of channels for value vector in each attention head + output_channels (int): Number of output channels + num_species (int): Maximum number of atomic numbers + edge_channels_list (list:int): List of sizes of invariant edge embedding. For example, + [input_channels, hidden_channels, hidden_channels]. The last one will be used as hidden + size when `use_atom_edge_embedding` is `True`. + use_atom_edge_embedding (bool): Whether to use atomic embedding along with relative + distance for edge scalar features + use_m_share_rad (bool): Whether all m components within a type-L vector of one channel + share radial function weights + use_attn_renorm (bool): Whether to re-normalize attention weights + attn_act_type (AttnActType): Type of attention activation function + alpha_drop (float): Dropout rate for attention weights + """ + + lmax: int + mmax: int + resolution: int + hidden_channels: int + num_heads: int + attn_alpha_channels: int + attn_value_channels: int + output_channels: int + num_species: int + edge_channels_list: tuple[int, ...] + use_atom_edge_embedding: bool = True + use_m_share_rad: bool = False + use_attn_renorm: bool = True + attn_act_type: AttnActType = AttnActType.S2_SEP + alpha_drop: float = 0.0 + deterministic: bool = True + + @nn.compact + def __call__( + self, + node_feats: jax.Array, + node_species: jax.Array, + edge_embeds: jax.Array, + senders: jax.Array, + receivers: jax.Array, + wigner_mats: WignerMats, + ) -> jax.Array: + num_nodes = node_feats.shape[0] + # Compute edge scalar features (invariant to rotations) + # Uses atomic numbers and edge distance expansion as inputs + if self.use_atom_edge_embedding: + assert self.num_species is not None, "num_species must be provided" + node_embeds = nn.Embed( + self.num_species, 2 * self.edge_channels_list[-1], + embedding_init=initializers.normal(stddev=0.001), # Why? + )(node_species) + senders_embeds, receivers_embeds = jnp.split(node_embeds, 2, axis=-1) + edge_embeds = jnp.concat( + (edge_embeds, senders_embeds[senders], receivers_embeds[receivers]), axis=-1 + ) + + messages = jnp.concat((node_feats[senders], node_feats[receivers]), axis=2) + + # radial function (scale all m components within a type-L vector of one channel + # with the same weight) + if self.use_m_share_rad: + edge_embeds_weight = MLP( + self.edge_channels_list + (messages.shape[-1] * (self.lmax + 1),) + )(edge_embeds).reshape(-1, (self.lmax + 1), messages.shape[-1]) + # [E, (L_max + 1) ** 2, C] + messages *= edge_embeds_weight[:, get_expand_index(self.lmax)] + + # Rotate the irreps to align with the edge, get m primary + messages = wigner_mats.rotate(messages) + + # First SO(2)-convolution + alpha_channels = self.num_heads * self.attn_alpha_channels + if self.attn_act_type == AttnActType.GATE: + extra_scalar_channels = alpha_channels + self.lmax * self.hidden_channels + elif self.attn_act_type == AttnActType.S2_SEP: + extra_scalar_channels = alpha_channels + self.hidden_channels + else: + extra_scalar_channels = alpha_channels + + messages, scalar_extra = SO2Convolution( + self.lmax, + self.mmax, + self.hidden_channels, + internal_weights=self.use_m_share_rad, + edge_channels_list=( + self.edge_channels_list if not self.use_m_share_rad else None + ), + # for attention weights and/or gate activation + extra_scalar_channels=extra_scalar_channels + )(messages, edge_embeds) + + # Activation + if self.attn_act_type == AttnActType.GATE: + # Gate activation + scalar_alpha, scalar_gating = jnp.split(scalar_extra, [alpha_channels], axis=-1) + messages = GateActivation( + self.lmax, self.mmax, self.hidden_channels, m_prime=True + )(scalar_gating, messages) + + elif self.attn_act_type == AttnActType.S2_SEP: + scalar_alpha, scalar_gating = jnp.split(scalar_extra, [alpha_channels], axis=-1) + messages = SeparableS2Activation( + self.lmax, self.mmax, self.resolution, m_prime=True + )(scalar_gating, messages) + + else: + scalar_alpha = scalar_extra + messages = S2Activation( + self.lmax, self.mmax, self.resolution, m_prime=True + )(messages) + # x_message._grid_act(self.so3_grid, self.value_act, self.mappingReduced) + + # Second SO(2)-convolution + messages = SO2Convolution( + self.lmax, self.mmax, self.num_heads * self.attn_value_channels + )(messages, edge_embeds) + + # Attention weights + scalar_alpha = scalar_alpha.reshape(-1, self.num_heads, self.attn_alpha_channels) + if self.use_attn_renorm: + scalar_alpha = nn.LayerNorm()(scalar_alpha) + scalar_alpha = SmoothLeakyReLU()(scalar_alpha) + + # torch_geometric.nn.inits.glorot(self.alpha_dot) # Following GATv2 + alpha = jnp.einsum("bik, ki -> bi", scalar_alpha, self.param( + 'alpha_dot', + initializers.lecun_normal(), + (self.attn_alpha_channels, self.num_heads), + )) + alpha = pyg_softmax(alpha, receivers, num_nodes) + alpha = alpha.reshape(alpha.shape[0], 1, self.num_heads, 1) + + if self.alpha_drop != 0.0: + alpha = nn.Dropout(self.alpha_drop, deterministic=self.deterministic)(alpha) + + # Attention weights * non-linear messages + attn = messages.reshape( + messages.shape[0], + messages.shape[1], + self.num_heads, + self.attn_value_channels, + ) + attn = attn * alpha + messages = attn.reshape( + attn.shape[0], + attn.shape[1], + self.num_heads * self.attn_value_channels, + ) + + # Rotate back the irreps + messages = wigner_mats.rotate_inv(messages) + + # Compute the sum of the incoming neighboring messages for each target node + node_feats = jax.ops.segment_sum(messages, receivers, num_nodes) + + node_feats = SO3Linear(self.lmax, self.output_channels)(node_feats) + return node_feats + + +class FeedForwardNetwork(nn.Module): + """ + FeedForwardNetwork: Perform feedforward network with S2 activation or gate activation + + Args: + lmax (int): Degree (l) + hidden_channels (int): Number of hidden channels used during feedforward network + output_channels (int): Number of output channels + resolution (int): Resolution of the S2 grid + ff_type (FeedForwardType): Type of feedforward network + """ + + lmax: int + hidden_channels: int + output_channels: int + resolution: int + ff_type: FeedForwardType + + @nn.compact + def __call__(self, node_feats: jax.Array) -> jax.Array: + node_feats_orig = node_feats + node_feats = SO3Linear(self.lmax, self.hidden_channels)(node_feats) + + if self.ff_type in [FeedForwardType.GRID, FeedForwardType.GRID_SEP]: + so3_grid = get_s2grid_mats(self.lmax, self.lmax) + + node_feats_grid = so3_grid.to_grid(node_feats) + node_feats_grid = MLP( + [self.hidden_channels] * 3, use_bias=False, use_layer_norm=False, + )(node_feats_grid) + node_feats = so3_grid.from_grid(node_feats_grid) + + if self.ff_type == FeedForwardType.GRID_SEP: + gating_scalars = nn.silu( + nn.Dense(self.hidden_channels)(node_feats_orig[:, 0:1]) + ) + node_feats = jnp.concat( + (gating_scalars, node_feats[:, 1:]), axis=1 + ) + + elif self.ff_type == FeedForwardType.GATE: + gating_scalars = nn.Dense(self.lmax * self.hidden_channels)(node_feats_orig[:, 0:1]) + node_feats = GateActivation( + self.lmax, self.lmax, self.hidden_channels + )(gating_scalars, node_feats) + + elif self.ff_type == FeedForwardType.S2_SEP: + gating_scalars = nn.Dense(self.hidden_channels)(node_feats_orig[:, 0:1]) + node_feats = SeparableS2Activation( + self.lmax, self.lmax, self.resolution + )(gating_scalars, node_feats) + + else: + node_feats = S2Activation(self.lmax, self.lmax, self.resolution)(node_feats) + + node_feats = SO3Linear(self.lmax, self.output_channels)(node_feats) + return node_feats diff --git a/src/mlip/models/equiformer_v2/utils.py b/src/mlip/models/equiformer_v2/utils.py new file mode 100644 index 0000000..5c79865 --- /dev/null +++ b/src/mlip/models/equiformer_v2/utils.py @@ -0,0 +1,194 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Yi-Lun Liao (https://github.com/atomicarchitects/equiformer_v2) under MIT license. + +from enum import Enum +from functools import cache + +import jax +import jax.numpy as jnp +from flax.struct import dataclass + + +class AttnActType(Enum): + """The options are as follows. Parameters not mentioned are False. + + Attributes: + GATE: use_gate_act=True + S2_SEP: use_sep_s2_act=True + S2: else + """ + GATE = 'gate' + S2_SEP ='s2_sep' + S2 ='s2' + + +class FeedForwardType(Enum): + """The options are as follows. Parameters not mentioned are False. + + Attributes: + GATE: Spectral atomwise, use_gate_act=True + GRID: Grid atomwise, use_grid_mlp=True + GRID_SEP: Grid atomwise, use_grid_mlp=True, use_sep_s2_act=True + S2: S2 activation + S2_SEP: S2 activation, use_sep_s2_act=True + """ + GATE = 'gate' + GRID = 'grid' + GRID_SEP = 'grid_sep' + S2 = 's2' + S2_SEP = 's2_sep' + + +def pyg_softmax(src: jax.Array, index: jax.Array, num_segments: int) -> jax.Array: + r"""Computes a sparsely evaluated softmax referenced from torch_geometric. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements for applying the softmax. + + Returns: + The softmax-ed tensor. + """ + + src_max = jax.ops.segment_max(src, index, num_segments) + out = src - src_max[index] + out = jnp.exp(out) + out_sum = jax.ops.segment_sum(out, index, num_segments) + 1e-16 + out_sum = out_sum[index] + + return out / out_sum + + +@cache +def get_expand_index( + lmax: int, mmax: int = None, vector_only: bool = False, m_prime: bool = False +) -> jax.Array: + """Expand coefficients from l or l-1 values on different irreps to (l+1)**2 or (1+1)**2-1 + values on all elements. + + Args: + lmax: Maximum degree (l). + mmax (optional): Maximum order (m), defaults to `lmax`. + vector_only (optional): If True, only return coefficients on vector part, otherwise, return + coefficients on both scalar and vector parts. + m_prime (optional): If True, indices are in order (l0m0, l1m0, l2m0, l1m1, l2m1, l2m2, + l1m-1, l2m-1, l2m-2, ...), otherwise, indices are in order (l0m0, l1m-1, l1m0, l1m1, + l2m-2, l2m-1, l2m0, l2m1, l2m2, ...). + """ + if mmax is None: + mmax = lmax + lmin = 1 if vector_only else 0 + + expand_index_list = [] + + if m_prime: + expand_index_list.append(jnp.arange(0, lmax + 1 - lmin)) + for mval in range(1, mmax + 1): + expand_index_list.extend([ + jnp.arange(mval - lmin, lmax + 1 - lmin), + jnp.arange(mval - lmin, lmax + 1 - lmin), + ]) + else: + for lval in range(lmin, lmax + 1): + length = min((2 * lval + 1), (2 * mmax + 1)) + expand_index_list.append( + jnp.ones([length], dtype=jnp.int32) * (lval - lmin) + ) + + return jnp.concat(expand_index_list) + + +@cache +def get_order_mask(lmax: int, mmax: int, lmax_emb: int = None) -> jax.Array: + """Compute the mask of orders less than or equal to `mmax` on IrrepsArray of + `(1, ..., lmax)`.""" + + if lmax_emb is None: + lmax_emb = lmax + + # Compute the degree (lval) and order (m) for each entry of the embedding + m_harmonic_list = [] + l_harmonic_list = [] + for lval in range(lmax_emb + 1): + m = jnp.arange(-lval, lval + 1) + m_harmonic_list.append(jnp.abs(m)) + l_harmonic_list.append(jnp.ones_like(m) * lval) + + m_harmonic = jnp.concat(m_harmonic_list) + l_harmonic = jnp.concat(l_harmonic_list) + + # Compute the indices of the entries to keep + # We only use a subset of m components for SO(2) convolution + return jnp.logical_and(l_harmonic <= lmax, m_harmonic <= mmax) + + +@cache +def get_rescale_mat(lmax: int, mmax: int, dim: int = 1) -> jax.Array: + """Rescale matrix for masked entries based on `mmax`.""" + + size = (lmax + 1) ** 2 + matrix = jnp.ones([size] * dim) + + if lmax != mmax: + for lval in range(lmax + 1): + if lval <= mmax: + continue + start = lval ** 2 + length = 2 * lval + 1 + rescale_factor = jnp.sqrt(length / (2 * mmax + 1)) + slices = [slice(start, start + length)] * dim + matrix = matrix.at[*slices].set(rescale_factor) + + return matrix + + +@dataclass +class MappingCoeffs: + """Holds the mapping coefficients to reduce parameters.""" + lmax: int + mmax: int + perm: jax.Array + m_size: tuple[int, ...] + num_coefficients: int + + +@cache +def get_mapping_coeffs(lmax: int, mmax: int) -> MappingCoeffs: + """Return the mapping matrix from lval <--> m and size of each degree.""" + + # Compute the degree (lval) and order (m) for each entry of the embedding + m_complex_list = [] + + num_coefficients = 0 + for lval in range(lmax + 1): + mmax_ = min(mmax, lval) + m = jnp.arange(-mmax_, mmax_ + 1) + m_complex_list.append(m) + num_coefficients += len(m) + + m_complex = jnp.concat(m_complex_list, axis=0) + + # `perm` moves m components from different L to contiguous index (m_prime) + perm_list = [] + m_size = [] + + for m in range(mmax + 1): + indices = jnp.arange(len(m_complex)) + + # Real part + idx_r = indices[m_complex == m] + perm_list.append(idx_r) + + m_size.append(len(idx_r)) + + # Imaginary part + if m != 0: + idx_i = indices[m_complex == -m] + perm_list.append(idx_i) + + perm = jnp.concat(perm_list) + + return MappingCoeffs(lmax, mmax, perm, m_size, num_coefficients) diff --git a/src/mlip/models/force_field.py b/src/mlip/models/force_field.py index 0a2d59f..0164818 100644 --- a/src/mlip/models/force_field.py +++ b/src/mlip/models/force_field.py @@ -15,6 +15,7 @@ from dataclasses import dataclass import jax +from flax.typing import RNGSequences import jraph from pydantic import BaseModel from typing_extensions import Self diff --git a/src/mlip/models/liten/config.py b/src/mlip/models/liten/config.py new file mode 100644 index 0000000..c2f438d --- /dev/null +++ b/src/mlip/models/liten/config.py @@ -0,0 +1,45 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Su Qun (https://github.com/lingcon01/LiTEN) under MIT license. + +import pydantic + +from mlip.models.options import Activation, VecNormType +from mlip.typing import PositiveInt + + +class LitenConfig(pydantic.BaseModel): + """Hyperparameters for the LiTEN model. + + Attributes: + num_layers: Number of LiTEN layers. Default is 6. + num_channels: The number of channels. Default is 256. + num_heads: Number of heads in the attention block. Default is 8. + num_rbf: Number of basis functions used in the embedding block. Default is 32. + trainable_rbf: Whether to add learnable weights to each of the radial embedding + basis functions. Default is ``False``. + activation: Activation function for the output block. Options are "silu" + (default), "ssp" (which is shifted softplus), "tanh", "sigmoid", and + "swish". + vecnorm_type: The type of the vector norm. The options are "none", "max_min" + (default), and "rms". + atomic_energies: How to treat the atomic energies. If set to ``None`` (default) + or the string ``"average"``, then the average atomic energies + stored in the dataset info are used. It can also be set to the + string ``"zero"`` which means not to use any atomic energies + in the model. Lastly, one can also pass an atomic energies + dictionary via this parameter different from the one in the + dataset info, that is used. + num_species: The number of elements (atomic species descriptors) allowed. + If ``None`` (default), infer the value from the atomic energies + map in the dataset info. + """ + + num_layers: PositiveInt = 6 + num_channels: PositiveInt = 256 + num_heads: PositiveInt = 8 + num_rbf: PositiveInt = 32 + trainable_rbf: bool = False + activation: Activation = Activation.SILU + vecnorm_type: VecNormType = VecNormType.MAX_MIN + atomic_energies: str | dict[int, float] | None = None + num_species: PositiveInt | None = None diff --git a/src/mlip/models/liten/models.py b/src/mlip/models/liten/models.py new file mode 100644 index 0000000..72a66eb --- /dev/null +++ b/src/mlip/models/liten/models.py @@ -0,0 +1,374 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Su Qun (https://github.com/lingcon01/LiTEN) under MIT license. + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.linen import initializers + +from mlip.data.dataset_info import DatasetInfo +from mlip.models.cutoff import CosineCutoff +from mlip.models.radial_basis import parse_radial_basis +from mlip.models.atomic_energies import get_atomic_energies +from mlip.models.mlip_network import MLIPNetwork +from mlip.models.options import parse_activation +from mlip.models.visnet.blocks import VecLayerNorm +from mlip.models.liten.config import LitenConfig +from mlip.utils.safe_norm import safe_norm + + +class Liten(MLIPNetwork): + """The LiTEN model flax module. It is derived from the + :class:`~mlip.models.mlip_network.MLIPNetwork` class. + + References: + * Qun Su, Kai Zhu, Qiaolin Gou, Jintu Zhang, Renling Hu, Yurong Li, + Yongze Wang, Hui Zhang, Ziyi You, Linlong Jiang, Yu Kang, Jike Wang, + Chang-Yu Hsieh and Tingjun Hou. A Scalable and Quantum-Accurate + Foundation Model for Biomolecular Force Field via Linearly Tensorized + Quadrangle Attention. arXiv, Jul 2025. + URL: https://arxiv.org/abs/2507.00884. + + Attributes: + config: Hyperparameters / configuration for the LiTEN model, see + :class:`~mlip.models.liten.config.LitenConfig`. + dataset_info: Hyperparameters dictated by the dataset + (e.g., cutoff radius or average number of neighbors). + """ + + Config = LitenConfig + + config: LitenConfig + dataset_info: DatasetInfo + + @nn.compact + def __call__( + self, + edge_vectors: jax.Array, + node_species: jax.Array, + senders: jax.Array, + receivers: jax.Array, + **_kwargs, # ignore any additional kwargs + ) -> jax.Array: + + r_max = self.dataset_info.cutoff_distance_angstrom + + num_species = self.config.num_species + if num_species is None: + num_species = len(self.dataset_info.atomic_energies_map) + + liten_kwargs = dict( + vecnorm_type=self.config.vecnorm_type, + num_heads=self.config.num_heads, + num_layers=self.config.num_layers, + num_channels=self.config.num_channels, + num_rbf=self.config.num_rbf, + rbf_type="expnorm", + trainable_rbf=self.config.trainable_rbf, + activation=self.config.activation, + cutoff=r_max, + num_species=num_species, + ) + + representation_model = LitenBlock(**liten_kwargs) + node_energies = representation_model( + edge_vectors, node_species, senders, receivers + ) + mean = self.dataset_info.scaling_mean + std = self.dataset_info.scaling_stdev + node_energies = mean + std * node_energies + + atomic_energies_ = get_atomic_energies( + self.dataset_info, self.config.atomic_energies, num_species + ) + atomic_energies_ = jnp.asarray(atomic_energies_) + node_energies += atomic_energies_[node_species] # [n_nodes, ] + + return node_energies + + +class LitenBlock(nn.Module): + vecnorm_type: str = "none" + num_heads: int = 8 + num_layers: int = 9 + num_channels: int = 256 + num_rbf: int = 32 + rbf_type: str = "expnorm" + trainable_rbf: bool = False + activation: str = "silu" + cutoff: float = 5.0 + num_species: int = 5 + + def setup(self) -> None: + self.node_embedding = nn.Embed(self.num_species, self.num_channels) + self.radial_embedding = parse_radial_basis(self.rbf_type)( + self.cutoff, self.num_rbf, self.trainable_rbf + ) + + self.edge_embedding = nn.Dense(self.num_channels) + + self.liten_layers = [ + LitenLayer( + num_heads=self.num_heads, + num_channels=self.num_channels, + activation=self.activation, + cutoff=self.cutoff, + vecnorm_type=self.vecnorm_type, + last_layer=i == self.num_layers - 1, + first_layer=i == 0, + ) + for i in range(self.num_layers) + ] + + self.out_norm = nn.LayerNorm(epsilon=1e-05) + self.readout_energy = nn.Sequential( + [ + nn.Dense(self.num_channels // 2), + parse_activation(self.activation), + nn.Dense(1), + ] + ) + + def __call__( + self, + edge_vectors: jax.Array, # [n_edges, 3] + node_species: jax.Array, # [n_nodes] int between 0 and num_species-1 + senders: jax.Array, # [n_edges] + receivers: jax.Array, # [n_edges] + ) -> jax.Array: + assert edge_vectors.ndim == 2 and edge_vectors.shape[1] == 3 + assert node_species.ndim == 1 + assert senders.ndim == 1 and receivers.ndim == 1 + assert edge_vectors.shape[0] == senders.shape[0] == receivers.shape[0] + + # Calculate distances + distances = safe_norm(edge_vectors, axis=-1) + + # Normalize edge vectors + edge_vectors = edge_vectors / (distances[:, None] + 1e-8) + + # Embedding Layers + node_feats = self.node_embedding(node_species) + + edge_feats = self.radial_embedding(distances) + # Cosine cutoff is seperated from radial basis function. + edge_feats = edge_feats * CosineCutoff(self.cutoff)(distances) + edge_feats = self.edge_embedding(edge_feats) + + # It will be [n_nodes, 3, num_channels] + vector_feats = None + + assert self.num_channels % self.num_heads == 0, ( + f"The number of hidden channels ({self.num_channels}) " + f"must be evenly divisible by the number of " + f"attention heads ({self.num_heads})" + ) + + for layer in self.liten_layers: + node_feats, edge_feats, vector_feats = layer( + node_feats, + edge_feats, + vector_feats, + distances, + senders, + receivers, + edge_vectors, + ) + + node_feats = self.out_norm(node_feats) + node_energies = self.readout_energy(node_feats).squeeze(-1) + + return node_energies + + +class LitenLayer(nn.Module): + num_heads: int + num_channels: int + activation: str + cutoff: float + vecnorm_type: str + last_layer: bool = False + first_layer: bool = False + + def setup(self): + assert self.num_channels % self.num_heads == 0, ( + f"The number of hidden channels ({self.num_channels}) " + f"must be evenly divisible by the number of " + f"attention heads ({self.num_heads})" + ) + self.head_dim = self.num_channels // self.num_heads + + # Setting eps=1e-05 to reproduce pytorch Layernorm + self.layernorm = nn.LayerNorm(epsilon=1e-05) + self.vec_layernorm = VecLayerNorm( + num_channels=self.num_channels, + norm_type=self.vecnorm_type, + ) + self.act = parse_activation(self.activation) + self.cutoff_fn = CosineCutoff(self.cutoff) + + self.alpha = self.param( + "alpha", + initializers.xavier_uniform(), + ( + 1, + self.num_heads, + self.head_dim, + ), + ) + + self.vec_linear = nn.Dense( + features=self.num_channels * 2, + use_bias=False, + kernel_init=initializers.xavier_uniform(), + ) + self.node_linear = nn.Dense( + features=self.num_channels, + kernel_init=initializers.xavier_uniform(), + ) + self.edge_linear = nn.Dense( + features=self.num_channels, + kernel_init=initializers.xavier_uniform(), + ) + self.part_linear1 = nn.Dense( + features=self.num_channels if self.first_layer else self.num_channels * 2, + kernel_init=initializers.xavier_uniform(), + ) + self.part_linear2 = nn.Dense( + features=self.num_channels * 2 if self.last_layer else self.num_channels * 3, + kernel_init=initializers.xavier_uniform(), + ) + + if not (self.last_layer or self.first_layer): + self.cross_linear = nn.Dense( + features=self.num_channels, + use_bias=False, + kernel_init=initializers.xavier_uniform(), + ) + self.f_linear = nn.Dense( + features=self.num_channels, + kernel_init=initializers.xavier_uniform(), + ) + + def message_fn( + self, + node_feats: jax.Array, + edge_feats: jax.Array, + vector_feats: jax.Array | None, + distances: jax.Array, + senders: jax.Array, + receivers: jax.Array, + edge_vectors: jax.Array, + ): + edge_feats = self.act(self.edge_linear(edge_feats)).reshape( + -1, self.num_heads, self.head_dim + ) + node_feats = self.node_linear(node_feats).reshape(-1, self.num_heads, self.head_dim) + attn = node_feats[receivers] + node_feats[senders] + edge_feats + attn = self.act(attn) * self.alpha + attn = attn.sum(axis=-1) * self.cutoff_fn(distances)[:, None] + attn = attn[:, :, None] + + n_nodes = len(node_feats) + node_feats = node_feats[senders] * edge_feats + node_feats = (node_feats * attn).reshape(-1, self.num_channels) + + node_sca = self.act(self.part_linear1(node_feats))[:, None] # [n_edges, 1, 2*num_channels] + if self.first_layer: + vector_feats = node_sca * edge_vectors[:, :, None] + else: + node_sca1, node_sca2 = jnp.split(node_sca, 2, axis=2) + vector_feats = ( + vector_feats[senders] * node_sca1 + node_sca2 * edge_vectors[:, :, None] + ) + + node_feats = jax.ops.segment_sum(node_feats, receivers, num_segments=n_nodes) + vector_feats = jax.ops.segment_sum(vector_feats, receivers, num_segments=n_nodes) + + return node_feats, vector_feats + + def edge_update( + self, + vector_feats: jax.Array, # [n_nodes, 3, num_channels] + edge_feats: jax.Array, # [n_edges, num_channels] + senders: jax.Array, # [n_edges] + receivers: jax.Array, # [n_edges] + edge_vectors: jax.Array, # [n_edges, 3] + ): + vector_feats = self.cross_linear(vector_feats) + + vec_cross_i = jnp.cross(vector_feats[senders], edge_vectors[:, :, None], axis=1) + vec_cross_j = jnp.cross(vector_feats[receivers], edge_vectors[:, :, None], axis=1) + sum_phi = jnp.sum(vec_cross_i * vec_cross_j, axis=1) + + diff_edge_feats = self.act(self.f_linear(edge_feats)) * sum_phi + + return diff_edge_feats + + def node_update( + self, + node_feats: jax.Array, + vector_feats: jax.Array, + ): + vec1, vec2 = jnp.split(self.vec_linear(vector_feats), 2, axis=-1) + vec_tri = jnp.sum(vec1 * vec2, axis=1) + + norm_vec = jnp.sqrt(jnp.sum(vec2 ** 2, axis=-2) + 1e-16) + vec_qua = norm_vec ** 3 + + node_feats = self.part_linear2(node_feats) + + if self.last_layer: + sca1, sca2 = jnp.split(node_feats, 2, axis=1) + else: + sca1, sca2, sca3 = jnp.split(node_feats, 3, axis=1) + + diff_scalar = (vec_qua + vec_tri) * sca1 + sca2 + + if self.last_layer: + return diff_scalar + + diff_vector = vec1 * sca3[:, None] + return diff_scalar, diff_vector + + def __call__( + self, + node_feats: jax.Array, + edge_feats: jax.Array, + vector_feats: jax.Array, + distances: jax.Array, + senders: jax.Array, + receivers: jax.Array, + edge_vectors: jax.Array, + ) -> tuple[jax.Array, jax.Array, jax.Array]: + scalar_out = self.layernorm(node_feats) + + if not self.first_layer: + vector_feats = self.vec_layernorm(vector_feats) + + scalar_out, vector_out = self.message_fn( + scalar_out, edge_feats, vector_feats, distances, senders, receivers, edge_vectors + ) + + if not (self.last_layer or self.first_layer): + diff_edge_feats = self.edge_update( + vector_feats, edge_feats, senders, receivers, edge_vectors + ) + edge_feats = edge_feats + diff_edge_feats + + node_feats = node_feats + scalar_out + + if self.first_layer: + vector_feats = vector_out + else: + vector_feats = vector_feats + vector_out + + diff_scalar = self.node_update(node_feats, vector_feats) + + if not self.last_layer: + diff_scalar, diff_vector = diff_scalar + vector_feats = vector_feats + diff_vector + + node_feats = node_feats + diff_scalar + + return node_feats, edge_feats, vector_feats diff --git a/src/mlip/models/mace/models.py b/src/mlip/models/mace/models.py index 78a2b4a..8ea26eb 100644 --- a/src/mlip/models/mace/models.py +++ b/src/mlip/models/mace/models.py @@ -72,6 +72,7 @@ def __call__( node_species: jnp.ndarray, senders: jnp.ndarray, receivers: jnp.ndarray, + **_kwargs, # ignore any additional kwargs ) -> jnp.ndarray: e3nn.config("path_normalization", "path") diff --git a/src/mlip/models/nequip/models.py b/src/mlip/models/nequip/models.py index e87064c..63fbc27 100644 --- a/src/mlip/models/nequip/models.py +++ b/src/mlip/models/nequip/models.py @@ -61,6 +61,7 @@ def __call__( node_species: jnp.ndarray, senders: jnp.ndarray, receivers: jnp.ndarray, + **_kwargs, # ignore any additional kwargs ) -> jnp.ndarray: e3nn.config("path_normalization", "path") diff --git a/src/mlip/models/predictor.py b/src/mlip/models/predictor.py index 6b9d3b9..092cd0e 100644 --- a/src/mlip/models/predictor.py +++ b/src/mlip/models/predictor.py @@ -49,11 +49,13 @@ class ForceFieldPredictor(nn.Module): mlip_network: nn.Module predict_stress: bool - def __call__(self, graph: jraph.GraphsTuple) -> Prediction: + def __call__(self, graph: jraph.GraphsTuple, training: bool = False) -> Prediction: """Returns a `Prediction` dataclass of properties based on an input graph. Args: graph: The input graph. + training: Whether the model is in training mode or not. If true, rngs should be + passed for stochastic modules. Returns: The properties as a `Prediction` object including "energy" and "forces". @@ -65,7 +67,7 @@ def __call__(self, graph: jraph.GraphsTuple) -> Prediction: "See models tutorial in documentation for details." ) - prediction, minus_forces, pseudo_stress = self._compute_gradients(graph) + prediction, minus_forces, pseudo_stress = self._compute_gradients(graph, training) prediction = prediction.replace(forces=-minus_forces) if not self.predict_stress: @@ -79,8 +81,8 @@ def __call__(self, graph: jraph.GraphsTuple) -> Prediction: ) def _compute_gradients( - self, graph: jraph.GraphsTuple - ) -> tuple[Prediction, np.ndarray, np.ndarray]: + self, graph: jraph.GraphsTuple, training: bool + ) -> tuple[Prediction, np.ndarray, np.ndarray | None]: """Return a `(prediction, gradients, pseudo_stress)` triple. The `prediction` holds graph energies, and eventual optional fields. @@ -89,12 +91,37 @@ def _compute_gradients( """ # Note: strains are invariant vector fields tangent to cell strains = jnp.zeros_like(graph.globals.cell) + + # NOTE: When direct_force is enabled or predict_stress is disabled, + # there is no need to compute corresponding gradients. + direct_force = getattr(self.mlip_network.config, "direct_force", False) + if direct_force: + argnums = 1 if self.predict_stress else None + else: + argnums = (0, 1) if self.predict_stress else 0 + + # Gradient is not needed + if argnums is None: + _, prediction = self._compute_energy( + graph.nodes.positions, strains, graph, training + ) + return prediction, -prediction.forces, None # pylint: disable=E1130 + # Differentiate wrt positions and strains (not cell) - (gradients, pseudo_stress), prediction = jax.grad( - self._compute_energy, (0, 1), has_aux=True - )(graph.nodes.positions, strains, graph) + grads, prediction = jax.grad( + self._compute_energy, argnums, has_aux=True + )(graph.nodes.positions, strains, graph, training) + + if argnums == (0, 1): + minus_forces, pseudo_stress = grads + elif argnums == 0: + minus_forces = grads + pseudo_stress = None + else: + minus_forces = -prediction.forces + pseudo_stress = grads - return prediction, gradients, pseudo_stress + return prediction, minus_forces, pseudo_stress @staticmethod def _compute_stress_results( @@ -125,6 +152,7 @@ def _compute_energy( positions: np.ndarray, strains: np.ndarray, graph: jraph.GraphsTuple, + training: bool, ) -> tuple[np.ndarray, Prediction]: """Return total energy and a `Prediction` object holding graph energies. @@ -132,13 +160,21 @@ def _compute_energy( differentiation. The `Prediction` object holds graph-wise energies at this stage, and may be further populated by downstream methods. """ - node_energies = self._compute_node_features(positions, strains, graph) + node_features = self._compute_node_features(positions, strains, graph, training) - assert node_energies.shape == (len(positions),), ( + assert node_features.shape in [(len(positions),), (len(positions), 4)], ( f"model output needs to be an array of shape " - f"(n_nodes, ) but got {node_energies.shape}" + f"(n_nodes, ) or (n_nodes, 4), but got {node_features.shape}" ) + # When `node_energies` is a concatenation of the node-wise energies and forces. + forces = None + if getattr(self.mlip_network.config, "direct_force", False): + node_energies, forces = jnp.split(node_features, [1], axis=-1) + node_energies = node_energies.squeeze(-1) + else: + node_energies = node_features + total_energy = jnp.sum(node_energies) graph_energies = e3nn.scatter_sum( @@ -147,6 +183,7 @@ def _compute_energy( prediction = Prediction( energy=graph_energies, + forces=forces, ) return total_energy, prediction @@ -156,6 +193,7 @@ def _compute_node_features( positions: np.ndarray, strains: np.ndarray, graph: jraph.GraphsTuple, + training: bool, ) -> np.ndarray: """Evaluate node-wise outputs of `.mlip_network` on graph data. @@ -186,6 +224,8 @@ def _compute_node_features( graph.nodes.species, graph.senders, graph.receivers, + n_node=graph.n_node, # kwargs + training=training, # kwargs ) padding_mask = jraph.get_node_padding_mask(graph) padding_mask = jnp.expand_dims( diff --git a/src/mlip/models/radial_basis.py b/src/mlip/models/radial_basis.py new file mode 100644 index 0000000..d80820d --- /dev/null +++ b/src/mlip/models/radial_basis.py @@ -0,0 +1,180 @@ +# Copyright 2025 Zhongguancun Academy + +""" +This module contains all radial basis functions seen in all models. +It can be used to refactor Mace, Visnet and Nequip. +""" + +from enum import Enum + +import numpy as np +import jax +import jax.numpy as jnp +import flax.linen as nn +import e3nn_jax as e3nn + + +class ExpNormalBasis(nn.Module): + """Original ExpNormalSmearing from Visnet without cutoff function.""" + cutoff: float + num_rbf: int + trainable: bool = True + + def setup(self): + self.alpha = 5.0 / self.cutoff + means, betas = self._initial_params() + if self.trainable: + self.means = self.param( + "means", nn.initializers.constant(means), (self.num_rbf,) + ) + self.betas = self.param( + "betas", nn.initializers.constant(betas), (self.num_rbf,) + ) + else: + self.means = means + self.betas = betas + + def _initial_params(self): + start_value = jnp.exp(-self.cutoff) + means = jnp.linspace(start_value, 1, self.num_rbf) + betas = jnp.full((self.num_rbf,), (2 / self.num_rbf * (1 - start_value)) ** -2) + return means, betas + + def __call__(self, dist: jax.Array) -> jax.Array: + dist = dist[..., None] + return jnp.exp( + (-1 * self.betas) * (jnp.exp(self.alpha * (-dist)) - self.means) ** 2 + ) + + +class GaussianBasis(nn.Module): + """ + Original GaussianSmearing from Visnet without cutoff function. + It's also used in So3krates named RBF. + """ + cutoff: float + num_rbf: int + trainable: bool = True + rbf_width: float = 1.0 + + def setup(self): + offset, coeff = self._initial_params() + if self.trainable: + self.offset = self.param( + "offset", nn.initializers.constant(offset), (self.num_rbf,) + ) + self.coeff = self.param("coeff", nn.initializers.constant(coeff), ()) + else: + self.offset = offset + self.coeff = coeff + + def _initial_params(self): + offset = jnp.linspace(0, self.cutoff, self.num_rbf) + coeff = -0.5 / (self.rbf_width * (offset[1] - offset[0])) ** 2 + return offset, coeff + + def __call__(self, dist: jax.Array) -> jax.Array: + dist = dist[..., None] - self.offset + return jnp.exp(self.coeff * jnp.square(dist)) + + +class BesselBasis(nn.Module): + """Bessel basis used in Mace and Nequip. This is not the same named function in So3krates.""" + cutoff: float + num_rbf: int + + @nn.compact + def __call__(self, dist: jax.Array) -> jax.Array: + return e3nn.bessel(dist, self.num_rbf, self.cutoff) + + +def log_binomial(n: int) -> jax.Array: + """ + Returns: jax.Array of shape (n+1,) + [log C(n, 0), ..., log C(n, n)] + """ + out = [] + for k in range(n + 1): + n_factorial = np.sum(np.log(np.arange(1, n + 1))) + k_factorial = np.sum(np.log(np.arange(1, k + 1))) + n_k_factorial = np.sum(np.log(np.arange(1, n - k + 1))) + out.append(n_factorial - k_factorial - n_k_factorial) + return jnp.stack(out) + +class BernsteinBasis(nn.Module): + """Bernstein polynomial basis from So3krates.""" + cutoff: float + num_rbf: int + gamma: float = 0.9448630629184640 + + @nn.compact + def __call__(self, dist: jax.Array) -> jax.Array: + b = log_binomial(self.num_rbf - 1) + k = jnp.arange(self.num_rbf) + k_rev = k[::-1] + + scaled_dist = -self.gamma * dist[..., None] + k_x = k * scaled_dist + kk_x = k_rev * jnp.log(1e-8 - jnp.expm1(scaled_dist)) + return jnp.exp(b + k_x + kk_x) + + +class PhysNetBasis(nn.Module): + """Expand distances in the basis used in PhysNet (see https://arxiv.org/abs/1902.08408)""" + cutoff: float + num_rbf: int + + @nn.compact + def __call__(self, dist: jax.Array) -> jax.Array: + exp_dist = jnp.exp(-dist)[..., None] + exp_cutoff = jnp.exp(-self.cutoff) + + offset = jnp.linspace(exp_cutoff, 1, self.num_rbf) + coeff = self.num_rbf / 2 / (1 - exp_cutoff) + return jnp.exp(-(coeff * (exp_dist - offset)) ** 2) + + +class FourierBasis(nn.Module): + """ + Expand distances in the Bessel basis (see https://arxiv.org/pdf/2003.03123.pdf). + It's also called Bessel basis in So3krates. Since we already have the BesselBasis, we have to + use another name. + """ + cutoff: float + num_rbf: int + + def setup(self): + self.offset = jnp.arange(0, self.num_rbf, 1) + + def __call__(self, dist: jax.Array) -> jax.Array: + dist = dist[..., None] + # In So3krates, safe_mask is used to avoid divide by zero. Here, we use a small epsilon. + return jnp.sin(jnp.pi / self.cutoff * self.offset * dist) / (dist + 1e-8) + +# --- Options --- + + +class RadialBasis(Enum): + GAUSS = "gauss" + EXPNORM = "expnorm" + BESSEL = "bessel" + BERNSTEIN = "bernstein" + PHYS = "phys" + FOURIER = "fourier" + + +def parse_radial_basis(basis: RadialBasis | str) -> type[nn.Module]: + """Parse `RadialBasis` parameter among available options. + + See :class:`~dipm.models.options.RadialBasis`. + """ + radial_basis_map = { + RadialBasis.GAUSS: GaussianBasis, + RadialBasis.EXPNORM: ExpNormalBasis, + RadialBasis.BESSEL: BesselBasis, + RadialBasis.BERNSTEIN: BernsteinBasis, + RadialBasis.PHYS: PhysNetBasis, + RadialBasis.FOURIER: FourierBasis, + } + assert set(RadialBasis) == set(radial_basis_map.keys()) + return radial_basis_map[RadialBasis(basis)] diff --git a/src/mlip/models/so3krates/blocks.py b/src/mlip/models/so3krates/blocks.py new file mode 100644 index 0000000..06c8c3d --- /dev/null +++ b/src/mlip/models/so3krates/blocks.py @@ -0,0 +1,218 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Thorben Frank (https://github.com/thorben-frank/mlff) under MIT license. + +import flax.linen as nn +from flax.linen.initializers import constant +import jax +import jax.numpy as jnp +import e3nn_jax as e3nn + +from mlip.models.options import parse_activation + + +class MLP(nn.Module): + features: tuple[int, ...] + activation: str + use_bias: bool = True + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + x = inputs + for i, feat in enumerate(self.features): + x = nn.Dense(feat, use_bias=self.use_bias)(x) + if i != len(self.features) - 1: + x = parse_activation(self.activation)(x) + return x + + +class ResidualMLP(nn.Module): + num_blocks: int = 3 + activation: str = 'silu' + # In original So3krates, this is set to False. But using bias is better. + use_bias: bool = True + + @nn.compact + def __call__(self, inputs: jax.Array): + x = inputs + feat = x.shape[-1] + for _ in range(self.num_blocks): + x = parse_activation(self.activation)(x) + x = nn.Dense(feat, use_bias=self.use_bias)(x) + x = x + inputs + # In original So3krates, there exists a non-residual Linear. But it would + # be slightly better to include it in the residual. + # x = parse_activation(self.activation)(x) + # x = nn.Dense(feat, use_bias=self.use_bias)(x) + return x + + +class InteractionBlock(nn.Module): + @nn.compact + def __call__( + self, + node_feats: jax.Array, + chi: e3nn.IrrepsArray, + ) -> tuple[jax.Array, e3nn.IrrepsArray]: + num_features = node_feats.shape[-1] + + # Tensor product using CG coefficents has been removed for simplicity. + chi_scalar = e3nn.norm(chi, squared=True, per_irrep=True).array + + feats = jnp.concatenate([node_feats, chi_scalar], axis=-1) + feats = nn.Dense(num_features + chi.irreps.num_irreps)(feats) + + # node_feats: [n_nodes, num_features], chi_coeffs: [n_nodes, n_heads] + node_feats, chi_coeffs = jnp.split(feats, [num_features], axis=-1) + + return node_feats, chi_coeffs * chi + + +class FeatureBlock(nn.Module): + num_heads: int + rad_features: tuple[int, ...] + sph_features: tuple[int, ...] + activation: str + avg_num_neighbors: float + + @nn.compact + def __call__( + self, + node_feats: jax.Array, + edge_feats: jax.Array, + chi_scalar: jax.Array, + cutoffs: jax.Array, + senders: jax.Array, + receivers: jax.Array + ) -> jax.Array: + alpha = FilterScaledAttentionMap( + num_heads=self.num_heads, + rad_features=self.rad_features, + sph_features=self.sph_features, + activation=self.activation + )(node_feats, edge_feats, chi_scalar, senders, receivers) + + alpha = alpha * cutoffs[:, None] # [n_edges, n_heads] + + head_dim = node_feats.shape[-1] // self.num_heads + v_j = nn.Dense(node_feats.shape[-1], use_bias=False)(node_feats)[senders] + v_j = v_j.reshape(-1, self.num_heads, head_dim) + + node_feats = jax.ops.segment_sum( + alpha[..., None] * v_j, receivers, num_segments=node_feats.shape[0] + ) / self.avg_num_neighbors + node_feats = node_feats.reshape(-1, head_dim * self.num_heads) + return node_feats + + +class GeometricBlock(nn.Module): + rad_features: tuple[int, ...] + sph_features: tuple[int, ...] + activation: str + avg_num_neighbors: float + + @nn.compact + def __call__( + self, + edge_sh: e3nn.IrrepsArray, + node_feats: jax.Array, + edge_feats: jax.Array, + chi_scalar: jax.Array, + cutoffs: jax.Array, + senders: jax.Array, + receivers: jax.Array + ) -> e3nn.IrrepsArray: + alpha = FilterScaledAttentionMap( + num_heads=edge_sh.irreps.num_irreps, + rad_features=self.rad_features, + sph_features=self.sph_features, + activation=self.activation + )(node_feats, edge_feats, chi_scalar, senders, receivers) + + alpha = alpha * cutoffs[:, None] + + # e3nn supports directly multiply IrrepsArray with scalars. + chi = e3nn.scatter_sum(alpha * edge_sh, dst=receivers, output_size=node_feats.shape[0]) + return chi / self.avg_num_neighbors + + +class FilterScaledAttentionMap(nn.Module): + num_heads: int + rad_features: tuple[int, ...] + sph_features: tuple[int, ...] + activation: str + + @nn.compact + def __call__( + self, + node_feats: jax.Array, + edge_feats: jax.Array, + chi_scalar: jax.Array, + senders: jax.Array, + receivers: jax.Array + ) -> jax.Array: + head_dim = node_feats.shape[-1] // self.num_heads + + # Radial spherical filter + w_ij = MLP(self.rad_features, self.activation)(edge_feats) + w_ij += MLP(self.sph_features, self.activation)(chi_scalar) + w_ij = w_ij.reshape(-1, self.num_heads, head_dim) + + # Geometric attention coefficients + q_i = nn.Dense(node_feats.shape[-1], use_bias=False)(node_feats) + q_i = q_i.reshape(-1, self.num_heads, head_dim)[receivers] + + k_j = nn.Dense(node_feats.shape[-1], use_bias=False)(node_feats) + k_j = k_j.reshape(-1, self.num_heads, head_dim)[senders] + + return (q_i * w_ij * k_j).sum(axis=-1) / jnp.sqrt(head_dim) # [n_edges, n_heads] + + +class ZBLRepulsion(nn.Module): + """Ziegler-Biersack-Littmark repulsion.""" + index_to_z: tuple[int, ...] + a0: float = 0.5291772105638411 + ke: float = 14.399645351950548 + + @nn.compact + def __call__( + self, + node_species: jax.Array, + distances: jax.Array, + cutoffs: jax.Array, + senders: jax.Array, + receivers: jax.Array, + ) -> jax.Array: + def softplus_inverse(x): + return x + jnp.log(-jnp.expm1(-x)) + + # We vectorize a/c for simplicity. + a_init = softplus_inverse(jnp.array([3.20000, 0.94230, 0.40280, 0.20160])) + c_init = softplus_inverse(jnp.array([0.18180, 0.50990, 0.28020, 0.02817])) + + a = nn.softplus(self.param('a', constant(a_init), (4,))) + c = nn.softplus(self.param('c', constant(c_init), (4,))) + c = c / jnp.sum(c) + + p = nn.softplus(self.param('p', constant(softplus_inverse(0.23)), (1,))) + d = nn.softplus(self.param( + 'd', constant(softplus_inverse(1 / (0.8854 * self.a0))), (1,) + )) + + z = jnp.array(self.index_to_z)[node_species] + z_i = z[receivers] + z_j = z[senders] + + x = self.ke * cutoffs * z_i * z_j / (distances + 1e-8) + + rzd = distances * (jnp.power(z_i, p) + jnp.power(z_j, p)) * d + + # ZBL screening function, shape: [n_edges] + y = jnp.sum(c * jnp.exp(-a * rzd[:, None]), axis=-1) + + scaled_d = distances / 1.5 + sigma_d = jnp.exp(-1. / (jnp.where(scaled_d > 1e-8, scaled_d, 1e-8))) + sigma_1_d = jnp.exp(-1. / (jnp.where(1 - scaled_d > 1e-8, 1 - scaled_d, 1e-8))) + w = sigma_1_d / (sigma_1_d + sigma_d) + + energy_rep = w * x * y / 2 + return jax.ops.segment_sum(energy_rep, receivers, num_segments=node_species.shape[0]) diff --git a/src/mlip/models/so3krates/config.py b/src/mlip/models/so3krates/config.py new file mode 100644 index 0000000..b7ca8c6 --- /dev/null +++ b/src/mlip/models/so3krates/config.py @@ -0,0 +1,66 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Thorben Frank (https://github.com/thorben-frank/mlff) under MIT license. + +import pydantic + +from mlip.models.cutoff import CutoffFunction +from mlip.models.radial_basis import RadialBasis +from mlip.models.options import Activation +from mlip.typing import PositiveInt + + +class So3kratesConfig(pydantic.BaseModel): + """Hyperparameters for the So3krates model. + + Attributes: + num_layers: Number of So3krates layers. Default is 3. + num_channels: The number of channels. Default is 128. + num_heads: Number of heads in the attention block. Default is 4. + num_rbf: Number of basis functions used in the embedding block. Default is 32. + activation: Activation function for the output block. Options are "silu" + (default), "ssp" (which is shifted softplus), "tanh", "sigmoid", and + "swish". + radial_cutoff_fn: The type of the cutoff / radial envelope function. + radial_basis_fn: The type of the radial basis function. + chi_irreps: The irreps of the spherical harmonic coorindates (SPHCs). + sphc_normalization: Normalization constant for initializing spherical harmonic + coordinates (SPHCs). If set to ``None``, SPHCs are initialized + to zero. + residual_mlp_1: Whether to apply a residual MLP after the first (feature + + geometric) update block inside each So3krates layer. + residual_mlp_2: Whether to apply a residual MLP after the interaction block inside + each So3krates layer. + normalization: Whether to apply LayerNorm to scalar node features before major + update blocks inside each So3krates layer. + zbl_repulsion: Whether to include an explicit Ziegler-Biersac-Littmark (ZBL) + short-range nuclear repulsion term in the predicted energies. + zbl_repulsion_shift: Constant energy shift subtracted from the ZBL repulsion + contribution. + atomic_energies: How to treat the atomic energies. If set to ``None`` (default) + or the string ``"average"``, then the average atomic energies + stored in the dataset info are used. It can also be set to the + string ``"zero"`` which means not to use any atomic energies + in the model. Lastly, one can also pass an atomic energies + dictionary via this parameter different from the one in the + dataset info, that is used. + num_species: The number of elements (atomic species descriptors) allowed. + If ``None`` (default), infer the value from the atomic energies + map in the dataset info. + """ + + num_layers: PositiveInt = 3 + num_channels: PositiveInt = 128 + num_heads: PositiveInt = 4 + num_rbf: PositiveInt = 32 + activation: Activation = Activation.SILU + radial_cutoff_fn: CutoffFunction = CutoffFunction.PHYS + radial_basis_fn: RadialBasis = RadialBasis.BERNSTEIN + chi_irreps: str = "1e + 2e + 3e + 4e" + sphc_normalization: float | None = None + residual_mlp_1: bool = True + residual_mlp_2: bool = False + normalization: bool = True + zbl_repulsion: bool = True + zbl_repulsion_shift: float = 0.0 + atomic_energies: str | dict[int, float] | None = None + num_species: PositiveInt | None = None diff --git a/src/mlip/models/so3krates/models.py b/src/mlip/models/so3krates/models.py new file mode 100644 index 0000000..0d5e8c1 --- /dev/null +++ b/src/mlip/models/so3krates/models.py @@ -0,0 +1,270 @@ +# Copyright 2025 Zhongguancun Academy +# Based on code initially developed by Thorben Frank (https://github.com/thorben-frank/mlff) under MIT license. + +import e3nn_jax as e3nn +import flax.linen as nn +import jax +import jax.numpy as jnp + +from mlip.data.dataset_info import DatasetInfo +from mlip.models.atomic_energies import get_atomic_energies +from mlip.models.mlip_network import MLIPNetwork +from mlip.models.so3krates.blocks import ( + MLP, ResidualMLP, FeatureBlock, GeometricBlock, InteractionBlock, ZBLRepulsion +) +from mlip.models.so3krates.config import So3kratesConfig +from mlip.models.cutoff import parse_cutoff +from mlip.models.radial_basis import parse_radial_basis +from mlip.utils.safe_norm import safe_norm + + +class So3krates(MLIPNetwork): + """The So3krates model flax module. It is derived from the + :class:`~mlip.models.mlip_network.MLIPNetwork` class. + + References: + * Frank Thorben, Oliver Unke and Klaus-Robert Müller. So3krates: Equivariant + attention for interactions on arbitrary length-scales in molecular systems. + Advances in Neural Information Processing Systems, 35, Dec 2022. + URL: https://proceedings.neurips.cc/paper_files/paper/2022/hash/bcf4ca90a8d405201d29dd47d75ac896-Abstract-Conference.html + + Attributes: + config: Hyperparameters / configuration for the So3krates model, see + :class:`~mlip.models.so3krates.config.So3kratesConfig`. + dataset_info: Hyperparameters dictated by the dataset + (e.g., cutoff radius or average number of neighbors). + """ + + Config = So3kratesConfig + + config: So3kratesConfig + dataset_info: DatasetInfo + + @nn.compact + def __call__( + self, + edge_vectors: jax.Array, + node_species: jax.Array, + senders: jax.Array, + receivers: jax.Array, + **_kwargs, # ignore any additional kwargs + ) -> jax.Array: + + r_max = self.dataset_info.cutoff_distance_angstrom + + num_species = self.config.num_species + if num_species is None: + num_species = len(self.dataset_info.atomic_energies_map) + + # Is it necessary to allow users to modify here? + rad_filter_features = [self.config.num_channels, self.config.num_channels] + sph_filter_features = [self.config.num_channels // 4, self.config.num_channels] + + so3krates_kwargs = dict( + num_heads=self.config.num_heads, + num_layers=self.config.num_layers, + num_channels=self.config.num_channels, + num_rbf=self.config.num_rbf, + chi_irreps=self.config.chi_irreps, + fb_rad_filter_features=rad_filter_features, + gb_rad_filter_features=rad_filter_features, + fb_sph_filter_features=sph_filter_features, + gb_sph_filter_features=sph_filter_features, + radial_basis_fn=self.config.radial_basis_fn, + sphc_normalization=self.config.sphc_normalization, + residual_mlp_1=self.config.residual_mlp_1, + residual_mlp_2=self.config.residual_mlp_2, + normalization=self.config.normalization, + activation=self.config.activation, + cutoff=r_max, + num_species=num_species, + avg_num_neighbors=self.dataset_info.avg_num_neighbors + ) + + representation_model = So3kratesBlock(**so3krates_kwargs) + + # This will be used by the ZBL repulsion term + distances = safe_norm(edge_vectors, axis=-1) + cutoffs = parse_cutoff(self.config.radial_cutoff_fn)(r_max)(distances) + + node_energies = representation_model( + edge_vectors, distances, cutoffs, node_species, senders, receivers + ) + mean = self.dataset_info.scaling_mean + std = self.dataset_info.scaling_stdev + node_energies = mean + std * node_energies + + if self.config.zbl_repulsion: + index_to_z = tuple(sorted(self.dataset_info.atomic_energies_map.keys())) + e_rep = ZBLRepulsion(index_to_z)( + node_species, distances, cutoffs, senders, receivers + ) + node_energies += e_rep - self.config.zbl_repulsion_shift + + atomic_energies_ = get_atomic_energies( + self.dataset_info, self.config.atomic_energies, num_species + ) + atomic_energies_ = jnp.asarray(atomic_energies_) + node_energies += atomic_energies_[node_species] # [n_nodes, ] + + return node_energies + + +class So3kratesBlock(nn.Module): + num_layers: int + num_channels: int + num_species: int + num_rbf: int + chi_irreps: str + fb_rad_filter_features: tuple[int, ...] + gb_rad_filter_features: tuple[int, ...] + fb_sph_filter_features: tuple[int, ...] + gb_sph_filter_features: tuple[int, ...] + cutoff: float = 5.0 + radial_basis_fn: str = 'phys' + sphc_normalization: float | None = None + activation: str = 'silu' + num_heads: int = 4 + residual_mlp_1: bool = False + residual_mlp_2: bool = False + normalization: bool = False + # In the original So3krates repo, this scaling factor does not exist. But for deeper networks, + # this is necessary to ensure that the initial loss does not explode. + avg_num_neighbors: float = 1.0 + + @nn.compact + def __call__(self, + edge_vectors: jax.Array, + distances: jax.Array, + cutoffs: jax.Array, + node_species: jax.Array, + senders: jax.Array, + receivers: jax.Array, + ) -> jax.Array: + edge_feats = parse_radial_basis(self.radial_basis_fn)( + self.cutoff, self.num_rbf + )(distances) + + # This implementation differs from the original So3krates repo: (1) the coefficents are + # different (So3krates uses physics convention, while here uses math convention), and + # (2) the output components follow a m-ordering in So3krates, while it is in Cartesian + # here. These are equivalent. + edge_vectors = e3nn.IrrepsArray('1e', edge_vectors) + chi_irreps = e3nn.Irreps(self.chi_irreps) + edge_sh = e3nn.spherical_harmonics(chi_irreps, edge_vectors, True) + + # Initalize node features and spherical harmonic coordinates (SPHCs) + node_feats = nn.Embed(self.num_species, self.num_channels)(node_species) + if self.sphc_normalization is None: + chi = e3nn.zeros(chi_irreps, (node_species.shape[0],), dtype=edge_vectors.dtype) + else: + chi = e3nn.scatter_sum( + edge_sh * cutoffs[:, None], dst=receivers, output_size=node_species.shape[0] + ) / self.sphc_normalization + + for _ in range(self.num_layers): + node_feats, chi = So3kratesLayer( + self.fb_rad_filter_features, + self.gb_rad_filter_features, + self.fb_sph_filter_features, + self.gb_sph_filter_features, + self.activation, + self.num_heads, + self.residual_mlp_1, + self.residual_mlp_2, + self.normalization, + self.avg_num_neighbors + )( + node_feats=node_feats, + chi=chi, + edge_feats=edge_feats, + edge_sh=edge_sh, + cutoffs=cutoffs, + senders=senders, + receivers=receivers + ) + + node_energies = MLP([node_feats.shape[-1], 1], self.activation)( + node_feats + ).squeeze(axis=-1) + + return node_energies + + +class So3kratesLayer(nn.Module): + fb_rad_filter_features: tuple[int, ...] + gb_rad_filter_features: tuple[int, ...] + fb_sph_filter_features: tuple[int, ...] + gb_sph_filter_features: tuple[int, ...] + activation: str = 'silu' + num_heads: int = 4 + residual_mlp_1: bool = False + residual_mlp_2: bool = False + normalization: bool = False + avg_num_neighbors: float = 1.0 + + @nn.compact + def __call__( + self, + node_feats: jax.Array, + chi: e3nn.IrrepsArray, + edge_feats: jax.Array, + edge_sh: e3nn.IrrepsArray, + cutoffs: jax.Array, + senders: jax.Array, + receivers: jax.Array + ) -> tuple[jax.Array, e3nn.IrrepsArray]: + chi_ij = chi[senders] - chi[receivers] + chi_scalar = e3nn.norm(chi_ij, squared=True, per_irrep=True).array + + # first block + node_feats_pre = nn.LayerNorm()(node_feats) if self.normalization else node_feats + + diff_node_feats = FeatureBlock( + self.num_heads, + rad_features=self.fb_rad_filter_features, + sph_features=self.fb_sph_filter_features, + activation=self.activation, + avg_num_neighbors=self.avg_num_neighbors + )( + node_feats=node_feats_pre, + edge_feats=edge_feats, + chi_scalar=chi_scalar, + cutoffs=cutoffs, + senders=senders, + receivers=receivers + ) + + diff_chi = GeometricBlock( + rad_features=self.gb_rad_filter_features, + sph_features=self.gb_sph_filter_features, + activation=self.activation, + avg_num_neighbors=self.avg_num_neighbors + )( + edge_sh=edge_sh, + node_feats=node_feats_pre, + edge_feats=edge_feats, + chi_scalar=chi_scalar, + cutoffs=cutoffs, + senders=senders, + receivers=receivers + ) + + node_feats = node_feats + diff_node_feats + chi = chi + diff_chi + + # second block + if self.residual_mlp_1: + node_feats = ResidualMLP(activation=self.activation)(node_feats) + + node_feats_pre = nn.LayerNorm()(node_feats) if self.normalization else node_feats + + diff_node_feats, diff_chi = InteractionBlock()(node_feats_pre, chi) + + node_feats = node_feats + diff_node_feats + chi = chi + diff_chi + + if self.residual_mlp_2: + node_feats = ResidualMLP(activation=self.activation)(node_feats) + + return node_feats, chi diff --git a/src/mlip/models/visnet/models.py b/src/mlip/models/visnet/models.py index e379968..7610ab6 100644 --- a/src/mlip/models/visnet/models.py +++ b/src/mlip/models/visnet/models.py @@ -69,6 +69,7 @@ def __call__( node_species: jnp.ndarray, senders: jnp.ndarray, receivers: jnp.ndarray, + **_kwargs, # ignore any additional kwargs ) -> jnp.ndarray: r_max = self.dataset_info.cutoff_distance_angstrom diff --git a/src/mlip/training/training_step.py b/src/mlip/training/training_step.py index e638355..8e38557 100644 --- a/src/mlip/training/training_step.py +++ b/src/mlip/training/training_step.py @@ -19,6 +19,7 @@ import optax from jax import Array from jraph import GraphsTuple +from flax.typing import RNGSequences from mlip.models.predictor import ForceFieldPredictor from mlip.training.ema import EMAParameterTransformation @@ -31,7 +32,7 @@ def _training_step( training_state: TrainingState, graph: GraphsTuple, epoch_number: int, - model_loss_fun: Callable[[ModelParameters, GraphsTuple, int], Array], + model_loss_fun: Callable[[ModelParameters, GraphsTuple, int, RNGSequences], Array], optimizer: optax.GradientTransformation, ema_fun: EMAParameterTransformation, avg_n_graphs_per_batch: float, @@ -49,9 +50,10 @@ def _training_step( # Calculate gradients. grad_fun = jax.grad(model_loss_fun, argnums=0, has_aux=True) - key, _ = jax.random.split(key, 2) + key, dropout_key, rotation_key = jax.random.split(key, 3) - grads, aux_info = grad_fun(params, graph, epoch_number) + rngs = {"dropout": dropout_key, 'rotation': rotation_key} + grads, aux_info = grad_fun(params, graph, epoch_number, rngs) # Aggregrate over devices. if should_parallelize: @@ -130,9 +132,9 @@ def make_train_step( """ def model_loss( - params: ModelParameters, ref_graph: GraphsTuple, epoch: int + params: ModelParameters, ref_graph: GraphsTuple, epoch: int, rngs: RNGSequences ) -> Array: - predictions = predictor.apply(params, ref_graph) + predictions = predictor.apply(params, ref_graph, training=True, rngs=rngs) return loss_fun(predictions, ref_graph, epoch) training_step = functools.partial(