Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mlip/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class GraphDatasetBuilderConfig(pydantic.BaseModel):
to ``True``, the models assume ``"zero"`` atomic
energies as can be set in the model hyperparameters.
avg_num_neighbors: The pre-computed average number of neighbors.
avg_num_nodes: The pre-computed average number of nodes.
avg_r_min_angstrom: The pre-computed average minimum distance between nodes.

"""
Expand All @@ -140,4 +141,5 @@ class GraphDatasetBuilderConfig(pydantic.BaseModel):

use_formation_energies: bool = False
avg_num_neighbors: Optional[float] = None
avg_num_nodes: Optional[float] = None
avg_r_min_angstrom: Optional[float] = None
12 changes: 12 additions & 0 deletions src/mlip/data/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Optional

import jraph
import numpy as np
import pydantic
from ase import Atom

Expand All @@ -40,6 +41,7 @@ class DatasetInfo(pydantic.BaseModel):
cutoff_distance_angstrom: The graph cutoff distance that was
used in the dataset in Angstrom.
avg_num_neighbors: The mean number of neighbors an atom has in the dataset.
avg_num_nodes: The mean number of nodes a structure has in the dataset.
avg_r_min_angstrom: The mean minimum edge distance for a structure in the
dataset.
scaling_mean: The mean used for the rescaling of the dataset values, the
Expand All @@ -51,6 +53,7 @@ class DatasetInfo(pydantic.BaseModel):
atomic_energies_map: dict[int, float]
cutoff_distance_angstrom: float
avg_num_neighbors: float = 1.0
avg_num_nodes: float = 1.0
avg_r_min_angstrom: Optional[float] = None
scaling_mean: float = 0.0
scaling_stdev: float = 1.0
Expand All @@ -62,6 +65,7 @@ def __str__(self):
return (
f"Atomic Energies: {atomic_energies_map_with_symbols}, "
f"Avg. num. neighbors: {self.avg_num_neighbors:.2f}, "
f"Avg. num. nodes: {self.avg_num_nodes:.2f}, "
f"Avg. r_min: {self.avg_r_min_angstrom:.2f}, "
f"Graph cutoff distance: {self.cutoff_distance_angstrom}"
)
Expand All @@ -72,6 +76,7 @@ def compute_dataset_info_from_graphs(
cutoff_distance_angstrom: float,
z_table: AtomicNumberTable,
avg_num_neighbors: Optional[float] = None,
avg_num_nodes: Optional[float] = None,
avg_r_min_angstrom: Optional[float] = None,
) -> DatasetInfo:
"""Computes the dataset info from graphs, typically training set graphs.
Expand All @@ -84,6 +89,8 @@ def compute_dataset_info_from_graphs(
map keys.
avg_num_neighbors: The optionally pre-computed average number of neighbors. If
provided, we skip recomputing this.
avg_num_nodes: The optionally pre-computed average number of nodes. If
provided, we skip recomputing this.
avg_r_min_angstrom: The optionally pre-computed average miminum radius. If
provided, we skip recomputing this.

Expand All @@ -99,6 +106,10 @@ def compute_dataset_info_from_graphs(
logger.debug("Computing average number of neighbors...")
avg_num_neighbors = compute_avg_num_neighbors(graphs)
logger.debug("Average number of neighbors: %.1f", avg_num_neighbors)
if avg_num_nodes is None:
logger.debug("Computing average number of nodes...")
avg_num_nodes = np.mean([i.item() for g in graphs for i in g.n_node])
logger.debug("Average number of nodes: %.1f", avg_num_nodes)
if avg_r_min_angstrom is None:
logger.debug("Computing average min neighbor distance...")
avg_r_min_angstrom = compute_avg_min_neighbor_distance(graphs)
Expand All @@ -119,6 +130,7 @@ def compute_dataset_info_from_graphs(
atomic_energies_map=atomic_energies_map,
cutoff_distance_angstrom=cutoff_distance_angstrom,
avg_num_neighbors=avg_num_neighbors,
avg_num_nodes=avg_num_nodes,
avg_r_min_angstrom=avg_r_min_angstrom,
scaling_mean=0.0,
scaling_stdev=1.0,
Expand Down
1 change: 1 addition & 0 deletions src/mlip/data/graph_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def prepare_datasets(self) -> None:
self._config.graph_cutoff_angstrom,
z_table,
self._config.avg_num_neighbors,
self._config.avg_num_nodes,
self._config.avg_r_min_angstrom,
)

Expand Down
3 changes: 3 additions & 0 deletions src/mlip/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
from mlip.models.nequip.models import Nequip
from mlip.models.predictor import ForceFieldPredictor
from mlip.models.visnet.models import Visnet
from mlip.models.liten.models import Liten
from mlip.models.so3krates.models import So3krates
from mlip.models.equiformer_v2.models import EquiformerV2
101 changes: 101 additions & 0 deletions src/mlip/models/cutoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2025 Zhongguancun Academy

"""
This module contains all cutoff / radial envelope functions seen in all models.
It can be used to refactor Mace, Visnet and Nequip.
"""

from enum import Enum

import jax
import jax.numpy as jnp
import flax.linen as nn
import e3nn_jax as e3nn


class SoftCutoff(nn.Module):
"""Soft envelope radial envelope function."""
cutoff: float
arg_multiplicator: float = 2.0
value_at_origin: float = 1.2

@nn.compact
def __call__(self, length):
return e3nn.soft_envelope(
length,
self.cutoff,
arg_multiplicator=self.arg_multiplicator,
value_at_origin=self.value_at_origin,
)


class PolynomialCutoff(nn.Module):
"""Polynomial radial envelope function from the MACE torch version."""
cutoff: float
p: int = 5

@nn.compact
def __call__(self, length: jax.Array):
a = - (self.p + 1.0) * (self.p + 2.0) / 2.0
b = self.p * (self.p + 2.0)
c = - self.p * (self.p + 1.0) / 2

x_norm = length / self.cutoff
envelope = 1.0 + jnp.pow(x_norm, self.p) * (
a + x_norm * (b + x_norm * c)
)
return envelope * (length < self.cutoff)


class CosineCutoff(nn.Module):
"""Behler-style cosine cutoff function."""
cutoff: float

@nn.compact
def __call__(self, length: jax.Array) -> jax.Array:
cutoffs = 0.5 * (jnp.cos(length * jnp.pi / self.cutoff) + 1.0)
return cutoffs * (length < self.cutoff)


class PhysCutoff(nn.Module):
"""Cutoff function used in PhysNet."""
cutoff: float

@nn.compact
def __call__(self, length: jax.Array) -> jax.Array:
x_norm = length / self.cutoff
cutoffs = 1 - 6 * x_norm ** 5 + 15 * x_norm ** 4 - 10 * x_norm ** 3
return cutoffs * (length < self.cutoff)


class ExponentialCutoff(nn.Module):
"""Exponential cutoff function used in SpookyNet."""
cutoff: float

@nn.compact
def __call__(self, length: jax.Array) -> jax.Array:
# TODO(bhcao): Check if this is numerically stable.
cutoffs = jnp.exp(-length ** 2 / ((self.cutoff - length) * (self.cutoff + length)))
return cutoffs * (length < self.cutoff)

# --- Options ---


class CutoffFunction(Enum):
POLYNOMIAL = "polynomial"
SOFT = "soft"
COSINE = "cosine"
PHYS = "phys"
EXPONENTIAL = "exponential"


def parse_cutoff(cutoff: CutoffFunction | str) -> type[nn.Module]:
cutoff_function_map = {
CutoffFunction.POLYNOMIAL: PolynomialCutoff,
CutoffFunction.SOFT: SoftCutoff,
CutoffFunction.COSINE: CosineCutoff,
CutoffFunction.PHYS: PhysCutoff,
CutoffFunction.EXPONENTIAL: ExponentialCutoff,
}
assert set(CutoffFunction) == set(cutoff_function_map.keys())
return cutoff_function_map[CutoffFunction(cutoff)]
95 changes: 95 additions & 0 deletions src/mlip/models/equiformer_v2/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2025 Zhongguancun Academy
# Based on code initially developed by Yi-Lun Liao (https://github.com/atomicarchitects/equiformer_v2) under MIT license.

import jax
import jax.numpy as jnp
import flax.linen as nn

from mlip.models.equiformer_v2.transform import get_s2grid_mats
from mlip.models.equiformer_v2.utils import get_expand_index


class SmoothLeakyReLU(nn.Module):
"""Smooth Leaky ReLU activation."""
negative_slope: float = 0.2

@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
x1 = ((1 + self.negative_slope) / 2) * x
x2 = ((1 - self.negative_slope) / 2) * x * (2 * nn.sigmoid(x) - 1)
return x1 + x2

# --- Vector activation functions ---


class GateActivation(nn.Module):
"""Apply gate for vector and silu for scalar."""
lmax: int
mmax: int
num_channels: int
m_prime: bool = False

@nn.compact
def __call__(self, gating_scalars: jax.Array, input_tensors: jax.Array) -> jax.Array:
"""
`gating_scalars`: shape [N, lmax * num_channels]
`input_tensors`: shape [N, (lmax + 1) ** 2, num_channels]
"""
expand_index = get_expand_index(
self.lmax, self.mmax, vector_only=True, m_prime=self.m_prime
)

gating_scalars = nn.sigmoid(gating_scalars)
gating_scalars = gating_scalars.reshape(
gating_scalars.shape[0], self.lmax, self.num_channels
)[:, expand_index]

input_tensors_scalars = nn.silu(input_tensors[:, 0:1])
input_tensors_vectors = input_tensors[:, 1:] * gating_scalars
output_tensors = jnp.concat(
(input_tensors_scalars, input_tensors_vectors), axis=1
)

return output_tensors


class S2Activation(nn.Module):
"""Apply silu on sphere function."""
lmax: int
mmax: int
resolution: int
m_prime: bool = False

@nn.compact
def __call__(self, inputs: jax.Array) -> jax.Array:
so3_grid = get_s2grid_mats(
self.lmax, self.mmax, resolution=self.resolution, m_prime=self.m_prime
)

x_grid = so3_grid.to_grid(inputs)
x_grid = nn.silu(x_grid)
outputs = so3_grid.from_grid(x_grid)
return outputs


class SeparableS2Activation(nn.Module):
"""Apply silu on sphere function for vector and silu directly for scalar."""
lmax: int
mmax: int
resolution: int
m_prime: bool = False

@nn.compact
def __call__(self, input_scalars: jax.Array, input_tensors: jax.Array) -> jax.Array:
output_scalars = nn.silu(input_scalars)
output_tensors = S2Activation(
self.lmax, self.mmax, self.resolution, self.m_prime
)(input_tensors)
outputs = jnp.concat(
(
output_scalars[:, None],
output_tensors[:, 1 : output_tensors.shape[1]],
),
axis=1,
)
return outputs
Loading
Loading