From 5e1aca2b7976168e3918f326cb12d3f6c3008550 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 10 Apr 2026 15:06:33 +0200 Subject: [PATCH 01/14] Add SKILL.md files for cuequivariance and cuequivariance_jax Reference documentation for AI coding assistants covering: - cuequivariance: custom group definition (Irrep subclass with @dataclass), segmented tensor products, CG coefficients, descriptors, EquivariantPolynomial - cuequivariance_jax: segmented_polynomial primitive (naive/uniform_1d), RepArray and ir_dict interfaces, NNX layers, indexing, multi-batch axes Co-Authored-By: Claude Opus 4.6 (1M context) --- cuequivariance/SKILL.md | 334 +++++++++++++++++++++++++++ cuequivariance_jax/SKILL.md | 440 ++++++++++++++++++++++++++++++++++++ 2 files changed, 774 insertions(+) create mode 100644 cuequivariance/SKILL.md create mode 100644 cuequivariance_jax/SKILL.md diff --git a/cuequivariance/SKILL.md b/cuequivariance/SKILL.md new file mode 100644 index 0000000..a6325f7 --- /dev/null +++ b/cuequivariance/SKILL.md @@ -0,0 +1,334 @@ +--- +name: cuequivariance +description: Define custom groups (Irrep subclasses), build segmented tensor products with CG coefficients, create equivariant polynomials, and use built-in descriptors (linear, tensor products, spherical harmonics). Use when working with cuequivariance group theory, irreps, or segmented polynomials. +--- + +# cuequivariance: Groups, Irreps, and Segmented Polynomials + +## Overview + +`cuequivariance` (imported as `cue`) provides two core abstractions: + +1. **Group theory**: `Irrep` subclasses define irreducible representations of Lie groups (SO3, O3, SU2, or custom). `Irreps` manages collections with multiplicities. +2. **Segmented polynomials**: `SegmentedTensorProduct` describes tensor contractions over segments of varying shape, linked by `Path` objects carrying Clebsch-Gordan coefficients. `SegmentedPolynomial` wraps multiple STPs into a polynomial with named inputs/outputs. `EquivariantPolynomial` attaches group representations to each operand. + +## Defining a custom group + +Subclass `cue.Irrep` (a frozen dataclass) and implement: + +```python +from __future__ import annotations +import dataclasses +import re +from typing import Iterator +import numpy as np +import cuequivariance as cue + +@dataclasses.dataclass(frozen=True) +class Z2(cue.Irrep): + odd: bool # dataclass field -- required for correct __eq__ and __hash__ + + # No __init__ needed -- @dataclass(frozen=True) generates it: Z2(odd=True) + + @classmethod + def regexp_pattern(cls) -> re.Pattern: + # Pattern whose first group is passed to from_string + return re.compile(r"(odd|even)") + + @classmethod + def from_string(cls, string: str) -> Z2: + return cls(odd=string == "odd") + + def __repr__(rep: Z2) -> str: + return "odd" if rep.odd else "even" + + def __mul__(rep1: Z2, rep2: Z2) -> Iterator[Z2]: + # Selection rule: which irreps appear in the tensor product rep1 x rep2 + return [Z2(odd=rep1.odd ^ rep2.odd)] + + @classmethod + def clebsch_gordan(cls, rep1: Z2, rep2: Z2, rep3: Z2) -> np.ndarray: + # Shape: (num_paths, rep1.dim, rep2.dim, rep3.dim) + if rep3 in rep1 * rep2: + return np.array([[[[1]]]]) + else: + return np.zeros((0, 1, 1, 1)) + + @property + def dim(rep: Z2) -> int: + return 1 + + def __lt__(rep1: Z2, rep2: Z2) -> bool: + # Ordering for sorting; dimension is compared first by the base class + return rep1.odd < rep2.odd + + @classmethod + def iterator(cls) -> Iterator[Z2]: + # Must yield trivial irrep first + for odd in [False, True]: + yield Z2(odd=odd) + + def discrete_generators(rep: Z2) -> np.ndarray: + # Shape: (num_generators, dim, dim) + if rep.odd: + return -np.ones((1, 1, 1)) + else: + return np.ones((1, 1, 1)) + + def continuous_generators(rep: Z2) -> np.ndarray: + # Shape: (lie_dim, dim, dim) -- Z2 is discrete, so lie_dim=0 + return np.zeros((0, rep.dim, rep.dim)) + + def algebra(self) -> np.ndarray: + # Shape: (lie_dim, lie_dim, lie_dim) -- structure constants [X_i, X_j] = A_ijk X_k + return np.zeros((0, 0, 0)) + + +# Usage: +irreps = cue.Irreps(Z2, "3x odd + 2x even") # dim=5 +``` + +### Required methods summary + +| Method | Returns | Purpose | +|--------|---------|---------| +| `regexp_pattern()` | `re.Pattern` | Parse string like `"1"`, `"0e"`, `"odd"` | +| `from_string(s)` | `Irrep` | Construct from matched string | +| `__repr__` | `str` | Canonical string form | +| `__mul__(a, b)` | `Iterator[Irrep]` | Selection rule for tensor product | +| `clebsch_gordan(a, b, c)` | `ndarray (n, d1, d2, d3)` | CG coefficients | +| `dim` (property) | `int` | Dimension of representation | +| `__lt__(a, b)` | `bool` | Ordering (dimension first, then custom) | +| `iterator()` | `Iterator[Irrep]` | All irreps, trivial first | +| `continuous_generators()` | `ndarray (lie_dim, dim, dim)` | Lie algebra generators | +| `discrete_generators()` | `ndarray (n, dim, dim)` | Finite symmetry generators | +| `algebra()` | `ndarray (lie_dim, lie_dim, lie_dim)` | Structure constants | + +### Built-in groups + +- **`cue.SO3(l)`**: 3D rotations. `l` is a non-negative integer. `dim = 2l+1`. String: `"0"`, `"1"`, `"2"`. +- **`cue.O3(l, p)`**: 3D rotations + parity. `p=+1` (even) or `p=-1` (odd). String: `"0e"`, `"1o"`, `"2e"`. +- **`cue.SU2(j)`**: Spin group. `j` is a non-negative half-integer. String: `"0"`, `"1/2"`, `"1"`. + +## Irreps and layout + +```python +irreps = cue.Irreps("SO3", "16x0 + 4x1 + 2x2") # 16 scalars, 4 vectors, 2 rank-2 +irreps.dim # 16*1 + 4*3 + 2*5 = 38 + +for mul, ir in irreps: + print(mul, ir, ir.dim) # 16 0 1, then 4 1 3, then 2 2 5 +``` + +`IrrepsLayout` controls memory order within each `(mul, ir)` block: + +- `cue.ir_mul`: data ordered as `(ir.dim, mul)` -- **used by all descriptors** +- `cue.mul_ir`: data ordered as `(mul, ir.dim)` -- **used by nnx dict[Irrep, Array]** + +`IrrepsAndLayout` combines irreps with a layout into a `Rep`: + +```python +rep = cue.IrrepsAndLayout(cue.Irreps("SO3", "4x0 + 2x1"), cue.ir_mul) +rep.dim # 10 +``` + +## Building a SegmentedTensorProduct from scratch + +The subscripts string uses Einstein notation. Operands are comma-separated, coefficient modes follow `+`. + +```python +# Matrix-vector multiply: y_i = sum_j M_ij * x_j +d = cue.SegmentedTensorProduct.from_subscripts("ij,j,i") +d.add_segment(0, (3, 4)) # operand 0: matrix segment of shape (3, 4) +d.add_segment(1, (4,)) # operand 1: vector of size 4 +d.add_segment(2, (3,)) # operand 2: output vector of size 3 +d.add_path(0, 0, 0, c=1.0) # link segments 0,0,0 with coefficient=1.0 + +poly = cue.SegmentedPolynomial.eval_last_operand(d) # last operand becomes output +[y] = poly(M_flat, x) # numpy evaluation +``` + +### Multi-segment STP (how descriptors work internally) + +Descriptors build STPs with multiple segments per operand. Each segment corresponds to an irrep block: + +```python +# Linear equivariant map: output[iv] = sum_u weight[uv] * input[iu] +d = cue.SegmentedTensorProduct.from_subscripts("uv,iu,iv") + +# Segment for l=1: ir_dim=3, mul_in=2, mul_out=5 +s_in_0 = d.add_segment(1, (3, 2)) # input block +s_out_0 = d.add_segment(2, (3, 5)) # output block +d.add_path((2, 5), s_in_0, s_out_0, c=1.0) + +# Segment for l=0: ir_dim=1, mul_in=4, mul_out=3 +s_in_1 = d.add_segment(1, (1, 4)) +s_out_1 = d.add_segment(2, (1, 3)) +d.add_path((4, 3), s_in_1, s_out_1, c=1.0) +``` + +### Weights operand + +For weighted tensor products (subscript starting with `uvw` or `uv`), the first operand is always weights. The weight segment shape is `(mul_1, mul_2, ...)` matching the multiplicity modes. The weights operand gets `new_scalars()` irreps since weights are invariant. + +### CG coefficients as path coefficients + +```python +d = cue.SegmentedTensorProduct.from_subscripts("uvw,iu,jv,kw+ijk") +# For each pair of input irreps and each output irrep in the selection rule: +for cg in cue.clebsch_gordan(ir1, ir2, ir3): + # cg has shape (ir1.dim, ir2.dim, ir3.dim) + d.add_path((mul1, mul2, mul3), seg_in1, seg_in2, seg_out, c=cg) +``` + +## Using descriptors (high-level API) + +All descriptors return `cue.EquivariantPolynomial`: + +```python +# Fully connected tensor product (all input-output irrep combinations) +e = cue.descriptors.fully_connected_tensor_product( + 16 * cue.Irreps("SO3", "0 + 1 + 2"), + 16 * cue.Irreps("SO3", "0 + 1 + 2"), + 16 * cue.Irreps("SO3", "0 + 1 + 2"), +) + +# Channelwise tensor product (same-channel only, sparse) +e = cue.descriptors.channelwise_tensor_product( + 64 * cue.Irreps("SO3", "0 + 1"), cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), simplify_irreps3=True, +) + +# Elementwise tensor product (paired channels) +e = cue.descriptors.elementwise_tensor_product( + cue.Irreps("SO3", "4x0 + 4x1"), cue.Irreps("SO3", "4x0 + 4x1"), +) + +# Linear equivariant map (no second input, just weight x input) +e = cue.descriptors.linear( + cue.Irreps("SO3", "4x0 + 2x1"), + cue.Irreps("SO3", "3x0 + 5x1"), +) + +# Spherical harmonics +e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3]) + +# Symmetric contraction (MACE-style) +e = cue.descriptors.symmetric_contraction( + 64 * cue.Irreps("SO3", "0 + 1 + 2"), + 64 * cue.Irreps("SO3", "0 + 1"), + [0, 1, 2, 3], +) +``` + +### EquivariantPolynomial key methods + +```python +e.inputs # tuple of Rep (group representations for each input) +e.outputs # tuple of Rep +e.polynomial # the underlying SegmentedPolynomial + +# Numpy evaluation +[out] = e(weights, input1, input2) + +# Preparing for uniform_1d execution (see cuequivariance_jax SKILL.md) +e_ready = e.squeeze_modes().flatten_coefficient_modes() + +# Split an operand into per-irrep pieces (for ir_dict interface) +e_split = e.split_operand_by_irrep(1).split_operand_by_irrep(-1) + +# Scale all coefficients +e_scaled = e * 0.5 + +# Fuse compatible STPs +e_fused = e.fuse_stps() +``` + +### normalize_paths_for_operand + +Called internally by descriptors. Normalizes path coefficients so that a random input produces unit-variance output for the specified operand. Critical for numerical stability. + +## SegmentedPolynomial structure + +```python +poly = e.polynomial +poly.num_inputs # number of input operands +poly.num_outputs # number of output operands +poly.inputs # tuple of SegmentedOperand +poly.outputs # tuple of SegmentedOperand +poly.operations # tuple of (Operation, SegmentedTensorProduct) + +# Each operation maps buffers to STP operands +for op, stp in poly.operations: + print(op.buffers) # e.g., (0, 1, 2) means inputs[0], inputs[1] -> outputs[0] + print(stp.subscripts) +``` + +### SegmentedOperand + +```python +operand = poly.inputs[0] +operand.num_segments # how many segments +operand.segments # tuple of shape tuples, e.g., ((3, 4), (1, 2)) +operand.size # total flattened size (sum of products of segment shapes) +operand.ndim # number of dimensions per segment +operand.all_same_segment_shape() # True if all segments have identical shape +operand.segment_shape # the common shape (only if all_same_segment_shape) +``` + +## Custom equivariant polynomial from scratch + +```python +import numpy as np +import cuequivariance as cue + +# Build a fully-connected SO3(1)xSO3(1)->SO3(0) tensor product manually +cg = cue.clebsch_gordan(cue.SO3(1), cue.SO3(1), cue.SO3(0)) # shape (1, 3, 3, 1) + +d = cue.SegmentedTensorProduct.from_subscripts("uvw,iu,jv,kw+ijk") +d.add_segment(1, (3, 4)) # input1: 4x SO3(1), shape=(ir_dim, mul) +d.add_segment(2, (3, 4)) # input2: 4x SO3(1) +d.add_segment(3, (1, 16)) # output: 16x SO3(0) (4*4 fully connected) + +for c in cg: + d.add_path((4, 4, 16), 0, 0, 0, c=c) + +d = d.normalize_paths_for_operand(-1) + +poly = cue.SegmentedPolynomial.eval_last_operand(d) +ep = cue.EquivariantPolynomial( + [ + cue.IrrepsAndLayout(cue.Irreps("SO3", "4x1").new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(cue.Irreps("SO3", "4x1"), cue.ir_mul), + cue.IrrepsAndLayout(cue.Irreps("SO3", "4x1"), cue.ir_mul), + ], + [cue.IrrepsAndLayout(cue.Irreps("SO3", "16x0"), cue.ir_mul)], + poly, +) + +# Numpy evaluation +w = np.random.randn(ep.inputs[0].dim) +x = np.random.randn(ep.inputs[1].dim) +y = np.random.randn(ep.inputs[2].dim) +[out] = ep(w, x, y) +``` + +## Key file locations + +| Component | Path | +|-----------|------| +| `Irrep` base class | `cuequivariance/group_theory/representations/irrep.py` | +| `Rep` base class | `cuequivariance/group_theory/representations/rep.py` | +| `SO3` | `cuequivariance/group_theory/representations/irrep_so3.py` | +| `O3` | `cuequivariance/group_theory/representations/irrep_o3.py` | +| `SU2` | `cuequivariance/group_theory/representations/irrep_su2.py` | +| `Irreps` | `cuequivariance/group_theory/irreps_array/irreps.py` | +| `IrrepsLayout` | `cuequivariance/group_theory/irreps_array/irreps_layout.py` | +| `IrrepsAndLayout` | `cuequivariance/group_theory/irreps_array/irreps_and_layout.py` | +| `SegmentedTensorProduct` | `cuequivariance/segmented_polynomials/segmented_tensor_product.py` | +| `SegmentedPolynomial` | `cuequivariance/segmented_polynomials/segmented_polynomial.py` | +| `EquivariantPolynomial` | `cuequivariance/group_theory/equivariant_polynomial.py` | +| Descriptors | `cuequivariance/group_theory/descriptors/` | +| `fully_connected_tensor_product` etc. | `cuequivariance/group_theory/descriptors/irreps_tp.py` | +| `spherical_harmonics` | `cuequivariance/group_theory/descriptors/spherical_harmonics_.py` | +| `symmetric_contraction` | `cuequivariance/group_theory/descriptors/symmetric_contractions.py` | diff --git a/cuequivariance_jax/SKILL.md b/cuequivariance_jax/SKILL.md new file mode 100644 index 0000000..e0a5865 --- /dev/null +++ b/cuequivariance_jax/SKILL.md @@ -0,0 +1,440 @@ +--- +name: cuequivariance-jax +description: Execute equivariant polynomials in JAX using segmented_polynomial (naive/uniform_1d), equivariant_polynomial with RepArray, ir_dict with dict[Irrep, Array], and Flax NNX layers (IrrepsLinear, SphericalHarmonics). Use when writing JAX code with cuequivariance. +--- + +# cuequivariance_jax: Executing Equivariant Polynomials in JAX + +## Overview + +`cuequivariance_jax` (imported as `cuex`) executes `cuequivariance` polynomials on GPU via JAX. It provides: + +1. **Core primitive**: `cuex.segmented_polynomial()` -- JAX primitive with full AD/vmap/JIT support +2. **Two data representations** (both built on `segmented_polynomial`): + - `cuex.equivariant_polynomial()` + `RepArray` -- the original interface, a single contiguous array with representation metadata + - `cuex.ir_dict` module -- `dict[Irrep, Array]` interface, conceptually simpler, works naturally with `jax.tree` operations +3. **NNX layers**: `cuex.nnx` module -- Flax NNX `Module` wrappers using `dict[Irrep, Array]` + +## Execution methods + +| Method | Backend | Requirements | +|--------|---------|-------------| +| `"naive"` | Pure JAX | Always works, any platform | +| `"uniform_1d"` | CUDA kernel | GPU, all segments uniform shape within each operand, single mode | +| `"indexed_linear"` | CUDA kernel | GPU, linear operations with `cuex.Repeats` indexing | + +## Core primitive: segmented_polynomial + +```python +import jax +import jax.numpy as jnp +import cuequivariance as cue +import cuequivariance_jax as cuex + +# Build a descriptor +e = cue.descriptors.channelwise_tensor_product( + 32 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), +) +poly = e.polynomial + +batch = 64 +w = jnp.ones((poly.inputs[0].size,)) # weights (shared across batch) +x = jax.random.normal(key, (batch, poly.inputs[1].size)) # batched input 1 +y = jax.random.normal(key, (batch, poly.inputs[2].size)) # batched input 2 + +# Execute with naive method +[out] = cuex.segmented_polynomial( + poly, + [w, x, y], # inputs + [jax.ShapeDtypeStruct((batch, poly.outputs[0].size), jnp.float32)], # output spec + method="naive", +) + +# Execute with uniform_1d (GPU, requires uniform segments) +[out] = cuex.segmented_polynomial( + poly, [w, x, y], + [jax.ShapeDtypeStruct((batch, poly.outputs[0].size), jnp.float32)], + method="uniform_1d", +) +``` + +### Multiple batch axes with broadcasting + +Inputs can have any number of batch axes (everything before the last axis). Standard NumPy broadcasting applies: each batch axis is either size-1 or a common size. Inputs with fewer batch dimensions are implicitly prepended with size-1 axes: + +```python +# 2 batch axes with size-1 broadcasting +w = jnp.ones((1, 10, poly.inputs[0].size)) # shared across axis 0 +x = jnp.ones((5, 10, poly.inputs[1].size)) # 5 along axis 0 +y = jnp.ones((5, 1, poly.inputs[2].size)) # shared across axis 1 + +[out] = cuex.segmented_polynomial( + poly, [w, x, y], + [jax.ShapeDtypeStruct((5, 10, poly.outputs[0].size), jnp.float32)], + method="uniform_1d", +) +# out.shape == (5, 10, ...) + +# Fewer batch dims: weights with no batch axis broadcast across all +w = jnp.ones((poly.inputs[0].size,)) # 0 batch axes -> prepended as (1, 1, ...) +x = jnp.ones((5, 10, poly.inputs[1].size)) +y = jnp.ones((5, 10, poly.inputs[2].size)) + +[out] = cuex.segmented_polynomial( + poly, [w, x, y], + [jax.ShapeDtypeStruct((5, 10, poly.outputs[0].size), jnp.float32)], + method="uniform_1d", +) +``` + +### Indexing (gather/scatter) + +Index arrays provide gather (for inputs) and scatter (for outputs). One index per operand (inputs + outputs), `None` means no indexing. Index arrays decouple input/output batch shapes -- the output shape is determined by the index ranges, not by the input shapes: + +```python +a = jnp.ones((1, 50, poly.inputs[0].size)) +b = jnp.ones((10, 50, poly.inputs[1].size)) +c = jnp.ones((100, 1, poly.inputs[2].size)) + +i = jax.random.randint(key, (100, 50), 0, 10) # gather b along axis 0 +j1 = jax.random.randint(key, (100, 50), 0, 11) # scatter output axis 0 +j2 = jax.random.randint(key, (100, 1), 0, 12) # scatter output axis 1 + +[out] = cuex.segmented_polynomial( + poly, [a, b, c], + [jax.ShapeDtypeStruct((11, 12, poly.outputs[0].size), jnp.float32)], + indices=[None, np.s_[i, :], None, np.s_[j1, j2]], + method="uniform_1d", +) +# out.shape == (11, 12, ...) -- determined by index ranges, not input shapes +``` + +### Gradients + +Fully differentiable -- supports `jax.grad`, `jax.jacobian`, `jax.jvp`, `jax.vmap`: + +```python +def loss(w, x, y): + [out] = cuex.segmented_polynomial( + poly, [w, x, y], + [jax.ShapeDtypeStruct((batch, poly.outputs[0].size), jnp.float32)], + method="naive", + ) + return jnp.sum(out ** 2) + +grad_w = jax.grad(loss, 0)(w, x, y) +``` + +## RepArray interface: equivariant_polynomial + +The original interface. Wraps `segmented_polynomial` with `RepArray` -- a single contiguous array with representation metadata: + +```python +e = cue.descriptors.fully_connected_tensor_product( + 4 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + 4 * cue.Irreps("SO3", "0 + 1"), +) + +inputs = [ + cuex.randn(jax.random.key(i), rep, (batch,), jnp.float32) + for i, rep in enumerate(e.inputs) +] + +# Returns a RepArray with representation metadata +out = cuex.equivariant_polynomial(e, inputs, method="naive") +out.array # the raw jax.Array +out.reps # dict mapping axes to Rep objects +``` + +## ir_dict interface + +An alternative to `RepArray`. Uses `dict[Irrep, Array]` where each value has shape `(..., multiplicity, irrep_dim)`. Conceptually simpler: works naturally with `jax.tree` operations and is the standard representation for NNX layers. + +### Preparing a polynomial for ir_dict + +Descriptors produce `EquivariantPolynomial` with dense operands. To use `ir_dict`, split operands by irrep: + +```python +e = cue.descriptors.channelwise_tensor_product( + 32 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + simplify_irreps3=True, +) + +# Split irreps-typed operands into per-irrep pieces +# Order matters: split inner operands first to preserve operand indices +poly = ( + e.split_operand_by_irrep(2) # split input 2 + .split_operand_by_irrep(1) # split input 1 + .split_operand_by_irrep(-1) # split output + .polynomial +) +# After split: all operands have uniform segment shapes (required for uniform_1d) +``` + +### Executing with segmented_polynomial_uniform_1d + +```python +from einops import rearrange + +num_edges, num_nodes = 100, 30 + +# Weights: reshape to (batch, num_segments, segment_size) +w_flat = jax.random.normal(key, (num_edges, poly.inputs[0].size)) +w = rearrange(w_flat, "e (s m) -> e s m", s=poly.inputs[0].num_segments) + +# Node features: dict[Irrep, Array] reshaped to (nodes, ir.dim, mul) for ir_mul layout +node_feats = { + cue.SO3(0): jnp.ones((num_nodes, 32, 1)), # 32x scalar + cue.SO3(1): jnp.ones((num_nodes, 32, 3)), # 32x vector +} +x1 = jax.tree.map(lambda v: rearrange(v, "n m i -> n i m"), node_feats) + +# Spherical harmonics: (edges, ir.dim) -- no multiplicity dimension +sph = { + cue.SO3(0): jnp.ones((num_edges, 1)), + cue.SO3(1): jnp.ones((num_edges, 3)), +} + +# Build output template +senders = jax.random.randint(key, (num_edges,), 0, num_nodes) +receivers = jax.random.randint(key, (num_edges,), 0, num_nodes) +irreps_out = e.outputs[0].irreps +out_template = { + ir: jax.ShapeDtypeStruct( + (num_nodes, desc.num_segments) + desc.segment_shape, w.dtype + ) + for (_, ir), desc in zip(irreps_out, poly.outputs) +} + +# Execute with gather (senders) and scatter (receivers) +y = cuex.ir_dict.segmented_polynomial_uniform_1d( + poly, + [w, x1, sph], + out_template, + input_indices=[None, senders, None], + output_indices=receivers, + name="tensor_product", +) +# y is dict[Irrep, Array] with accumulated results at receiver nodes +``` + +### ir_dict utility functions + +```python +# Validate dict matches irreps +cuex.ir_dict.assert_mul_ir_dict(irreps, x) # asserts shape (..., mul, ir.dim) + +# Convert flat array <-> dict +d = cuex.ir_dict.flat_to_dict(irreps, flat_array) # layout="mul_ir" default +d = cuex.ir_dict.flat_to_dict(irreps, flat_array, layout="ir_mul") +flat = cuex.ir_dict.dict_to_flat(irreps, d) + +# Arithmetic +z = cuex.ir_dict.irreps_add(x, y) +z = cuex.ir_dict.irreps_zeros_like(x) + +# Create template dict +template = cuex.ir_dict.mul_ir_dict(irreps, jax.ShapeDtypeStruct(shape, dtype)) +``` + +## NNX layers + +### IrrepsLinear + +Equivariant linear layer using `dict[Irrep, Array]`: + +```python +from flax import nnx + +linear = cuex.nnx.IrrepsLinear( + irreps_in=cue.Irreps(cue.SO3, "4x0 + 2x1").regroup(), # must be regrouped + irreps_out=cue.Irreps(cue.SO3, "3x0 + 5x1").regroup(), + scale=1.0, + dtype=jnp.float32, + rngs=nnx.Rngs(0), +) + +# Input/output: dict[Irrep, Array] with shape (batch, mul, ir.dim) +x = { + cue.SO3(0): jnp.ones((batch, 4, 1)), + cue.SO3(1): jnp.ones((batch, 2, 3)), +} +y = linear(x) +# y[cue.SO3(0)].shape == (batch, 3, 1) +# y[cue.SO3(1)].shape == (batch, 5, 3) +``` + +Implementation uses `jnp.einsum("uv,...ui->...vi", w, x[ir])` per irrep with `1/sqrt(mul_in)` normalization. + +### SphericalHarmonics + +```python +sh = cuex.nnx.SphericalHarmonics(max_degree=3, eps=0.0) + +vectors = jax.random.normal(key, (batch, 3)) # 3D vectors +y = sh(vectors) +# y[cue.O3(0, 1)].shape == (batch, 1, 1) # L=0 +# y[cue.O3(1, -1)].shape == (batch, 1, 3) # L=1 +# y[cue.O3(2, 1)].shape == (batch, 1, 5) # L=2 +# y[cue.O3(3, -1)].shape == (batch, 1, 7) # L=3 +``` + +### IrrepsNormalize + +```python +norm = cuex.nnx.IrrepsNormalize(eps=1e-6, scale=1.0, skip_scalars=True) +y = norm(x) # normalizes non-scalar irreps by RMS over ir.dim, averaged over mul +``` + +### MLP (scalar only) + +```python +mlp = cuex.nnx.MLP( + layer_sizes=[64, 128, 64], + activation=jax.nn.silu, + output_activation=False, + dtype=jnp.float32, + rngs=nnx.Rngs(0), +) +y = mlp(x_scalar) # standard dense MLP with 1/sqrt(fan_in) normalization +``` + +### IrrepsIndexedLinear + +For species-indexed linear layers (different weights per atom type): + +```python +indexed_linear = cuex.nnx.IrrepsIndexedLinear( + irreps_in=cue.Irreps(cue.O3, "8x0e").regroup(), + irreps_out=cue.Irreps(cue.O3, "16x0e").regroup(), + num_indices=50, # number of species + scale=1.0, + dtype=jnp.float32, + rngs=nnx.Rngs(0), +) + +# num_index_counts: how many atoms of each species +species_counts = jnp.array([3, 4, 3, ...]) # sum = batch_size +y = indexed_linear(x, species_counts) +``` + +Uses `method="indexed_linear"` internally with `cuex.Repeats`. + +## Preparing polynomials for uniform_1d + +The `uniform_1d` CUDA kernel requires: +1. All segments within each operand have **the same shape** +2. A **single mode** in the subscripts (after preprocessing) + +### From EquivariantPolynomial to uniform_1d-ready + +For `equivariant_polynomial()` (RepArray interface): + +```python +e = cue.descriptors.channelwise_tensor_product(...) +e = e.squeeze_modes().flatten_coefficient_modes() +# If still >1 mode: e = e.flatten_modes(["u", "w"]) +out = cuex.equivariant_polynomial(e, inputs, method="uniform_1d") +``` + +For `ir_dict` (dict[Irrep, Array] interface): + +```python +e = cue.descriptors.channelwise_tensor_product(..., simplify_irreps3=True) +poly = ( + e.split_operand_by_irrep(2) # split input 2 + .split_operand_by_irrep(1) # split input 1 + .split_operand_by_irrep(-1) # split output + .polynomial +) +# Now each operand has uniform segments -> use ir_dict.segmented_polynomial_uniform_1d +``` + +### Why split_operand_by_irrep works + +Before splitting, a polynomial has dense operands like `32x0+32x1` with segments `((1,32), (3,32))` -- non-uniform shapes. After `split_operand_by_irrep`, each piece has a single irrep type, so all segments within that operand are the same shape. The pytree structure of `dict[Irrep, Array]` maps naturally to the split operands. + +## RepArray + +Representation-aware JAX array: + +```python +rep = cue.IrrepsAndLayout(cue.Irreps("SO3", "4x0 + 2x1"), cue.ir_mul) +x = cuex.RepArray(rep, jnp.ones((batch, rep.dim))) +x = cuex.randn(jax.random.key(0), rep, (batch,), jnp.float32) + +x.array # raw jax.Array +x.reps # {axis: Rep} +x.irreps # Irreps (if last axis is IrrepsAndLayout) +``` + +## Complete GNN message-passing example + +This pattern is used in NequIP, MACE, and similar equivariant GNN models: + +```python +class MessagePassing(nnx.Module): + def __init__(self, irreps_in, irreps_sh, irreps_out, epsilon, *, name, dtype, rngs): + e = ( + cue.descriptors.channelwise_tensor_product( + irreps_in, irreps_sh, irreps_out, True + ) + * epsilon + ) + self.weight_numel = e.inputs[0].dim + self.irreps_out = e.outputs[0].irreps + self.poly = ( + e.split_operand_by_irrep(2) + .split_operand_by_irrep(1) + .split_operand_by_irrep(-1) + .polynomial + ) + + def __call__(self, weights, node_feats, sph, senders, receivers, num_nodes): + # weights: (num_edges, weight_numel) + w = rearrange(weights, "e (s m) -> e s m", s=self.poly.inputs[0].num_segments) + # node_feats: dict[Irrep, Array] with (nodes, mul, ir.dim) + x1 = jax.tree.map(lambda v: rearrange(v, "n m i -> n i m"), node_feats) + # sph: dict[Irrep, Array] with (edges, 1, ir.dim) or (edges, ir.dim) + x2 = jax.tree.map(lambda v: rearrange(v, "e 1 i -> e i"), sph) + + out_template = { + ir: jax.ShapeDtypeStruct( + (num_nodes, desc.num_segments) + desc.segment_shape, w.dtype + ) + for (_, ir), desc in zip(self.irreps_out, self.poly.outputs) + } + + y = cuex.ir_dict.segmented_polynomial_uniform_1d( + self.poly, [w, x1, x2], out_template, + input_indices=[None, senders, None], + output_indices=receivers, + name="tensor_product", + ) + # Rearrange output back to (nodes, mul, ir.dim) for downstream layers + return { + ir: rearrange(v, "n (i s) m -> n (s m) i", i=ir.dim) + for ir, v in y.items() + } +``` + +## Key file locations + +| Component | Path | +|-----------|------| +| `segmented_polynomial` primitive | `cuequivariance_jax/segmented_polynomials/segmented_polynomial.py` | +| `uniform_1d` backend | `cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py` | +| `naive` backend | `cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py` | +| `indexed_linear` backend | `cuequivariance_jax/segmented_polynomials/segmented_polynomial_indexed_linear.py` | +| `equivariant_polynomial` | `cuequivariance_jax/equivariant_polynomial.py` | +| `ir_dict` module | `cuequivariance_jax/ir_dict.py` | +| `nnx` module | `cuequivariance_jax/nnx.py` | +| `RepArray` | `cuequivariance_jax/rep_array/rep_array_.py` | +| `Repeats` / utilities | `cuequivariance_jax/segmented_polynomials/utils.py` | +| NequIP example | `cuequivariance_jax/examples/nequip_nnx.py` | +| MACE example | `cuequivariance_jax/examples/mace_nnx.py` | From 04f4d462326492fdd9f08c0cb8114c1007f4c16d Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 10 Apr 2026 15:21:54 +0200 Subject: [PATCH 02/14] Move SKILL.md into packages and add `python -m` CLI to print them Move SKILL.md files from repo root into the Python package directories so they ship with pip install. Add __main__.py so users can retrieve the skill content with: python -m cuequivariance skill python -m cuequivariance_jax skill Co-Authored-By: Claude Opus 4.6 (1M context) --- cuequivariance/{ => cuequivariance}/SKILL.md | 0 cuequivariance/cuequivariance/__main__.py | 29 +++++++++++++++++++ .../{ => cuequivariance_jax}/SKILL.md | 0 .../cuequivariance_jax/__main__.py | 29 +++++++++++++++++++ 4 files changed, 58 insertions(+) rename cuequivariance/{ => cuequivariance}/SKILL.md (100%) create mode 100644 cuequivariance/cuequivariance/__main__.py rename cuequivariance_jax/{ => cuequivariance_jax}/SKILL.md (100%) create mode 100644 cuequivariance_jax/cuequivariance_jax/__main__.py diff --git a/cuequivariance/SKILL.md b/cuequivariance/cuequivariance/SKILL.md similarity index 100% rename from cuequivariance/SKILL.md rename to cuequivariance/cuequivariance/SKILL.md diff --git a/cuequivariance/cuequivariance/__main__.py b/cuequivariance/cuequivariance/__main__.py new file mode 100644 index 0000000..83e8da6 --- /dev/null +++ b/cuequivariance/cuequivariance/__main__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from pathlib import Path + + +def main(): + if len(sys.argv) >= 2 and sys.argv[1] == "skill": + skill_path = Path(__file__).parent / "SKILL.md" + print(skill_path.read_text()) + else: + print("Usage: python -m cuequivariance skill") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/cuequivariance_jax/SKILL.md b/cuequivariance_jax/cuequivariance_jax/SKILL.md similarity index 100% rename from cuequivariance_jax/SKILL.md rename to cuequivariance_jax/cuequivariance_jax/SKILL.md diff --git a/cuequivariance_jax/cuequivariance_jax/__main__.py b/cuequivariance_jax/cuequivariance_jax/__main__.py new file mode 100644 index 0000000..b458579 --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/__main__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from pathlib import Path + + +def main(): + if len(sys.argv) >= 2 and sys.argv[1] == "skill": + skill_path = Path(__file__).parent / "SKILL.md" + print(skill_path.read_text()) + else: + print("Usage: python -m cuequivariance_jax skill") + sys.exit(1) + + +if __name__ == "__main__": + main() From 80981425998341cad66d02e83cfb9cf32b9ef512 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 10 Apr 2026 15:34:13 +0200 Subject: [PATCH 03/14] Fix docs: output shape bug, add ir_dict tutorial, add missing API refs - Fix jax/poly.rst: output ShapeDtypeStruct was (3,3) for a flat size-9 operand, causing a warning. Use (-1,) to infer from descriptor. - Add jax/ir_dict.rst tutorial: split_operand_by_irrep, segmented_polynomial_uniform_1d, gather/scatter indexing, utilities. - Add ir_dict, nnx, and Repeats sections to cuequivariance_jax API docs. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../cuequivariance_jax/SKILL.md | 10 +- docs/api/cuequivariance_jax.rst | 41 ++++ docs/tutorials/jax/index.rst | 1 + docs/tutorials/jax/ir_dict.rst | 191 ++++++++++++++++++ docs/tutorials/jax/poly.rst | 3 +- 5 files changed, 241 insertions(+), 5 deletions(-) create mode 100644 docs/tutorials/jax/ir_dict.rst diff --git a/cuequivariance_jax/cuequivariance_jax/SKILL.md b/cuequivariance_jax/cuequivariance_jax/SKILL.md index e0a5865..d61d2d7 100644 --- a/cuequivariance_jax/cuequivariance_jax/SKILL.md +++ b/cuequivariance_jax/cuequivariance_jax/SKILL.md @@ -173,7 +173,7 @@ poly = ( .split_operand_by_irrep(-1) # split output .polynomial ) -# After split: all operands have uniform segment shapes (required for uniform_1d) +# After split: each operand has a single irrep type, mapping naturally to dict[Irrep, Array] ``` ### Executing with segmented_polynomial_uniform_1d @@ -352,12 +352,14 @@ poly = ( .split_operand_by_irrep(-1) # split output .polynomial ) -# Now each operand has uniform segments -> use ir_dict.segmented_polynomial_uniform_1d +# Each operand has a single irrep type -> maps naturally to dict[Irrep, Array] ``` -### Why split_operand_by_irrep works +### Why split_operand_by_irrep matters -Before splitting, a polynomial has dense operands like `32x0+32x1` with segments `((1,32), (3,32))` -- non-uniform shapes. After `split_operand_by_irrep`, each piece has a single irrep type, so all segments within that operand are the same shape. The pytree structure of `dict[Irrep, Array]` maps naturally to the split operands. +Without splitting, a dense operand like `32x0+32x1` requires all irreps packed into a single contiguous buffer. After `split_operand_by_irrep`, each irrep gets its own separate buffer passed to the CUDA kernel via FFI. The buffers no longer need to be contiguous with each other. + +This is especially useful when the polynomial is preceded or followed by per-irrep linear layers (like `IrrepsLinear`). With split operands, no transpose or copy is needed between the linear layers and the polynomial — the `dict[Irrep, Array]` flows directly through the pipeline. ## RepArray diff --git a/docs/api/cuequivariance_jax.rst b/docs/api/cuequivariance_jax.rst index 567da6f..103589b 100644 --- a/docs/api/cuequivariance_jax.rst +++ b/docs/api/cuequivariance_jax.rst @@ -47,6 +47,38 @@ Tensor Products equivariant_polynomial segmented_polynomial +ir_dict +------- + +Utilities for working with ``dict[Irrep, Array]`` representation, an alternative to ``RepArray``. + +.. autosummary:: + :toctree: generated/ + :template: function_template.rst + + ir_dict.segmented_polynomial_uniform_1d + ir_dict.assert_mul_ir_dict + ir_dict.mul_ir_dict + ir_dict.flat_to_dict + ir_dict.dict_to_flat + ir_dict.irreps_add + ir_dict.irreps_zeros_like + +NNX Layers +---------- + +Flax NNX modules using ``dict[Irrep, Array]`` representation. + +.. autosummary:: + :toctree: generated/ + :template: class_template.rst + + nnx.IrrepsLinear + nnx.SphericalHarmonics + nnx.IrrepsNormalize + nnx.MLP + nnx.IrrepsIndexedLinear + Extra Modules ------------- @@ -62,6 +94,15 @@ Extra Modules spherical_harmonics +Utilities +--------- + +.. autosummary:: + :toctree: generated/ + :template: class_template.rst + + Repeats + Triangle -------- diff --git a/docs/tutorials/jax/index.rst b/docs/tutorials/jax/index.rst index 2c3c2bd..aa0cd74 100644 --- a/docs/tutorials/jax/index.rst +++ b/docs/tutorials/jax/index.rst @@ -20,4 +20,5 @@ JAX Examples :maxdepth: 1 poly + ir_dict diff --git a/docs/tutorials/jax/ir_dict.rst b/docs/tutorials/jax/ir_dict.rst new file mode 100644 index 0000000..0f1c131 --- /dev/null +++ b/docs/tutorials/jax/ir_dict.rst @@ -0,0 +1,191 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +The ``ir_dict`` Interface +========================= + +The :mod:`cuequivariance_jax.ir_dict` module provides an alternative to :class:`~cuequivariance_jax.RepArray` for working with equivariant data. Instead of a single contiguous array, features are stored as ``dict[Irrep, Array]`` where each value has shape ``(..., multiplicity, irrep_dim)``. + +This representation works naturally with ``jax.tree`` operations and is used by the :mod:`cuequivariance_jax.nnx` layers. + +From Descriptor to ``ir_dict`` +------------------------------ + +Descriptors produce :class:`~cuequivariance.EquivariantPolynomial` objects with dense operands (e.g., ``32x0+32x1``). A dense operand requires all irreps to be packed into a single contiguous buffer. By splitting each operand by irrep with :meth:`~cuequivariance.EquivariantPolynomial.split_operand_by_irrep`, each irrep gets its own separate buffer. This relaxes the memory layout constraint: the buffers for different irreps no longer need to be contiguous with each other. + +This is especially useful when the polynomial is preceded or followed by linear layers that act independently on each irrep (like :class:`~cuequivariance_jax.nnx.IrrepsLinear`). With split operands, there is no need for any transpose or copy between the linear layers and the polynomial — the ``dict[Irrep, Array]`` flows directly through the pipeline. + +.. jupyter-execute:: + + import jax + import jax.numpy as jnp + from einops import rearrange + import cuequivariance as cue + import cuequivariance_jax as cuex + + # Build a channelwise tensor product + e = cue.descriptors.channelwise_tensor_product( + 32 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + simplify_irreps3=True, + ) + print("Before split:") + print(e) + +.. jupyter-execute:: + + # Split operands by irrep + # Order: split inner operands first to preserve indices + e_split = ( + e.split_operand_by_irrep(2) + .split_operand_by_irrep(1) + .split_operand_by_irrep(-1) + ) + poly = e_split.polynomial + + print("After split:") + print(e_split) + print() + for i, op in enumerate(poly.inputs): + print(f" Input {i}: num_segments={op.num_segments}, uniform={op.all_same_segment_shape()}") + for i, op in enumerate(poly.outputs): + print(f" Output {i}: num_segments={op.num_segments}, uniform={op.all_same_segment_shape()}") + + +Executing with ``segmented_polynomial_uniform_1d`` +-------------------------------------------------- + +The :func:`~cuequivariance_jax.ir_dict.segmented_polynomial_uniform_1d` function handles the flattening/unflattening between the ``dict[Irrep, Array]`` pytree structure and the flat arrays that the kernel expects. + +Each input array has shape ``(..., num_segments, *segment_shape)``. For the weight operand, we reshape the flat weights into this form. For ``dict[Irrep, Array]`` operands, each value is one leaf of the pytree. + +.. jupyter-execute:: + + batch = 16 + + # Weights: reshape flat -> (batch, num_segments, segment_size) + w_flat = jax.random.normal(jax.random.key(0), (batch, poly.inputs[0].size)) + w = rearrange(w_flat, "b (s m) -> b s m", s=poly.inputs[0].num_segments) + print(f"Weights: {w.shape} (batch, num_segments, segment_size)") + + # Inputs as dict[Irrep, Array] + # Shape convention: (batch, ir.dim, mul) for ir_mul layout + node_feats = { + cue.SO3(0): jax.random.normal(jax.random.key(1), (batch, 32, 1)), + cue.SO3(1): jax.random.normal(jax.random.key(2), (batch, 32, 3)), + } + # Rearrange from (batch, mul, ir.dim) to (batch, ir.dim, mul) for ir_mul layout + x = jax.tree.map(lambda v: rearrange(v, "b m i -> b i m"), node_feats) + print(f"Input l=0: {x[cue.SO3(0)].shape} (batch, ir.dim, mul)") + print(f"Input l=1: {x[cue.SO3(1)].shape} (batch, ir.dim, mul)") + + # Second input (e.g. spherical harmonics): (batch, ir.dim) + sph = { + cue.SO3(0): jax.random.normal(jax.random.key(3), (batch, 1)), + cue.SO3(1): jax.random.normal(jax.random.key(4), (batch, 3)), + } + + # Build output template: one entry per split output + irreps_out = e.outputs[0].irreps + out_template = { + ir: jax.ShapeDtypeStruct( + (batch, desc.num_segments) + desc.segment_shape, w.dtype + ) + for (_, ir), desc in zip(irreps_out, poly.outputs) + } + print(f"Output template: { {str(k): v.shape for k, v in out_template.items()} }") + +.. jupyter-execute:: + + # Execute + y = cuex.ir_dict.segmented_polynomial_uniform_1d( + poly, [w, x, sph], out_template, + ) + + for ir, v in y.items(): + print(f" Output {ir}: {v.shape}") + + +Indexing (Gather/Scatter) +------------------------- + +In graph neural networks, features live on nodes and edges with different batch sizes. Index arrays handle the gather (for inputs) and scatter (for outputs): + +.. jupyter-execute:: + + num_edges, num_nodes = 100, 30 + + w = jax.random.normal(jax.random.key(0), (num_edges, poly.inputs[0].size)) + w = rearrange(w, "e (s m) -> e s m", s=poly.inputs[0].num_segments) + + node_feats = { + cue.SO3(0): jax.random.normal(jax.random.key(1), (num_nodes, 1, 32)), + cue.SO3(1): jax.random.normal(jax.random.key(2), (num_nodes, 3, 32)), + } + + sph = { + cue.SO3(0): jax.random.normal(jax.random.key(3), (num_edges, 1)), + cue.SO3(1): jax.random.normal(jax.random.key(4), (num_edges, 3)), + } + + senders = jax.random.randint(jax.random.key(5), (num_edges,), 0, num_nodes) + receivers = jax.random.randint(jax.random.key(6), (num_edges,), 0, num_nodes) + + out_template = { + ir: jax.ShapeDtypeStruct( + (num_nodes, desc.num_segments) + desc.segment_shape, w.dtype + ) + for (_, ir), desc in zip(irreps_out, poly.outputs) + } + + # Gather node features at senders, scatter results to receivers + y = cuex.ir_dict.segmented_polynomial_uniform_1d( + poly, + [w, node_feats, sph], + out_template, + input_indices=[None, senders, None], + output_indices=receivers, + ) + + for ir, v in y.items(): + print(f" Output {ir}: {v.shape}") + + +Utility Functions +----------------- + +The ``ir_dict`` module provides helpers for converting between flat arrays and ``dict[Irrep, Array]``: + +.. jupyter-execute:: + + irreps = cue.Irreps(cue.SO3, "4x0 + 2x1") + + # Flat array -> dict + flat = jnp.ones((8, irreps.dim)) + d = cuex.ir_dict.flat_to_dict(irreps, flat) + print(f"flat_to_dict: l=0 {d[cue.SO3(0)].shape}, l=1 {d[cue.SO3(1)].shape}") + + # Dict -> flat array + flat_back = cuex.ir_dict.dict_to_flat(irreps, d) + print(f"dict_to_flat: {flat_back.shape}") + + # Arithmetic + z = cuex.ir_dict.irreps_add(d, d) + print(f"irreps_add: l=0 sum={float(z[cue.SO3(0)].sum())}") + + # Validation + cuex.ir_dict.assert_mul_ir_dict(irreps, d) + print("assert_mul_ir_dict: passed") diff --git a/docs/tutorials/jax/poly.rst b/docs/tutorials/jax/poly.rst index 68043c2..19c040e 100644 --- a/docs/tutorials/jax/poly.rst +++ b/docs/tutorials/jax/poly.rst @@ -44,7 +44,8 @@ Basic Usage # 2. Define output shapes/dtypes # We need to tell JAX what the output will look like - output_structs = [jax.ShapeDtypeStruct((3, 3), jnp.float32)] + # Use -1 for the last dimension to infer from the polynomial descriptor + output_structs = [jax.ShapeDtypeStruct((-1,), jnp.float32)] # 3. Execute input_arr = jnp.ones((3,)) From 6f80ed83de1e40359cefb08a5cd9aa34c2262941 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 10 Apr 2026 17:04:05 +0200 Subject: [PATCH 04/14] ignore docs/_build/ --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index efa6755..5bdb1b0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__ docs/api/generated/ docs/public/ docs/jupyter_execute/ +docs/_build/ From 15eb7c21f0bdb277448fcff21fabff602455e8d9 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 13 Apr 2026 10:01:01 +0200 Subject: [PATCH 05/14] Add module docstrings pointing AI agents to SKILL.md Makes the `python -m skill` command discoverable via `help()` or docstring inspection, which is typically the first thing an AI coding assistant does when exploring an unfamiliar library. Co-Authored-By: Claude Opus 4.6 (1M context) --- cuequivariance/cuequivariance/__init__.py | 4 ++++ cuequivariance_jax/cuequivariance_jax/__init__.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index dd92a58..19e217e 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -12,6 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""cuEquivariance - GPU-accelerated equivariant operations for 3D neural networks. + +For AI coding assistants: run `python -m cuequivariance skill` for detailed usage guidance. +""" import importlib.resources __version__ = ( diff --git a/cuequivariance_jax/cuequivariance_jax/__init__.py b/cuequivariance_jax/cuequivariance_jax/__init__.py index 913e273..7d9ae9d 100644 --- a/cuequivariance_jax/cuequivariance_jax/__init__.py +++ b/cuequivariance_jax/cuequivariance_jax/__init__.py @@ -12,6 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""cuEquivariance JAX - JAX backend for cuEquivariance equivariant operations. + +For AI coding assistants: run `python -m cuequivariance_jax skill` for detailed usage guidance. +""" import importlib.resources __version__ = ( From 88a4c1d9753a4190e3b6f8b685ae7943c92acb57 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 14 Apr 2026 08:41:44 +0200 Subject: [PATCH 06/14] Add SKILL.md, __main__.py, and docstring for cuequivariance_torch Mirrors what was done for cuequivariance and cuequivariance_jax: - SKILL.md with full API reference for AI coding assistants - __main__.py CLI (`python -m cuequivariance_torch skill`) - Module docstring pointing to the CLI Co-Authored-By: Claude Opus 4.6 (1M context) --- .../cuequivariance_torch/SKILL.md | 388 ++++++++++++++++++ .../cuequivariance_torch/__init__.py | 4 + .../cuequivariance_torch/__main__.py | 29 ++ 3 files changed, 421 insertions(+) create mode 100644 cuequivariance_torch/cuequivariance_torch/SKILL.md create mode 100644 cuequivariance_torch/cuequivariance_torch/__main__.py diff --git a/cuequivariance_torch/cuequivariance_torch/SKILL.md b/cuequivariance_torch/cuequivariance_torch/SKILL.md new file mode 100644 index 0000000..002765a --- /dev/null +++ b/cuequivariance_torch/cuequivariance_torch/SKILL.md @@ -0,0 +1,388 @@ +--- +name: cuequivariance-torch +description: Execute equivariant tensor products in PyTorch using SegmentedPolynomial (naive/uniform_1d/fused_tp/indexed_linear), high-level operations (ChannelWiseTensorProduct, FullyConnectedTensorProduct, Linear, SymmetricContraction, SphericalHarmonics, Rotation), and layers (BatchNorm, FullyConnectedTensorProductConv). Use when writing PyTorch code with cuequivariance. +--- + +# cuequivariance_torch: Executing Equivariant Polynomials in PyTorch + +## Overview + +`cuequivariance_torch` (imported as `cuet`) executes `cuequivariance` polynomials on GPU via PyTorch. It provides: + +1. **Core primitive**: `cuet.SegmentedPolynomial` -- `torch.nn.Module` with multiple CUDA backends +2. **High-level operations** (`torch.nn.Module`): `ChannelWiseTensorProduct`, `FullyConnectedTensorProduct`, `Linear`, `SymmetricContraction`, `SphericalHarmonics`, `Rotation`, `Inversion` +3. **Layers**: `cuet.layers.BatchNorm`, `cuet.layers.FullyConnectedTensorProductConv` (message passing) +4. **Utilities**: `triangle_attention`, `triangle_multiplicative_update`, `attention_pair_bias` (AlphaFold2-style) +5. **Export support**: `onnx_custom_translation_table()`, `register_tensorrt_plugins()` + +## Execution methods + +| Method | Backend | Requirements | +|--------|---------|-------------| +| `"naive"` | Pure PyTorch (einsum) | Always works, any platform | +| `"uniform_1d"` | CUDA kernel | GPU, all segments uniform shape within each operand, single mode | +| `"fused_tp"` | CUDA kernel | GPU, 3- or 4-operand contractions, float32/float64 | +| `"indexed_linear"` | CUDA kernel | GPU, linear with indexed weights, sorted indices | + +## Core primitive: SegmentedPolynomial + +```python +import torch +import cuequivariance as cue +import cuequivariance_torch as cuet + +# Build a descriptor +e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) +poly = e.polynomial + +# Create the module +sp = cuet.SegmentedPolynomial(poly, method="uniform_1d") + +# Forward pass +x = torch.randn(batch, 3, device="cuda") +[output] = sp([x]) +# output.shape == (batch, 9) -- 1 + 3 + 5 +``` + +### Inputs, indexing, and scatter + +```python +e = cue.descriptors.channelwise_tensor_product( + 16 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), +) +poly = e.polynomial + +sp = cuet.SegmentedPolynomial(poly, method="uniform_1d") + +w = torch.randn(1, poly.inputs[0].size, device="cuda") # shared weights +x1 = torch.randn(batch, poly.inputs[1].size, device="cuda") # batched input 1 +x2 = torch.randn(batch, poly.inputs[2].size, device="cuda") # batched input 2 + +# Basic forward +[out] = sp([w, x1, x2]) + +# With input gathering (e.g., gather x1 by node index) +senders = torch.randint(0, num_nodes, (num_edges,), device="cuda") +[out] = sp([w, x1, x2], input_indices={1: senders}) + +# With output scattering (accumulate into target nodes) +receivers = torch.randint(0, num_nodes, (num_edges,), device="cuda") +[out] = sp( + [w, x1, x2], + input_indices={1: senders}, + output_indices={0: receivers}, + output_shapes={0: torch.empty(num_nodes, 1, device="cuda")}, +) +``` + +### Math dtype control + +```python +# Compute in float32 regardless of input dtype +sp = cuet.SegmentedPolynomial(poly, method="fused_tp", math_dtype=torch.float32) + +# For fused_tp, math_dtype must be float32 or float64 +# For naive, any torch.dtype works +# For uniform_1d, float32 or float64 (auto-selects float32 if input is e.g. float16) +``` + +## High-level operations + +All operations are `torch.nn.Module` subclasses. They wrap `SegmentedPolynomial` and handle layout transposition automatically. + +### Memory layout + +`IrrepsLayout` controls memory order within each `(mul, ir)` block: + +- `cue.mul_ir`: data ordered as `(mul, ir.dim)` -- **default, compatible with e3nn** +- `cue.ir_mul`: data ordered as `(ir.dim, mul)` -- **used internally by descriptors** + +Operations accept `layout` (applies to all), or per-operand `layout_in1`, `layout_in2`, `layout_out`. + +### ChannelWiseTensorProduct + +Channel-wise tensor product: pairs channels of `x1` with channels of `x2`. + +```python +# With internal weights (default: shared_weights=True, internal_weights=True) +tp = cuet.ChannelWiseTensorProduct( + cue.Irreps("SO3", "32x0 + 32x1"), # irreps_in1 + cue.Irreps("SO3", "0 + 1"), # irreps_in2 + layout=cue.mul_ir, + device="cuda", + dtype=torch.float32, +) +# tp.weight_numel -- number of weight parameters +# tp.irreps_out -- output irreps (auto-computed) + +x1 = torch.randn(batch, tp.irreps_in1.dim, device="cuda") +x2 = torch.randn(batch, tp.irreps_in2.dim, device="cuda") + +out = tp(x1, x2) # uses internal weight parameter +# out.shape == (batch, tp.irreps_out.dim) + +# With external weights (shared_weights=False) +tp = cuet.ChannelWiseTensorProduct( + cue.Irreps("SO3", "32x0 + 32x1"), + cue.Irreps("SO3", "0 + 1"), + layout=cue.mul_ir, + shared_weights=False, + device="cuda", +) +w = torch.randn(batch, tp.weight_numel, device="cuda") +out = tp(x1, x2, weight=w) + +# With gather/scatter for graph neural networks +out = tp(x1, x2, weight=w, indices_1=senders, indices_out=receivers, size_out=num_nodes) +``` + +Default method: `"uniform_1d"` if segments are uniform, else `"naive"`. + +### FullyConnectedTensorProduct + +All input irrep pairs contribute to all output irreps (dense contraction). + +```python +tp = cuet.FullyConnectedTensorProduct( + cue.Irreps("O3", "4x0e + 4x1o"), # irreps_in1 + cue.Irreps("O3", "0e + 1o"), # irreps_in2 + cue.Irreps("O3", "4x0e + 4x1o"), # irreps_out + layout=cue.mul_ir, + internal_weights=True, # store weights as parameters + device="cuda", +) + +out = tp(x1, x2) # uses internal weights +# or: out = tp(x1, x2, weight=w) # external weights +``` + +Default method: `"fused_tp"`. + +### Linear + +Equivariant linear layer (weight-only, no second input). + +```python +linear = cuet.Linear( + cue.Irreps("SO3", "4x0 + 2x1"), # irreps_in + cue.Irreps("SO3", "3x0 + 5x1"), # irreps_out + layout=cue.mul_ir, + internal_weights=True, + device="cuda", +) + +out = linear(x) + +# Species-indexed weights (different weights per atom type) +linear = cuet.Linear( + irreps_in, irreps_out, + weight_classes=50, # 50 different weight sets + internal_weights=True, + device="cuda", +) +out = linear(x, weight_indices=species_indices) # species_indices: (batch,) int tensor +``` + +Default method: `"naive"`. Use `method="fused_tp"` for CUDA acceleration. + +### SymmetricContraction + +MACE-style symmetric contraction with element-indexed weights. + +```python +sc = cuet.SymmetricContraction( + cue.Irreps("O3", "32x0e + 32x1o"), # irreps_in (uniform mul required) + cue.Irreps("O3", "32x0e"), # irreps_out (uniform mul required) + contraction_degree=3, # polynomial degree + num_elements=95, # number of chemical elements + layout=cue.ir_mul, + dtype=torch.float32, + device="cuda", +) + +# indices: (batch,) int tensor selecting which element weights to use +out = sc(x, indices) +# out.shape == (batch, irreps_out.dim) +``` + +Default method: `"uniform_1d"` if segments are uniform, else `"naive"`. + +### SphericalHarmonics + +```python +sh = cuet.SphericalHarmonics( + ls=[0, 1, 2, 3], # degrees + normalize=True, # normalize input vectors + device="cuda", +) + +vectors = torch.randn(batch, 3, device="cuda") +out = sh(vectors) +# out.shape == (batch, 1 + 3 + 5 + 7) -- sum of 2l+1 +``` + +Default method: `"uniform_1d"`. + +### Rotation and Inversion + +```python +# Rotation (SO3 or O3 irreps) +rot = cuet.Rotation( + cue.Irreps("SO3", "4x0 + 2x1 + 1x2"), + layout=cue.ir_mul, + device="cuda", +) + +# Euler angles (YXY convention) +gamma = torch.tensor([0.1], device="cuda") +beta = torch.tensor([0.2], device="cuda") +alpha = torch.tensor([0.3], device="cuda") +out = rot(gamma, beta, alpha, x) + +# Helper: encode angle for rotation +encoded = cuet.encode_rotation_angle(angle, ell=3) # cos/sin encoding + +# Helper: 3D vector to Euler angles +beta, alpha = cuet.vector_to_euler_angles(vector) + +# Inversion (O3 irreps only) +inv = cuet.Inversion( + cue.Irreps("O3", "4x0e + 2x1o"), + layout=cue.ir_mul, + device="cuda", +) +out = inv(x) +``` + +## Layers + +### BatchNorm + +Batch normalization for equivariant representations (adapted from e3nn). + +```python +bn = cuet.layers.BatchNorm( + cue.Irreps("O3", "4x0e + 4x1o"), + layout=cue.mul_ir, + eps=1e-5, + momentum=0.1, + affine=True, +) + +out = bn(x) # x.shape == (batch, irreps.dim) +``` + +### FullyConnectedTensorProductConv + +Message passing layer for equivariant GNNs (DiffDock-style). + +```python +conv = cuet.layers.FullyConnectedTensorProductConv( + in_irreps=cue.Irreps("O3", "4x0e + 4x1o"), + sh_irreps=cue.Irreps("O3", "0e + 1o"), + out_irreps=cue.Irreps("O3", "4x0e + 4x1o"), + mlp_channels=[16, 32, 32], # MLP for path weights + mlp_activation=torch.nn.ReLU(), + batch_norm=True, + layout=cue.ir_mul, +) + +# graph = ((src, dst), (num_src_nodes, num_dst_nodes)) +graph = ((src, dst), (num_src_nodes, num_dst_nodes)) + +out = conv(src_features, edge_sh, edge_emb, graph, reduce="mean") +# out.shape == (num_dst_nodes, out_irreps.dim) + +# Optional: separate scalar features for efficient first-layer GEMM +out = conv(src_features, edge_sh, edge_emb, graph, + src_scalars=src_scalars, dst_scalars=dst_scalars) +``` + +## Triangle operations (AlphaFold2-style) + +Require `cuequivariance_ops_torch`. + +```python +# Triangle attention with pair bias +out = cuet.triangle_attention(q, k, v, bias, mask=mask, scale=scale) +# q, k, v: (B, N, H, Q/K, D), bias: (B, 1, H, Q, K) + +# Triangle multiplicative update +out = cuet.triangle_multiplicative_update( + x, # (B, I, J, C) + mask=mask, # (B, I, J) + precision=cuet.TriMulPrecision.DEFAULT, +) + +# Attention with pair bias (diffusion models) +out = cuet.attention_pair_bias(q, k, v, bias, mask=mask) +``` + +## ONNX and TensorRT export + +```python +# ONNX export +table = cuet.onnx_custom_translation_table() +onnx_program = torch.onnx.export(model, inputs, custom_translation_table=table) + +# TensorRT plugin registration +cuet.register_tensorrt_plugins() +``` + +## Complete GNN example + +```python +import torch +import cuequivariance as cue +import cuequivariance_torch as cuet + +class SimpleGNN(torch.nn.Module): + def __init__(self, irreps_in, irreps_sh, irreps_out): + super().__init__() + self.tp = cuet.ChannelWiseTensorProduct( + irreps_in, irreps_sh, layout=cue.mul_ir, + shared_weights=False, device="cuda", + ) + self.linear = cuet.Linear( + self.tp.irreps_out, irreps_out, + layout=cue.mul_ir, internal_weights=True, device="cuda", + ) + self.sh = cuet.SphericalHarmonics( + ls=[ir.l for _, ir in irreps_sh], normalize=True, device="cuda", + ) + + def forward(self, node_feats, edge_vec, edge_index, num_nodes): + src, dst = edge_index + edge_sh = self.sh(edge_vec) + w = torch.randn(1, self.tp.weight_numel, device=node_feats.device) + + # Message: tensor product on edges with scatter to destination nodes + messages = self.tp( + node_feats, edge_sh, weight=w, + indices_1=src, indices_2=None, + indices_out=dst, size_out=num_nodes, + ) + return self.linear(messages) +``` + +## Key file locations + +| Component | Path | +|-----------|------| +| `SegmentedPolynomial` | `cuequivariance_torch/primitives/segmented_polynomial.py` | +| `uniform_1d` backend | `cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py` | +| `naive` backend | `cuequivariance_torch/primitives/segmented_polynomial_naive.py` | +| `fused_tp` backend | `cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py` | +| `indexed_linear` backend | `cuequivariance_torch/primitives/segmented_polynomial_indexed_linear.py` | +| `ChannelWiseTensorProduct` | `cuequivariance_torch/operations/tp_channel_wise.py` | +| `FullyConnectedTensorProduct` | `cuequivariance_torch/operations/tp_fully_connected.py` | +| `Linear` | `cuequivariance_torch/operations/linear.py` | +| `SymmetricContraction` | `cuequivariance_torch/operations/symmetric_contraction.py` | +| `SphericalHarmonics` | `cuequivariance_torch/operations/spherical_harmonics.py` | +| `Rotation` / `Inversion` | `cuequivariance_torch/operations/rotation.py` | +| `BatchNorm` | `cuequivariance_torch/layers/batchnorm.py` | +| `FullyConnectedTensorProductConv` | `cuequivariance_torch/layers/tp_conv_fully_connected.py` | +| Triangle operations | `cuequivariance_torch/primitives/triangle.py` | +| Layout transposition | `cuequivariance_torch/primitives/transpose.py` | diff --git a/cuequivariance_torch/cuequivariance_torch/__init__.py b/cuequivariance_torch/cuequivariance_torch/__init__.py index 23c929e..44fcd85 100644 --- a/cuequivariance_torch/cuequivariance_torch/__init__.py +++ b/cuequivariance_torch/cuequivariance_torch/__init__.py @@ -12,6 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""cuEquivariance PyTorch - PyTorch backend for cuEquivariance equivariant operations. + +For AI coding assistants: run `python -m cuequivariance_torch skill` for detailed usage guidance. +""" import importlib.resources __version__ = ( diff --git a/cuequivariance_torch/cuequivariance_torch/__main__.py b/cuequivariance_torch/cuequivariance_torch/__main__.py new file mode 100644 index 0000000..094f969 --- /dev/null +++ b/cuequivariance_torch/cuequivariance_torch/__main__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from pathlib import Path + + +def main(): + if len(sys.argv) >= 2 and sys.argv[1] == "skill": + skill_path = Path(__file__).parent / "SKILL.md" + print(skill_path.read_text()) + else: + print("Usage: python -m cuequivariance_torch skill") + sys.exit(1) + + +if __name__ == "__main__": + main() From ca0f0ac1d1318a585bb3f0f1f8aa7f7116843d8c Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 14 Apr 2026 10:44:08 +0200 Subject: [PATCH 07/14] Add IrDictPolynomial and _ir_dict descriptor variants Introduce a new ir_dict workflow that decouples from EquivariantPolynomial and IrrepsAndLayout. The key additions: - IrDictPolynomial: a lightweight dataclass pairing a SegmentedPolynomial (already split by irrep) with per-operand-group Irreps metadata. - split_polynomial_by_irreps: standalone helper bridging Irreps and SegmentedPolynomial.split_operand_by_size. - _ir_dict variants for all descriptors (channelwise, fully_connected, full, elementwise, linear, symmetric_contraction, spherical_harmonics). Each descriptor is refactored with a _core function shared between the existing EquivariantPolynomial path and the new IrDictPolynomial path. The NNX layers (nnx.py) no longer import RepArray, EquivariantPolynomial, IrrepsAndLayout, or spherical_harmonics. The MACE and NequIP examples are updated to use the new _ir_dict path. Co-Authored-By: Claude Opus 4.6 (1M context) --- cuequivariance/cuequivariance/__init__.py | 4 + .../cuequivariance/group_theory/__init__.py | 3 + .../group_theory/descriptors/__init__.py | 20 +- .../group_theory/descriptors/irreps_tp.py | 364 ++++++++++++++---- .../descriptors/spherical_harmonics_.py | 85 ++-- .../descriptors/symmetric_contractions.py | 138 ++++--- .../mace/symmetric_contractions.py | 67 +++- .../group_theory/ir_dict_polynomial.py | 136 +++++++ .../group_theory/ir_dict_polynomial_test.py | 337 ++++++++++++++++ cuequivariance_jax/cuequivariance_jax/nnx.py | 49 ++- cuequivariance_jax/examples/mace_nnx.py | 27 +- cuequivariance_jax/examples/nequip_nnx.py | 18 +- 12 files changed, 1023 insertions(+), 225 deletions(-) create mode 100644 cuequivariance/cuequivariance/group_theory/ir_dict_polynomial.py create mode 100644 cuequivariance/tests/group_theory/ir_dict_polynomial_test.py diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index dd92a58..b5b3f61 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -51,6 +51,8 @@ reduced_antisymmetric_tensor_product_basis, EquivariantPolynomial, EquivariantTensorProduct, # deprecated + IrDictPolynomial, + split_polynomial_by_irreps, ) from cuequivariance import segmented_polynomials as segmented_polynomials @@ -89,6 +91,8 @@ "reduced_antisymmetric_tensor_product_basis", "EquivariantPolynomial", "EquivariantTensorProduct", + "IrDictPolynomial", + "split_polynomial_by_irreps", "segmented_polynomials", "group_theory", "descriptors", diff --git a/cuequivariance/cuequivariance/group_theory/__init__.py b/cuequivariance/cuequivariance/group_theory/__init__.py index 4e221b9..f00be4b 100644 --- a/cuequivariance/cuequivariance/group_theory/__init__.py +++ b/cuequivariance/cuequivariance/group_theory/__init__.py @@ -44,6 +44,7 @@ from .equivariant_polynomial import EquivariantPolynomial from .equivariant_tensor_product import EquivariantTensorProduct +from .ir_dict_polynomial import IrDictPolynomial, split_polynomial_by_irreps __all__ = [ @@ -72,4 +73,6 @@ "reduced_antisymmetric_tensor_product_basis", "EquivariantPolynomial", "EquivariantTensorProduct", + "IrDictPolynomial", + "split_polynomial_by_irreps", ] diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/__init__.py b/cuequivariance/cuequivariance/group_theory/descriptors/__init__.py index f18c438..2de8b1c 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/__init__.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/__init__.py @@ -15,12 +15,17 @@ from .transposition import transpose from .irreps_tp import ( full_tensor_product, + full_tensor_product_ir_dict, fully_connected_tensor_product, + fully_connected_tensor_product_ir_dict, channelwise_tensor_product, + channelwise_tensor_product_ir_dict, elementwise_tensor_product, + elementwise_tensor_product_ir_dict, linear, + linear_ir_dict, ) -from .symmetric_contractions import symmetric_contraction +from .symmetric_contractions import symmetric_contraction, symmetric_contraction_ir_dict from .rotations import ( fixed_axis_angle_rotation, y_rotation, @@ -30,16 +35,26 @@ yxy_rotation, inversion, ) -from .spherical_harmonics_ import sympy_spherical_harmonics, spherical_harmonics +from .spherical_harmonics_ import ( + sympy_spherical_harmonics, + spherical_harmonics, + spherical_harmonics_ir_dict, +) __all__ = [ "transpose", "full_tensor_product", + "full_tensor_product_ir_dict", "fully_connected_tensor_product", + "fully_connected_tensor_product_ir_dict", "channelwise_tensor_product", + "channelwise_tensor_product_ir_dict", "elementwise_tensor_product", + "elementwise_tensor_product_ir_dict", "linear", + "linear_ir_dict", "symmetric_contraction", + "symmetric_contraction_ir_dict", "fixed_axis_angle_rotation", "y_rotation", "x_rotation", @@ -49,4 +64,5 @@ "inversion", "sympy_spherical_harmonics", "spherical_harmonics", + "spherical_harmonics_ir_dict", ] diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py index 992d51f..a46e4d7 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py @@ -19,6 +19,34 @@ from cuequivariance.group_theory.irreps_array.irrep_utils import into_list_of_irrep +def _fully_connected_tensor_product_core( + irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps +) -> cue.SegmentedPolynomial: + G = irreps1.irrep_class + + d = cue.SegmentedTensorProduct.from_subscripts("uvw,iu,jv,kw+ijk") + + for mul, ir in irreps1: + d.add_segment(1, (ir.dim, mul)) + for mul, ir in irreps2: + d.add_segment(2, (ir.dim, mul)) + for mul, ir in irreps3: + d.add_segment(3, (ir.dim, mul)) + + for (i1, (mul1, ir1)), (i2, (mul2, ir2)), (i3, (mul3, ir3)) in itertools.product( + enumerate(irreps1), enumerate(irreps2), enumerate(irreps3) + ): + if ir3 not in ir1 * ir2: + continue + + # for loop over the different solutions of the Clebsch-Gordan decomposition + for cg in G.clebsch_gordan(ir1, ir2, ir3): + d.add_path((mul1, mul2, mul3), i1, i2, i3, c=cg) + + d = d.normalize_paths_for_operand(-1) + return cue.SegmentedPolynomial.eval_last_operand(d) + + def fully_connected_tensor_product( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps ) -> cue.EquivariantPolynomial: @@ -51,59 +79,56 @@ def fully_connected_tensor_product( Where ``61440x0`` are the 61440 weights needed to mix all the inputs with all the outputs. """ - G = irreps1.irrep_class - - d = cue.SegmentedTensorProduct.from_subscripts("uvw,iu,jv,kw+ijk") - - for mul, ir in irreps1: - d.add_segment(1, (ir.dim, mul)) - for mul, ir in irreps2: - d.add_segment(2, (ir.dim, mul)) - for mul, ir in irreps3: - d.add_segment(3, (ir.dim, mul)) - - for (i1, (mul1, ir1)), (i2, (mul2, ir2)), (i3, (mul3, ir3)) in itertools.product( - enumerate(irreps1), enumerate(irreps2), enumerate(irreps3) - ): - if ir3 not in ir1 * ir2: - continue - - # for loop over the different solutions of the Clebsch-Gordan decomposition - for cg in G.clebsch_gordan(ir1, ir2, ir3): - d.add_path((mul1, mul2, mul3), i1, i2, i3, c=cg) - - d = d.normalize_paths_for_operand(-1) + poly = _fully_connected_tensor_product_core(irreps1, irreps2, irreps3) return cue.EquivariantPolynomial( [ - cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps1.new_scalars(poly.inputs[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), ], [cue.IrrepsAndLayout(irreps3, cue.ir_mul)], - cue.SegmentedPolynomial.eval_last_operand(d), + poly, ) -def full_tensor_product( - irreps1: cue.Irreps, - irreps2: cue.Irreps, - irreps3_filter: Optional[Sequence[cue.Irrep]] = None, -) -> cue.EquivariantPolynomial: +def fully_connected_tensor_product_ir_dict( + irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps +) -> cue.IrDictPolynomial: """ - subscripts: ``lhs[iu],rhs[jv],output[kuv]`` + subscripts: ``weights[uvw],lhs[iu],rhs[jv],output[kw]`` - Construct a weightless channelwise tensor product descriptor. + Construct a fully connected tensor product as an :class:`~cuequivariance.IrDictPolynomial`. + + This is the ``ir_dict`` variant of :func:`fully_connected_tensor_product`. .. currentmodule:: cuequivariance Args: irreps1 (Irreps): Irreps of the first operand. irreps2 (Irreps): Irreps of the second operand. - irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. + irreps3 (Irreps): Irreps of the output. Returns: - :class:`cue.EquivariantPolynomial `: Descriptor of the full tensor product. + :class:`cue.IrDictPolynomial `: The fully connected tensor product + with ``input_irreps = (weight_irreps, irreps1, irreps2)`` and ``output_irreps = (irreps3,)``. """ + poly = _fully_connected_tensor_product_core(irreps1, irreps2, irreps3) + weight_irreps = irreps1.new_scalars(poly.inputs[0].size) + poly = cue.split_polynomial_by_irreps(poly, 2, irreps2) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps1) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps3) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(weight_irreps, irreps1, irreps2), + output_irreps=(irreps3,), + ) + + +def _full_tensor_product_core( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter: Optional[Sequence[cue.Irrep]], +) -> tuple[cue.SegmentedPolynomial, cue.Irreps]: G = irreps1.irrep_class if irreps3_filter is not None: @@ -136,28 +161,51 @@ def full_tensor_product( d = d.permute_segments(2, inv) d = d.normalize_paths_for_operand(-1) + return cue.SegmentedPolynomial.eval_last_operand(d), irreps3 + + +def full_tensor_product( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter: Optional[Sequence[cue.Irrep]] = None, +) -> cue.EquivariantPolynomial: + """ + subscripts: ``lhs[iu],rhs[jv],output[kuv]`` + + Construct a weightless channelwise tensor product descriptor. + + .. currentmodule:: cuequivariance + + Args: + irreps1 (Irreps): Irreps of the first operand. + irreps2 (Irreps): Irreps of the second operand. + irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. + + Returns: + :class:`cue.EquivariantPolynomial `: Descriptor of the full tensor product. + """ + poly, irreps3 = _full_tensor_product_core(irreps1, irreps2, irreps3_filter) return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), ], [cue.IrrepsAndLayout(irreps3, cue.ir_mul)], - cue.SegmentedPolynomial.eval_last_operand(d), + poly, ) -def channelwise_tensor_product( +def full_tensor_product_ir_dict( irreps1: cue.Irreps, irreps2: cue.Irreps, - irreps3_filter=None, - simplify_irreps3: bool = False, -) -> cue.EquivariantPolynomial: + irreps3_filter: Optional[Sequence[cue.Irrep]] = None, +) -> cue.IrDictPolynomial: """ - subscripts: ``weights[uv],lhs[iu],rhs[jv],output[kuv]`` + subscripts: ``lhs[iu],rhs[jv],output[kuv]`` - Construct a channelwise tensor product descriptor. + Construct a weightless channelwise tensor product as an :class:`~cuequivariance.IrDictPolynomial`. - This operation is computationally sparser than the fully connected tensor product. + This is the ``ir_dict`` variant of :func:`full_tensor_product`. .. currentmodule:: cuequivariance @@ -165,11 +213,28 @@ def channelwise_tensor_product( irreps1 (Irreps): Irreps of the first operand. irreps2 (Irreps): Irreps of the second operand. irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. - simplify_irreps3 (bool, optional): If True, the irreps of the output are simplified. Returns: - :class:`cue.EquivariantPolynomial `: Descriptor of the channelwise tensor product. + :class:`cue.IrDictPolynomial `: The full tensor product + with ``input_irreps = (irreps1, irreps2)`` and ``output_irreps = (irreps3,)``. """ + poly, irreps3 = _full_tensor_product_core(irreps1, irreps2, irreps3_filter) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps2) + poly = cue.split_polynomial_by_irreps(poly, 0, irreps1) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps3) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(irreps1, irreps2), + output_irreps=(irreps3,), + ) + + +def _channelwise_tensor_product_core( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter, + simplify_irreps3: bool, +) -> tuple[cue.SegmentedPolynomial, cue.Irreps]: G = irreps1.irrep_class if irreps3_filter is not None: @@ -215,14 +280,84 @@ def channelwise_tensor_product( d = d.permute_segments(3, [sid for _, _, sid in segments]) irreps3 = irreps3.simplify() + return cue.SegmentedPolynomial.eval_last_operand(d), irreps3 + + +def channelwise_tensor_product( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter=None, + simplify_irreps3: bool = False, +) -> cue.EquivariantPolynomial: + """ + subscripts: ``weights[uv],lhs[iu],rhs[jv],output[kuv]`` + + Construct a channelwise tensor product descriptor. + + This operation is computationally sparser than the fully connected tensor product. + + .. currentmodule:: cuequivariance + + Args: + irreps1 (Irreps): Irreps of the first operand. + irreps2 (Irreps): Irreps of the second operand. + irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. + simplify_irreps3 (bool, optional): If True, the irreps of the output are simplified. + + Returns: + :class:`cue.EquivariantPolynomial `: Descriptor of the channelwise tensor product. + """ + poly, irreps3 = _channelwise_tensor_product_core( + irreps1, irreps2, irreps3_filter, simplify_irreps3 + ) return cue.EquivariantPolynomial( [ - cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps1.new_scalars(poly.inputs[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), ], [cue.IrrepsAndLayout(irreps3, cue.ir_mul)], - cue.SegmentedPolynomial.eval_last_operand(d), + poly, + ) + + +def channelwise_tensor_product_ir_dict( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter=None, +) -> cue.IrDictPolynomial: + """ + subscripts: ``weights[uv],lhs[iu],rhs[jv],output[kuv]`` + + Construct a channelwise tensor product as an :class:`~cuequivariance.IrDictPolynomial`. + + This is the ``ir_dict`` variant of :func:`channelwise_tensor_product`. + The returned polynomial is already split by irrep and ready for use with + :func:`cuequivariance_jax.ir_dict.segmented_polynomial_uniform_1d`. + The output irreps are always simplified (each irrep appears at most once). + + .. currentmodule:: cuequivariance + + Args: + irreps1 (Irreps): Irreps of the first operand. + irreps2 (Irreps): Irreps of the second operand. + irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. + + Returns: + :class:`cue.IrDictPolynomial `: The channelwise tensor product + with ``input_irreps = (weight_irreps, irreps1, irreps2)`` and ``output_irreps = (irreps3,)``. + """ + poly, irreps3 = _channelwise_tensor_product_core( + irreps1, irreps2, irreps3_filter, simplify_irreps3=True + ) + weight_irreps = irreps1.new_scalars(poly.inputs[0].size) + poly = cue.split_polynomial_by_irreps(poly, 2, irreps2) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps1) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps3) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(weight_irreps, irreps1, irreps2), + output_irreps=(irreps3,), ) @@ -255,24 +390,11 @@ def _align_two_irreps( return cue.Irreps(irreps1.irrep_class, l1), cue.Irreps(irreps2.irrep_class, l2) -def elementwise_tensor_product( +def _elementwise_tensor_product_core( irreps1: cue.Irreps, irreps2: cue.Irreps, - irreps3_filter: Optional[Sequence[cue.Irrep]] = None, -) -> cue.EquivariantPolynomial: - """ - subscripts: ``lhs[iu],rhs[ju],output[ku]`` - - Construct an elementwise tensor product descriptor. - - Args: - irreps1 (Irreps): Irreps of the first operand. - irreps2 (Irreps): Irreps of the second operand. - irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. - - Returns: - :class:`cue.EquivariantPolynomial `: Descriptor of the elementwise tensor product. - """ + irreps3_filter: Optional[Sequence[cue.Irrep]], +) -> tuple[cue.SegmentedPolynomial, cue.Irreps, cue.Irreps, cue.Irreps]: G = irreps1.irrep_class if irreps1.num_irreps != irreps2.num_irreps: @@ -300,29 +422,84 @@ def elementwise_tensor_product( irreps3 = cue.Irreps(G, irreps3) d = d.normalize_paths_for_operand(-1) + return cue.SegmentedPolynomial.eval_last_operand(d), irreps1_cut, irreps2_cut, irreps3 + + +def elementwise_tensor_product( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter: Optional[Sequence[cue.Irrep]] = None, +) -> cue.EquivariantPolynomial: + """ + subscripts: ``lhs[iu],rhs[ju],output[ku]`` + + Construct an elementwise tensor product descriptor. + + Args: + irreps1 (Irreps): Irreps of the first operand. + irreps2 (Irreps): Irreps of the second operand. + irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. + + Returns: + :class:`cue.EquivariantPolynomial `: Descriptor of the elementwise tensor product. + """ + poly, _, _, irreps3 = _elementwise_tensor_product_core( + irreps1, irreps2, irreps3_filter + ) return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), ], [cue.IrrepsAndLayout(irreps3, cue.ir_mul)], - cue.SegmentedPolynomial.eval_last_operand(d), + poly, ) -def linear(irreps_in: cue.Irreps, irreps_out: cue.Irreps) -> cue.EquivariantPolynomial: +def elementwise_tensor_product_ir_dict( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter: Optional[Sequence[cue.Irrep]] = None, +) -> cue.IrDictPolynomial: """ - subscripts: ``weights[uv],input[iu],output[iv]`` + subscripts: ``lhs[iu],rhs[ju],output[ku]`` - Construct the descriptor of a linear equivariant transformation. + Construct an elementwise tensor product as an :class:`~cuequivariance.IrDictPolynomial`. + + This is the ``ir_dict`` variant of :func:`elementwise_tensor_product`. + + Note: + The input irreps may be refined (split into smaller blocks) to align + multiplicities. The actual irreps used are available in the returned + ``input_irreps``. + + .. currentmodule:: cuequivariance Args: - irreps_in (Irreps): Irreps of the input. - irreps_out (Irreps): Irreps of the output. + irreps1 (Irreps): Irreps of the first operand. + irreps2 (Irreps): Irreps of the second operand. + irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. Returns: - :class:`cue.EquivariantPolynomial `: Descriptor of the linear transformation. + :class:`cue.IrDictPolynomial `: The elementwise tensor product + with ``input_irreps = (irreps1_aligned, irreps2_aligned)`` and ``output_irreps = (irreps3,)``. """ + poly, irreps1_cut, irreps2_cut, irreps3 = _elementwise_tensor_product_core( + irreps1, irreps2, irreps3_filter + ) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps2_cut) + poly = cue.split_polynomial_by_irreps(poly, 0, irreps1_cut) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps3) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(irreps1_cut, irreps2_cut), + output_irreps=(irreps3,), + ) + + +def _linear_core( + irreps_in: cue.Irreps, irreps_out: cue.Irreps +) -> cue.SegmentedPolynomial: d = cue.SegmentedTensorProduct.from_subscripts("uv_iu_iv") for mul, ir in irreps_in: d.add_segment(1, (ir.dim, mul)) @@ -336,12 +513,59 @@ def linear(irreps_in: cue.Irreps, irreps_out: cue.Irreps) -> cue.EquivariantPoly d.add_path((mul1, mul2), i1, i2, c=1.0) d = d.normalize_paths_for_operand(-1) + return cue.SegmentedPolynomial.eval_last_operand(d) + +def linear(irreps_in: cue.Irreps, irreps_out: cue.Irreps) -> cue.EquivariantPolynomial: + """ + subscripts: ``weights[uv],input[iu],output[iv]`` + + Construct the descriptor of a linear equivariant transformation. + + Args: + irreps_in (Irreps): Irreps of the input. + irreps_out (Irreps): Irreps of the output. + + Returns: + :class:`cue.EquivariantPolynomial `: Descriptor of the linear transformation. + """ + poly = _linear_core(irreps_in, irreps_out) return cue.EquivariantPolynomial( [ - cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps_in.new_scalars(poly.inputs[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps_in, cue.ir_mul), ], [cue.IrrepsAndLayout(irreps_out, cue.ir_mul)], - cue.SegmentedPolynomial.eval_last_operand(d), + poly, + ) + + +def linear_ir_dict( + irreps_in: cue.Irreps, irreps_out: cue.Irreps +) -> cue.IrDictPolynomial: + """ + subscripts: ``weights[uv],input[iu],output[iv]`` + + Construct a linear equivariant transformation as an :class:`~cuequivariance.IrDictPolynomial`. + + This is the ``ir_dict`` variant of :func:`linear`. + + .. currentmodule:: cuequivariance + + Args: + irreps_in (Irreps): Irreps of the input. + irreps_out (Irreps): Irreps of the output. + + Returns: + :class:`cue.IrDictPolynomial `: The linear transformation + with ``input_irreps = (weight_irreps, irreps_in)`` and ``output_irreps = (irreps_out,)``. + """ + poly = _linear_core(irreps_in, irreps_out) + weight_irreps = irreps_in.new_scalars(poly.inputs[0].size) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps_in) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps_out) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(weight_irreps, irreps_in), + output_irreps=(irreps_out,), ) diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py index de36ea8..ae577bb 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py @@ -20,6 +20,42 @@ from cuequivariance.etc.sympy_utils import sqrtQarray_to_sympy +def _spherical_harmonics_core( + ir_vec: cue.Irrep, ls: list[int] +) -> tuple[cue.SegmentedPolynomial, cue.Irreps]: + if len(ls) != 1: + results = [_spherical_harmonics_core(ir_vec, [ell]) for ell in ls] + poly = cue.SegmentedPolynomial.stack( + [r[0] for r in results], [False, True] + ) + irreps_out = cue.Irreps( + type(ir_vec), sum([list(r[1]) for r in results], []) + ) + return poly, irreps_out + + [ell] = ls + ir, formula = sympy_spherical_harmonics(ir_vec, ell) + + assert ir_vec.dim == 3 + d = cue.SegmentedTensorProduct.empty_segments([3] * ell + [ir.dim]) + for i in range(ir.dim): + for degrees, coeff in ( + sympy.Poly(formula[i], sympy.symbols("x:3")).as_dict().items() + ): + indices = poly_degrees_to_path_indices(degrees) + d.add_path(*indices, i, c=coeff) + + d = d.symmetrize_operands(range(ell), force=True) + + poly = cue.SegmentedPolynomial( + [cue.SegmentedOperand([()] * 3)], + [cue.SegmentedOperand([()] * ir.dim)], + [(cue.Operation([0] * ell + [1]), d)], + ) + irreps_out = cue.Irreps(type(ir_vec), [(1, ir)]) + return poly, irreps_out + + def spherical_harmonics( ir_vec: cue.Irrep, ls: list[int], layout: cue.IrrepsLayout = cue.ir_mul ) -> cue.EquivariantPolynomial: @@ -42,33 +78,36 @@ def spherical_harmonics( │ []·a[]➜B[] ───── num_paths=3 ╰─ []·a[]·a[]➜B[] ─ num_paths=11 """ - if len(ls) != 1: - return cue.EquivariantPolynomial.stack( - [spherical_harmonics(ir_vec, [ell], layout) for ell in ls], [False, True] - ) + poly, irreps_out = _spherical_harmonics_core(ir_vec, ls) + return cue.EquivariantPolynomial( + [cue.IrrepsAndLayout(cue.Irreps(ir_vec), cue.ir_mul)], + [cue.IrrepsAndLayout(irreps_out, cue.ir_mul)], + poly, + ) - [ell] = ls - ir, formula = sympy_spherical_harmonics(ir_vec, ell) - assert ir_vec.dim == 3 - d = cue.SegmentedTensorProduct.empty_segments([3] * ell + [ir.dim]) - for i in range(ir.dim): - for degrees, coeff in ( - sympy.Poly(formula[i], sympy.symbols("x:3")).as_dict().items() - ): - indices = poly_degrees_to_path_indices(degrees) - d.add_path(*indices, i, c=coeff) +def spherical_harmonics_ir_dict( + ir_vec: cue.Irrep, ls: list[int] +) -> cue.IrDictPolynomial: + """Polynomial descriptor for the spherical harmonics as an :class:`~cuequivariance.IrDictPolynomial`. - d = d.symmetrize_operands(range(ell), force=True) + This is the ``ir_dict`` variant of :func:`spherical_harmonics`. - return cue.EquivariantPolynomial( - [cue.IrrepsAndLayout(cue.Irreps(ir_vec), cue.ir_mul)], - [cue.IrrepsAndLayout(cue.Irreps(ir), cue.ir_mul)], - cue.SegmentedPolynomial( - [cue.SegmentedOperand([()] * 3)], - [cue.SegmentedOperand([()] * ir.dim)], - [(cue.Operation([0] * ell + [1]), d)], - ), + Args: + ir_vec (Irrep): irrep of the input vector, for example ``cue.SO3(1)``. + ls (list of int): list of spherical harmonic degrees, for example ``[0, 1, 2]``. + + Returns: + :class:`cue.IrDictPolynomial `: The spherical harmonics + with ``input_irreps = (Irreps(ir_vec),)`` and ``output_irreps = (irreps_out,)``. + """ + poly, irreps_out = _spherical_harmonics_core(ir_vec, ls) + irreps_in = cue.Irreps(ir_vec) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps_out) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(irreps_in,), + output_irreps=(irreps_out,), ) diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py index 2aa1df9..1575e68 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py @@ -17,57 +17,20 @@ import cuequivariance as cue -def symmetric_contraction( - irreps_in: cue.Irreps, - irreps_out: cue.Irreps, - degrees: tuple[int, ...], -) -> cue.EquivariantPolynomial: - """Construct the descriptor for a symmetric contraction. - - The symmetric contraction is a weighted sum of the input contracted with itself degree times. - - Subscripts: ``weights[u],input[u],output[u]`` - - Args: - irreps_in (Irreps): The input irreps, the multiplicity are treated in parallel. - irreps_out (Irreps): The output irreps. - degrees (tuple[int, ...]): List of degrees for the symmetric contractions. - - Returns: - EquivariantPolynomial: The descriptor of the symmetric contraction. - The operands are the weights, the input degree times and the output. - - Example: - >>> cue.descriptors.symmetric_contraction( - ... 16 * cue.Irreps("SO3", "0 + 1 + 2"), - ... 16 * cue.Irreps("SO3", "0 + 1"), - ... (1, 2, 3) - ... ) - ╭ a=32x0+80x0+176x0 b=16x0+16x1+16x2 -> C=16x0+16x1 - │ []·a[u]·b[u]➜C[u] ─────────── num_paths=4 u=16 - │ []·a[u]·b[u]·b[u]➜C[u] ────── num_paths=37 u=16 - ╰─ []·a[u]·b[u]·b[u]·b[u]➜C[u] ─ num_paths=437 u=16 - - Where ``32x0+80x0+176x0`` are the weights needed for each degree (32 for degree 1, 80 for degree 2, 176 for degree 3). - """ - return symmetric_contraction_cached(irreps_in, irreps_out, tuple(degrees)) - - @cache -def symmetric_contraction_cached( +def _symmetric_contraction_core( irreps_in: cue.Irreps, irreps_out: cue.Irreps, degrees: tuple[int, ...], -) -> cue.EquivariantPolynomial: +) -> cue.SegmentedPolynomial: degrees = list(degrees) if len(degrees) != 1: - return cue.EquivariantPolynomial.stack( - [ - symmetric_contraction(irreps_in, irreps_out, (degree,)) - for degree in degrees - ], - [True, False, False], - ) + polys = [ + _symmetric_contraction_core(irreps_in, irreps_out, (degree,)) + for degree in degrees + ] + return cue.SegmentedPolynomial.stack(polys, [True, False, False]) + [degree] = degrees del degrees @@ -118,15 +81,84 @@ def symmetric_contraction_cached( for i in input_operands: assert d.operands[i] == input_operand + return cue.SegmentedPolynomial( + [d.operands[0], input_operand], + [d.operands[-1]], + [(cue.Operation([0] + [1] * degree + [2]), d)], + ) + + +def symmetric_contraction( + irreps_in: cue.Irreps, + irreps_out: cue.Irreps, + degrees: tuple[int, ...], +) -> cue.EquivariantPolynomial: + """Construct the descriptor for a symmetric contraction. + + The symmetric contraction is a weighted sum of the input contracted with itself degree times. + + Subscripts: ``weights[u],input[u],output[u]`` + + Args: + irreps_in (Irreps): The input irreps, the multiplicity are treated in parallel. + irreps_out (Irreps): The output irreps. + degrees (tuple[int, ...]): List of degrees for the symmetric contractions. + + Returns: + EquivariantPolynomial: The descriptor of the symmetric contraction. + The operands are the weights, the input degree times and the output. + + Example: + >>> cue.descriptors.symmetric_contraction( + ... 16 * cue.Irreps("SO3", "0 + 1 + 2"), + ... 16 * cue.Irreps("SO3", "0 + 1"), + ... (1, 2, 3) + ... ) + ╭ a=32x0+80x0+176x0 b=16x0+16x1+16x2 -> C=16x0+16x1 + │ []·a[u]·b[u]➜C[u] ─────────── num_paths=4 u=16 + │ []·a[u]·b[u]·b[u]➜C[u] ────── num_paths=37 u=16 + ╰─ []·a[u]·b[u]·b[u]·b[u]➜C[u] ─ num_paths=437 u=16 + + Where ``32x0+80x0+176x0`` are the weights needed for each degree (32 for degree 1, 80 for degree 2, 176 for degree 3). + """ + poly = _symmetric_contraction_core(irreps_in, irreps_out, tuple(degrees)) return cue.EquivariantPolynomial( [ - cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), - cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul), + cue.IrrepsAndLayout(irreps_in.new_scalars(poly.inputs[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps_in, cue.ir_mul), ], - [cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul)], - cue.SegmentedPolynomial( - [d.operands[0], input_operand], - [d.operands[-1]], - [(cue.Operation([0] + [1] * degree + [2]), d)], - ), + [cue.IrrepsAndLayout(irreps_out, cue.ir_mul)], + poly, + ) + + +def symmetric_contraction_ir_dict( + irreps_in: cue.Irreps, + irreps_out: cue.Irreps, + degrees: tuple[int, ...], +) -> cue.IrDictPolynomial: + """Construct a symmetric contraction as an :class:`~cuequivariance.IrDictPolynomial`. + + This is the ``ir_dict`` variant of :func:`symmetric_contraction`. + + .. currentmodule:: cuequivariance + + Args: + irreps_in (Irreps): The input irreps, the multiplicity are treated in parallel. + irreps_out (Irreps): The output irreps. + degrees (tuple[int, ...]): List of degrees for the symmetric contractions. + + Returns: + :class:`cue.IrDictPolynomial `: The symmetric contraction + with ``input_irreps = (weight_irreps, mul * irreps_in)`` and + ``output_irreps = (mul * irreps_out,)``. + """ + poly = _symmetric_contraction_core(irreps_in, irreps_out, tuple(degrees)) + weight_irreps = irreps_in.new_scalars(poly.inputs[0].size) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps_in) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps_out) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(weight_irreps, irreps_in), + output_irreps=(irreps_out,), ) diff --git a/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py index aa655ee..23321a8 100644 --- a/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py @@ -19,6 +19,9 @@ import cuequivariance as cue from cuequivariance.etc.linalg import round_to_sqrt_rational, triu_array +from cuequivariance.group_theory.descriptors.symmetric_contractions import ( + _symmetric_contraction_core as _std_symmetric_contraction_core, +) def symmetric_contraction( @@ -43,21 +46,54 @@ def symmetric_contraction( x = cuex.randn(jax.random.key(1), e.inputs[1]) y = cuex.equivariant_polynomial(e, [w, x]) """ - return symmetric_contraction_cached(irreps_in, irreps_out, tuple(degrees)) + poly, projection = _symmetric_contraction_cached( + irreps_in, irreps_out, tuple(degrees) + ) + return cue.EquivariantPolynomial( + [ + cue.IrrepsAndLayout(irreps_in.new_scalars(poly.inputs[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps_in, cue.ir_mul), + ], + [cue.IrrepsAndLayout(irreps_out, cue.ir_mul)], + poly, + ), projection + + +def symmetric_contraction_ir_dict( + irreps_in: cue.Irreps, irreps_out: cue.Irreps, degrees: tuple[int, ...] +) -> tuple[cue.IrDictPolynomial, np.ndarray]: + r"""``ir_dict`` variant of :func:`symmetric_contraction`. + + Returns: + tuple of (:class:`~cuequivariance.IrDictPolynomial`, np.ndarray): + The polynomial (with ``input_irreps = (weight_irreps, irreps_in)`` + and ``output_irreps = (irreps_out,)``) and the projection matrix. + """ + poly, projection = _symmetric_contraction_cached( + irreps_in, irreps_out, tuple(degrees) + ) + weight_irreps = irreps_in.new_scalars(poly.inputs[0].size) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps_in) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps_out) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(weight_irreps, irreps_in), + output_irreps=(irreps_out,), + ), projection @cache -def symmetric_contraction_cached( +def _symmetric_contraction_cached( irreps_in: cue.Irreps, irreps_out: cue.Irreps, degrees: tuple[int, ...] -) -> tuple[cue.EquivariantPolynomial, np.ndarray]: +) -> tuple[cue.SegmentedPolynomial, np.ndarray]: assert min(degrees) > 0 # poly1 replicates the behavior of the original MACE implementation - poly1 = cue.EquivariantPolynomial.stack( + poly1 = cue.SegmentedPolynomial.stack( [ - cue.EquivariantPolynomial.stack( + cue.SegmentedPolynomial.stack( [ - _symmetric_contraction(irreps_in, irreps_out[i : i + 1], deg) + _symmetric_contraction_poly(irreps_in, irreps_out[i : i + 1], deg) for deg in reversed(degrees) ], [True, False, False], @@ -66,7 +102,7 @@ def symmetric_contraction_cached( ], [True, False, True], ) - poly2 = cue.descriptors.symmetric_contraction(irreps_in, irreps_out, degrees) + poly2 = _std_symmetric_contraction_core(irreps_in, irreps_out, tuple(degrees)) a1, a2 = [ np.concatenate( [ @@ -75,7 +111,7 @@ def symmetric_contraction_cached( 1, None, ) - for _, d in pol.polynomial.operations + for _, d in pol.operations ], axis=1, ) @@ -120,9 +156,9 @@ def _stp_to_matrix( # This function is an adaptation of https://github.com/ACEsuit/mace/blob/bd412319b11c5f56c37cec6c4cfae74b2a49ff43/mace/modules/symmetric_contraction.py -def _symmetric_contraction( +def _symmetric_contraction_poly( irreps_in: cue.Irreps, irreps_out: cue.Irreps, degree: int -) -> cue.EquivariantPolynomial: +) -> cue.SegmentedPolynomial: mul = irreps_in.muls[0] assert all(mul == m for m in irreps_in.muls) assert all(mul == m for m in irreps_out.muls) @@ -157,15 +193,8 @@ def _symmetric_contraction( assert d.num_operands >= 3 [w, x], y = d.operands[:2], d.operands[-1] - return cue.EquivariantPolynomial( - [ - cue.IrrepsAndLayout(irreps_in.new_scalars(w.size), cue.ir_mul), - cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul), - ], - [cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul)], - cue.SegmentedPolynomial( - [w, x], [y], [(cue.Operation([0] + [1] * degree + [2]), d)] - ), + return cue.SegmentedPolynomial( + [w, x], [y], [(cue.Operation([0] + [1] * degree + [2]), d)] ) diff --git a/cuequivariance/cuequivariance/group_theory/ir_dict_polynomial.py b/cuequivariance/cuequivariance/group_theory/ir_dict_polynomial.py new file mode 100644 index 0000000..529736a --- /dev/null +++ b/cuequivariance/cuequivariance/group_theory/ir_dict_polynomial.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import itertools + +import cuequivariance as cue + + +def split_polynomial_by_irreps( + polynomial: cue.SegmentedPolynomial, + operand_id: int, + irreps: cue.Irreps, +) -> cue.SegmentedPolynomial: + """Split a polynomial operand according to irreps boundaries. + + Each ``(mul, ir)`` block in the irreps becomes a separate operand + in the resulting polynomial. + + Args: + polynomial: The polynomial to split. + operand_id: Index of the operand to split (negative indices supported). + irreps: Irreps describing the operand's structure. + + Returns: + A new :class:`~cuequivariance.SegmentedPolynomial` with the specified + operand split into one operand per ``(mul, ir)`` block. + """ + offsets = list( + itertools.accumulate((mul * ir.dim for mul, ir in irreps), initial=0) + ) + return polynomial.split_operand_by_size(operand_id, offsets) + + +@dataclasses.dataclass(init=False, frozen=True) +class IrDictPolynomial: + """A segmented polynomial with per-operand irreps metadata for the ``ir_dict`` workflow. + + This class pairs a :class:`~cuequivariance.SegmentedPolynomial` (already split + by irrep) with the :class:`~cuequivariance.Irreps` that describe each operand group. + + Each :class:`~cuequivariance.Irreps` in ``input_irreps`` and ``output_irreps`` + corresponds to a logical operand group (e.g. weights, node features, spherical + harmonics, output features). Within each group, every ``(mul, ir)`` block maps + to one polynomial operand. + + Contract: + - The polynomial is already split by irrep: each operand corresponds to + exactly one ``(mul, ir)`` block. + - The ``(mul, ir)`` blocks in ``input_irreps`` and ``output_irreps`` + are in the same order as the polynomial's input and output operands. + - For each ``(mul, ir)`` block, the corresponding polynomial operand + has size ``mul * ir.dim``. + + Args: + polynomial: The underlying polynomial, already split by irrep. + input_irreps: One :class:`~cuequivariance.Irreps` per input group. + output_irreps: One :class:`~cuequivariance.Irreps` per output group. + """ + + polynomial: cue.SegmentedPolynomial + input_irreps: tuple[cue.Irreps, ...] + output_irreps: tuple[cue.Irreps, ...] + + def __init__( + self, + polynomial: cue.SegmentedPolynomial, + input_irreps: list[cue.Irreps] | tuple[cue.Irreps, ...], + output_irreps: list[cue.Irreps] | tuple[cue.Irreps, ...], + ): + object.__setattr__(self, "polynomial", polynomial) + object.__setattr__(self, "input_irreps", tuple(input_irreps)) + object.__setattr__(self, "output_irreps", tuple(output_irreps)) + + expected_inputs = sum(len(irreps) for irreps in self.input_irreps) + if expected_inputs != polynomial.num_inputs: + raise ValueError( + f"input_irreps describe {expected_inputs} operands, " + f"but polynomial has {polynomial.num_inputs} inputs" + ) + + expected_outputs = sum(len(irreps) for irreps in self.output_irreps) + if expected_outputs != polynomial.num_outputs: + raise ValueError( + f"output_irreps describe {expected_outputs} operands, " + f"but polynomial has {polynomial.num_outputs} outputs" + ) + + operand_idx = 0 + for irreps in self.input_irreps: + for mul, ir in irreps: + actual_size = polynomial.inputs[operand_idx].size + expected_size = mul * ir.dim + if expected_size != actual_size: + raise ValueError( + f"Input operand {operand_idx} ({mul}x{ir}): " + f"expected size {expected_size}, " + f"got {actual_size}" + ) + operand_idx += 1 + + operand_idx = 0 + for irreps in self.output_irreps: + for mul, ir in irreps: + actual_size = polynomial.outputs[operand_idx].size + expected_size = mul * ir.dim + if expected_size != actual_size: + raise ValueError( + f"Output operand {operand_idx} ({mul}x{ir}): " + f"expected size {expected_size}, " + f"got {actual_size}" + ) + operand_idx += 1 + + def __repr__(self): + labels = [] + for irreps in self.input_irreps: + for mul, ir in irreps: + labels.append(f"{mul}x{ir}" if mul > 1 else f"{ir}") + for irreps in self.output_irreps: + for mul, ir in irreps: + labels.append(f"{mul}x{ir}" if mul > 1 else f"{ir}") + return self.polynomial.to_string(labels) diff --git a/cuequivariance/tests/group_theory/ir_dict_polynomial_test.py b/cuequivariance/tests/group_theory/ir_dict_polynomial_test.py new file mode 100644 index 0000000..c738c91 --- /dev/null +++ b/cuequivariance/tests/group_theory/ir_dict_polynomial_test.py @@ -0,0 +1,337 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +import cuequivariance as cue + + +# -------------------------------------------------------------------------- +# split_polynomial_by_irreps +# -------------------------------------------------------------------------- + + +def test_split_polynomial_by_irreps_matches_split_operand_by_irrep(): + """The new standalone function should produce the same result as + EquivariantPolynomial.split_operand_by_irrep.""" + irreps_in = cue.Irreps(cue.O3, "64x0e + 32x1o") + irreps_sh = cue.Irreps(cue.O3, "0e + 1o") + irreps_out = cue.Irreps(cue.O3, "0e + 1o + 2e") + + e = cue.descriptors.channelwise_tensor_product(irreps_in, irreps_sh, irreps_out, True) + old = ( + e.split_operand_by_irrep(2) + .split_operand_by_irrep(1) + .split_operand_by_irrep(-1) + .polynomial + ) + + new = e.polynomial + new = cue.split_polynomial_by_irreps(new, 2, irreps_sh) + new = cue.split_polynomial_by_irreps(new, 1, irreps_in) + new = cue.split_polynomial_by_irreps(new, -1, e.outputs[0].irreps) + + assert old == new + + +# -------------------------------------------------------------------------- +# IrDictPolynomial validation +# -------------------------------------------------------------------------- + + +def test_ir_dict_polynomial_rejects_wrong_operand_count(): + irreps_in = cue.Irreps(cue.O3, "4x0e + 2x1o") + irreps_out = cue.Irreps(cue.O3, "3x0e") + + result = cue.descriptors.linear_ir_dict(irreps_in, irreps_out) + + with pytest.raises(ValueError, match="input_irreps describe"): + cue.IrDictPolynomial( + polynomial=result.polynomial, + input_irreps=(irreps_in,), # wrong: should include weight group + output_irreps=result.output_irreps, + ) + + +def test_ir_dict_polynomial_rejects_wrong_operand_size(): + irreps_in = cue.Irreps(cue.O3, "4x0e + 2x1o") + irreps_out = cue.Irreps(cue.O3, "3x0e") + + result = cue.descriptors.linear_ir_dict(irreps_in, irreps_out) + + with pytest.raises(ValueError, match="expected size"): + cue.IrDictPolynomial( + polynomial=result.polynomial, + input_irreps=( + result.input_irreps[0], + cue.Irreps(cue.O3, "3x0e + 2x1o"), # wrong mul for 0e + ), + output_irreps=result.output_irreps, + ) + + +# -------------------------------------------------------------------------- +# _ir_dict descriptor variants match the old EquivariantPolynomial path +# -------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "irreps1, irreps2, irreps3_filter", + [ + ( + cue.Irreps(cue.O3, "64x0e + 32x1o"), + cue.Irreps(cue.O3, "0e + 1o"), + cue.Irreps(cue.O3, "0e + 1o + 2e"), + ), + ( + cue.Irreps(cue.O3, "16x0e + 8x1o + 4x2e"), + cue.Irreps(cue.O3, "0e + 1o"), + None, + ), + ( + cue.Irreps(cue.SO3, "8x0 + 4x1 + 2x2"), + cue.Irreps(cue.SO3, "0 + 1"), + None, + ), + ], +) +def test_channelwise_tensor_product_ir_dict(irreps1, irreps2, irreps3_filter): + # channelwise_tensor_product_ir_dict always simplifies output irreps + e = cue.descriptors.channelwise_tensor_product( + irreps1, irreps2, irreps3_filter, simplify_irreps3=True + ) + old_poly = ( + e.split_operand_by_irrep(2) + .split_operand_by_irrep(1) + .split_operand_by_irrep(-1) + .polynomial + ) + + result = cue.descriptors.channelwise_tensor_product_ir_dict( + irreps1, irreps2, irreps3_filter + ) + + assert result.polynomial == old_poly + assert result.output_irreps[0] == e.outputs[0].irreps + assert result.input_irreps[1] == irreps1 + assert result.input_irreps[2] == irreps2 + + +@pytest.mark.parametrize( + "irreps1, irreps2, irreps3", + [ + ( + cue.Irreps(cue.O3, "4x0e + 2x1o"), + cue.Irreps(cue.O3, "0e + 1o"), + cue.Irreps(cue.O3, "4x0e + 2x1o"), + ), + ( + cue.Irreps(cue.SO3, "8x0 + 4x1"), + cue.Irreps(cue.SO3, "0 + 1 + 2"), + cue.Irreps(cue.SO3, "8x0 + 4x1 + 2x2"), + ), + ], +) +def test_fully_connected_tensor_product_ir_dict(irreps1, irreps2, irreps3): + e = cue.descriptors.fully_connected_tensor_product(irreps1, irreps2, irreps3) + old_poly = ( + e.split_operand_by_irrep(2) + .split_operand_by_irrep(1) + .split_operand_by_irrep(-1) + .polynomial + ) + + result = cue.descriptors.fully_connected_tensor_product_ir_dict( + irreps1, irreps2, irreps3 + ) + + assert result.polynomial == old_poly + assert result.output_irreps[0] == irreps3 + + +@pytest.mark.parametrize( + "irreps_in, irreps_out", + [ + ( + cue.Irreps(cue.O3, "4x0e + 2x1o"), + cue.Irreps(cue.O3, "3x0e + 5x1o"), + ), + ( + cue.Irreps(cue.SO3, "16x0 + 8x1 + 4x2"), + cue.Irreps(cue.SO3, "8x0 + 4x1"), + ), + ], +) +def test_linear_ir_dict(irreps_in, irreps_out): + e = cue.descriptors.linear(irreps_in, irreps_out) + old_poly = ( + e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial + ) + + result = cue.descriptors.linear_ir_dict(irreps_in, irreps_out) + + assert result.polynomial == old_poly + assert result.output_irreps[0] == irreps_out + assert result.input_irreps[1] == irreps_in + + +def test_full_tensor_product_ir_dict(): + irreps1 = cue.Irreps(cue.O3, "2x0e + 1x1o") + irreps2 = cue.Irreps(cue.O3, "0e + 1o") + + e = cue.descriptors.full_tensor_product(irreps1, irreps2) + old_poly = ( + e.split_operand_by_irrep(1) + .split_operand_by_irrep(0) + .split_operand_by_irrep(-1) + .polynomial + ) + + result = cue.descriptors.full_tensor_product_ir_dict(irreps1, irreps2) + + assert result.polynomial == old_poly + assert result.input_irreps[0] == irreps1 + assert result.input_irreps[1] == irreps2 + + +def test_elementwise_tensor_product_ir_dict(): + irreps1 = cue.Irreps(cue.O3, "4x0e + 4x1o") + irreps2 = cue.Irreps(cue.O3, "4x0e + 4x1o") + + e = cue.descriptors.elementwise_tensor_product(irreps1, irreps2) + old_poly = ( + e.split_operand_by_irrep(1) + .split_operand_by_irrep(0) + .split_operand_by_irrep(-1) + .polynomial + ) + + result = cue.descriptors.elementwise_tensor_product_ir_dict(irreps1, irreps2) + + assert result.polynomial == old_poly + + +def test_symmetric_contraction_ir_dict(): + irreps_in = 16 * cue.Irreps("SO3", "0 + 1 + 2") + irreps_out = 16 * cue.Irreps("SO3", "0 + 1") + + e = cue.descriptors.symmetric_contraction(irreps_in, irreps_out, (1, 2, 3)) + old_poly = ( + e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial + ) + + result = cue.descriptors.symmetric_contraction_ir_dict( + irreps_in, irreps_out, (1, 2, 3) + ) + + assert result.polynomial == old_poly + (output_irreps,) = result.output_irreps + assert output_irreps == irreps_out + + +def test_mace_symmetric_contraction_ir_dict(): + from cuequivariance.group_theory.experimental.mace.symmetric_contractions import ( + symmetric_contraction as mace_sc, + symmetric_contraction_ir_dict as mace_sc_ir_dict, + ) + + irreps_in = 4 * cue.Irreps("SO3", "0 + 1 + 2") + irreps_out = 4 * cue.Irreps("SO3", "0 + 1") + + e, projection_old = mace_sc(irreps_in, irreps_out, [1, 2, 3]) + old_poly = ( + e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial + ) + + result, projection_new = mace_sc_ir_dict(irreps_in, irreps_out, [1, 2, 3]) + + assert result.polynomial == old_poly + np.testing.assert_array_equal(projection_old, projection_new) + (output_irreps,) = result.output_irreps + assert output_irreps == irreps_out + + +@pytest.mark.parametrize("max_degree", [1, 2, 3, 4]) +def test_spherical_harmonics_ir_dict(max_degree): + ir_vec = cue.O3(1, -1) + ls = list(range(max_degree + 1)) + + e = cue.descriptors.spherical_harmonics(ir_vec, ls) + old_poly = e.split_operand_by_irrep(-1).polynomial + + result = cue.descriptors.spherical_harmonics_ir_dict(ir_vec, ls) + + assert result.polynomial == old_poly + (output_irreps,) = result.output_irreps + assert output_irreps == e.outputs[0].irreps + + # Numpy evaluation: verify output matches unsplit + vec = np.array([0.3, -0.5, 0.8]) + [out_flat] = e.polynomial(vec) + + out_parts = result.polynomial(vec) + out_concat = np.concatenate(out_parts) + np.testing.assert_allclose(out_flat, out_concat, atol=1e-12) + + +# -------------------------------------------------------------------------- +# Numpy evaluation: ir_dict variant produces same results as original +# -------------------------------------------------------------------------- + + +def test_channelwise_numpy_evaluation(): + """Evaluate both paths with numpy and compare outputs.""" + irreps1 = cue.Irreps(cue.O3, "4x0e + 2x1o") + irreps_sh = cue.Irreps(cue.O3, "0e + 1o") + + e = cue.descriptors.channelwise_tensor_product(irreps1, irreps_sh, simplify_irreps3=True) + result = cue.descriptors.channelwise_tensor_product_ir_dict(irreps1, irreps_sh) + + # Generate random inputs matching the unsplit polynomial + np.random.seed(42) + inputs_orig = [np.random.randn(op.size) for op in e.polynomial.inputs] + [out_orig] = e.polynomial(*inputs_orig) + + # The split polynomial has more operands — reconstruct matching inputs + inputs_split = [np.random.randn(op.size) for op in result.polynomial.inputs] + + # Use the same flat data for both + # Unsplit: [weights, input1_flat, input2_flat] + # Split: [weights, input1_ir0, input1_ir1, input2_ir0, input2_ir1] + w = np.random.randn(result.polynomial.inputs[0].size) + x1 = np.random.randn(e.polynomial.inputs[1].size) + x2 = np.random.randn(e.polynomial.inputs[2].size) + + [out_orig] = e.polynomial(w, x1, x2) + + # Split x1 and x2 by irrep boundaries + x1_parts = [] + offset = 0 + for mul, ir in irreps1: + size = mul * ir.dim + x1_parts.append(x1[offset : offset + size]) + offset += size + + x2_parts = [] + offset = 0 + for mul, ir in irreps_sh: + size = mul * ir.dim + x2_parts.append(x2[offset : offset + size]) + offset += size + + out_split = result.polynomial(w, *x1_parts, *x2_parts) + out_split_concat = np.concatenate(out_split) + + np.testing.assert_allclose(out_orig, out_split_concat, atol=1e-12) diff --git a/cuequivariance_jax/cuequivariance_jax/nnx.py b/cuequivariance_jax/cuequivariance_jax/nnx.py index 1043592..78570b0 100644 --- a/cuequivariance_jax/cuequivariance_jax/nnx.py +++ b/cuequivariance_jax/cuequivariance_jax/nnx.py @@ -28,10 +28,8 @@ from . import ir_dict from .activation import normalize_function -from .rep_array.rep_array_ import RepArray from .segmented_polynomials.segmented_polynomial import segmented_polynomial from .segmented_polynomials.utils import Repeats -from .spherical_harmonics import spherical_harmonics try: from flax import nnx @@ -143,33 +141,31 @@ def __call__(self, x: dict[Irrep, Array]) -> dict[Irrep, Array]: class SphericalHarmonics(nnx.Module): def __init__(self, max_degree: int, eps: float = 0.0): self.eps = eps - self.max_degree = max_degree - self.irreps_in = cue.Irreps(cue.O3, "1o") - self.irreps_out = cue.Irreps( - cue.O3, [(1, cue.O3(L, (-1) ** L)) for L in range(max_degree + 1)] + desc = cue.descriptors.spherical_harmonics_ir_dict( + cue.O3(1, -1), list(range(max_degree + 1)) ) + self.poly = desc.polynomial + (self.irreps_out,) = desc.output_irreps def __call__(self, x: Array) -> dict[Irrep, Array]: assert x.shape[-1] == 3 shape = x.shape[:-1] - x = RepArray(self.irreps_in, x, cue.ir_mul) - x = jax.tree.map( - lambda v: v / _safe_norm(v, self.eps, axis=-1, keepdim=True), x + x = x / _safe_norm(x, self.eps, axis=-1, keepdim=True) + outputs = segmented_polynomial( + self.poly, + [x], + [ + jax.ShapeDtypeStruct(shape + (out.size,), x.dtype) + for out in self.poly.outputs + ], + method="naive", + name="spherical_harmonics", ) - y = spherical_harmonics(range(self.max_degree + 1), x, normalize=False) - - y = { - ir: rearrange(v, "... i m -> ... m i") - for (_, ir), v in zip(y.irreps, y.segments) - } - actual = jax.tree.map(lambda x: x.shape, y) - expected = { - cue.O3(L, (-1) ** L): shape + (1, 2 * L + 1) - for L in range(self.max_degree + 1) + return { + ir: y.reshape(shape + (1, ir.dim)) + for (_, ir), y in zip(self.irreps_out, outputs) } - assert actual == expected, f"y: {actual}, expected: {expected}" - return y class IrrepsNormalize(nnx.Module): @@ -245,12 +241,12 @@ def __init__( self.irreps_in = irreps_in self.irreps_out = irreps_out self.num_indices = num_indices - self.scale = scale / jnp.sqrt(num_indices) self.name = name - self.e = cue.descriptors.linear(irreps_in, irreps_out) * self.scale + scale = scale / jnp.sqrt(num_indices) + self.poly = cue.descriptors.linear(irreps_in, irreps_out).polynomial * scale self.w = nnx.Param( - jax.random.normal(rngs.params(), (num_indices, self.e.inputs[0].dim), dtype) + jax.random.normal(rngs.params(), (num_indices, self.poly.inputs[0].size), dtype) ) def __call__( @@ -263,11 +259,10 @@ def __call__( x_flat = ir_dict.dict_to_flat(self.irreps_in, x_ir_mul) num_elements = x_flat.shape[0] - p = self.e.polynomial [y_flat] = segmented_polynomial( - p, + self.poly, [self.w[...], x_flat], - [jax.ShapeDtypeStruct((num_elements, p.outputs[0].size), x_flat.dtype)], + [jax.ShapeDtypeStruct((num_elements, self.poly.outputs[0].size), x_flat.dtype)], [Repeats(num_index_counts), None, None], method="indexed_linear", name=self.name, diff --git a/cuequivariance_jax/examples/mace_nnx.py b/cuequivariance_jax/examples/mace_nnx.py index e1439b4..72aa557 100644 --- a/cuequivariance_jax/examples/mace_nnx.py +++ b/cuequivariance_jax/examples/mace_nnx.py @@ -26,8 +26,8 @@ import jax import jax.numpy as jnp import numpy as np -from cuequivariance.group_theory.experimental.mace import ( - symmetric_contraction as mace_symmetric_contraction, +from cuequivariance.group_theory.experimental.mace.symmetric_contractions import ( + symmetric_contraction_ir_dict as mace_symmetric_contraction_ir_dict, ) from cuequivariance_jax.nnx import ( MLP, @@ -89,21 +89,12 @@ def __init__( rngs: nnx.Rngs, ): self.name = name - e = ( - cue.descriptors.channelwise_tensor_product( - irreps_in, irreps_sh, irreps_out, True - ) - * epsilon - ) - self.weight_numel = e.inputs[0].dim - self.irreps_out = e.outputs[0].irreps - - self.poly = ( - e.split_operand_by_irrep(2) - .split_operand_by_irrep(1) - .split_operand_by_irrep(-1) - .polynomial + desc = cue.descriptors.channelwise_tensor_product_ir_dict( + irreps_in, irreps_sh, irreps_out ) + (self.irreps_out,) = desc.output_irreps + self.poly = desc.polynomial * epsilon + self.weight_numel = self.poly.inputs[0].size def __call__( self, @@ -156,12 +147,12 @@ def __init__( self.irreps_out = irreps_out self.name = name - e, projection = mace_symmetric_contraction( + desc, projection = mace_symmetric_contraction_ir_dict( irreps_in, irreps_out, range(1, correlation + 1) ) self.projection = jnp.array(projection, dtype=dtype) - self.poly = e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial + self.poly = desc.polynomial self.w = nnx.Param( jax.random.normal( rngs.params(), diff --git a/cuequivariance_jax/examples/nequip_nnx.py b/cuequivariance_jax/examples/nequip_nnx.py index f1016da..cef7cb0 100644 --- a/cuequivariance_jax/examples/nequip_nnx.py +++ b/cuequivariance_jax/examples/nequip_nnx.py @@ -222,20 +222,12 @@ def __init__( rngs: nnx.Rngs, ): self.name = name - e = ( - cue.descriptors.channelwise_tensor_product( - irreps_in, irreps_sh, irreps_out, True - ) - * epsilon - ) - self.weight_numel = e.inputs[0].dim - self.irreps_out = e.outputs[0].irreps - self.poly = ( - e.split_operand_by_irrep(2) - .split_operand_by_irrep(1) - .split_operand_by_irrep(-1) - .polynomial + desc = cue.descriptors.channelwise_tensor_product_ir_dict( + irreps_in, irreps_sh, irreps_out ) + (self.irreps_out,) = desc.output_irreps + self.poly = desc.polynomial * epsilon + self.weight_numel = self.poly.inputs[0].size def __call__( self, From 47e95bb93127a18591a21881775f83bb3c919db4 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 14 Apr 2026 10:49:52 +0200 Subject: [PATCH 08/14] format --- .gitignore | 1 + .../group_theory/descriptors/irreps_tp.py | 7 ++++- .../descriptors/spherical_harmonics_.py | 8 ++--- .../group_theory/ir_dict_polynomial_test.py | 31 +++++++------------ cuequivariance_jax/cuequivariance_jax/nnx.py | 10 ++++-- 5 files changed, 28 insertions(+), 29 deletions(-) diff --git a/.gitignore b/.gitignore index efa6755..5bdb1b0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__ docs/api/generated/ docs/public/ docs/jupyter_execute/ +docs/_build/ diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py index a46e4d7..778130f 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py @@ -422,7 +422,12 @@ def _elementwise_tensor_product_core( irreps3 = cue.Irreps(G, irreps3) d = d.normalize_paths_for_operand(-1) - return cue.SegmentedPolynomial.eval_last_operand(d), irreps1_cut, irreps2_cut, irreps3 + return ( + cue.SegmentedPolynomial.eval_last_operand(d), + irreps1_cut, + irreps2_cut, + irreps3, + ) def elementwise_tensor_product( diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py index ae577bb..cb9d973 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py @@ -25,12 +25,8 @@ def _spherical_harmonics_core( ) -> tuple[cue.SegmentedPolynomial, cue.Irreps]: if len(ls) != 1: results = [_spherical_harmonics_core(ir_vec, [ell]) for ell in ls] - poly = cue.SegmentedPolynomial.stack( - [r[0] for r in results], [False, True] - ) - irreps_out = cue.Irreps( - type(ir_vec), sum([list(r[1]) for r in results], []) - ) + poly = cue.SegmentedPolynomial.stack([r[0] for r in results], [False, True]) + irreps_out = cue.Irreps(type(ir_vec), sum([list(r[1]) for r in results], [])) return poly, irreps_out [ell] = ls diff --git a/cuequivariance/tests/group_theory/ir_dict_polynomial_test.py b/cuequivariance/tests/group_theory/ir_dict_polynomial_test.py index c738c91..457bd98 100644 --- a/cuequivariance/tests/group_theory/ir_dict_polynomial_test.py +++ b/cuequivariance/tests/group_theory/ir_dict_polynomial_test.py @@ -17,7 +17,6 @@ import cuequivariance as cue - # -------------------------------------------------------------------------- # split_polynomial_by_irreps # -------------------------------------------------------------------------- @@ -30,7 +29,9 @@ def test_split_polynomial_by_irreps_matches_split_operand_by_irrep(): irreps_sh = cue.Irreps(cue.O3, "0e + 1o") irreps_out = cue.Irreps(cue.O3, "0e + 1o + 2e") - e = cue.descriptors.channelwise_tensor_product(irreps_in, irreps_sh, irreps_out, True) + e = cue.descriptors.channelwise_tensor_product( + irreps_in, irreps_sh, irreps_out, True + ) old = ( e.split_operand_by_irrep(2) .split_operand_by_irrep(1) @@ -176,9 +177,7 @@ def test_fully_connected_tensor_product_ir_dict(irreps1, irreps2, irreps3): ) def test_linear_ir_dict(irreps_in, irreps_out): e = cue.descriptors.linear(irreps_in, irreps_out) - old_poly = ( - e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial - ) + old_poly = e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial result = cue.descriptors.linear_ir_dict(irreps_in, irreps_out) @@ -228,9 +227,7 @@ def test_symmetric_contraction_ir_dict(): irreps_out = 16 * cue.Irreps("SO3", "0 + 1") e = cue.descriptors.symmetric_contraction(irreps_in, irreps_out, (1, 2, 3)) - old_poly = ( - e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial - ) + old_poly = e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial result = cue.descriptors.symmetric_contraction_ir_dict( irreps_in, irreps_out, (1, 2, 3) @@ -244,6 +241,8 @@ def test_symmetric_contraction_ir_dict(): def test_mace_symmetric_contraction_ir_dict(): from cuequivariance.group_theory.experimental.mace.symmetric_contractions import ( symmetric_contraction as mace_sc, + ) + from cuequivariance.group_theory.experimental.mace.symmetric_contractions import ( symmetric_contraction_ir_dict as mace_sc_ir_dict, ) @@ -251,9 +250,7 @@ def test_mace_symmetric_contraction_ir_dict(): irreps_out = 4 * cue.Irreps("SO3", "0 + 1") e, projection_old = mace_sc(irreps_in, irreps_out, [1, 2, 3]) - old_poly = ( - e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial - ) + old_poly = e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial result, projection_new = mace_sc_ir_dict(irreps_in, irreps_out, [1, 2, 3]) @@ -296,17 +293,11 @@ def test_channelwise_numpy_evaluation(): irreps1 = cue.Irreps(cue.O3, "4x0e + 2x1o") irreps_sh = cue.Irreps(cue.O3, "0e + 1o") - e = cue.descriptors.channelwise_tensor_product(irreps1, irreps_sh, simplify_irreps3=True) + e = cue.descriptors.channelwise_tensor_product( + irreps1, irreps_sh, simplify_irreps3=True + ) result = cue.descriptors.channelwise_tensor_product_ir_dict(irreps1, irreps_sh) - # Generate random inputs matching the unsplit polynomial - np.random.seed(42) - inputs_orig = [np.random.randn(op.size) for op in e.polynomial.inputs] - [out_orig] = e.polynomial(*inputs_orig) - - # The split polynomial has more operands — reconstruct matching inputs - inputs_split = [np.random.randn(op.size) for op in result.polynomial.inputs] - # Use the same flat data for both # Unsplit: [weights, input1_flat, input2_flat] # Split: [weights, input1_ir0, input1_ir1, input2_ir0, input2_ir1] diff --git a/cuequivariance_jax/cuequivariance_jax/nnx.py b/cuequivariance_jax/cuequivariance_jax/nnx.py index 78570b0..a8956a9 100644 --- a/cuequivariance_jax/cuequivariance_jax/nnx.py +++ b/cuequivariance_jax/cuequivariance_jax/nnx.py @@ -246,7 +246,9 @@ def __init__( scale = scale / jnp.sqrt(num_indices) self.poly = cue.descriptors.linear(irreps_in, irreps_out).polynomial * scale self.w = nnx.Param( - jax.random.normal(rngs.params(), (num_indices, self.poly.inputs[0].size), dtype) + jax.random.normal( + rngs.params(), (num_indices, self.poly.inputs[0].size), dtype + ) ) def __call__( @@ -262,7 +264,11 @@ def __call__( [y_flat] = segmented_polynomial( self.poly, [self.w[...], x_flat], - [jax.ShapeDtypeStruct((num_elements, self.poly.outputs[0].size), x_flat.dtype)], + [ + jax.ShapeDtypeStruct( + (num_elements, self.poly.outputs[0].size), x_flat.dtype + ) + ], [Repeats(num_index_counts), None, None], method="indexed_linear", name=self.name, From 87c47418193a0dafc796964745dc731c47105771 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 14 Apr 2026 10:52:40 +0200 Subject: [PATCH 09/14] Fix doctest: update symmetric_contraction weight repr The weight irreps changed from per-degree blocks (32x0+80x0+176x0) to a single scalar block (288x0) after the _core extraction. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../group_theory/descriptors/symmetric_contractions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py index 1575e68..64ba2a6 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py @@ -114,7 +114,7 @@ def symmetric_contraction( ... 16 * cue.Irreps("SO3", "0 + 1"), ... (1, 2, 3) ... ) - ╭ a=32x0+80x0+176x0 b=16x0+16x1+16x2 -> C=16x0+16x1 + ╭ a=288x0 b=16x0+16x1+16x2 -> C=16x0+16x1 │ []·a[u]·b[u]➜C[u] ─────────── num_paths=4 u=16 │ []·a[u]·b[u]·b[u]➜C[u] ────── num_paths=37 u=16 ╰─ []·a[u]·b[u]·b[u]·b[u]➜C[u] ─ num_paths=437 u=16 From 3962ce09b794a72974a0d119c2872f9645f8486d Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 14 Apr 2026 10:54:51 +0200 Subject: [PATCH 10/14] CI: add Python 3.14 to cuequivariance test matrix Skip numpy 1.26 downgrade test on 3.14 since numpy 1.26 does not support Python 3.14. Co-Authored-By: Claude Opus 4.6 (1M context) --- .github/workflows/tests.yml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8aa7fee..a81eb73 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,28 +14,30 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.12"] + python-version: ["3.10", "3.12", "3.14"] steps: - uses: actions/checkout@v4 - + - uses: ./.github/actions/setup-python-uv with: python-version: ${{ matrix.python-version }} - + - uses: ./.github/actions/setup-cuequivariance with: install-graphviz: 'true' - + - name: Test with pytest run: | pytest --doctest-modules -x -m "not slow" cuequivariance - + - name: Downgrade numpy + if: matrix.python-version != '3.14' run: | python -m uv pip install -U "numpy==1.26.*" - + - name: Test with pytest (numpy 1.26) + if: matrix.python-version != '3.14' run: | pytest --doctest-modules -x -m "not slow" cuequivariance From cd4d0568dd43803740495771f62346948de34ca2 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 14 Apr 2026 11:09:04 +0200 Subject: [PATCH 11/14] fix --- cuequivariance_jax/cuequivariance_jax/nnx.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cuequivariance_jax/cuequivariance_jax/nnx.py b/cuequivariance_jax/cuequivariance_jax/nnx.py index a8956a9..4716961 100644 --- a/cuequivariance_jax/cuequivariance_jax/nnx.py +++ b/cuequivariance_jax/cuequivariance_jax/nnx.py @@ -259,6 +259,7 @@ def __call__( # Convert dict (batch, mul, ir.dim) -> ir_mul flat order x_ir_mul = jax.tree.map(lambda v: rearrange(v, "... m i -> ... i m"), x) x_flat = ir_dict.dict_to_flat(self.irreps_in, x_ir_mul) + x_flat = x_flat.astype(self.w[...].dtype) num_elements = x_flat.shape[0] [y_flat] = segmented_polynomial( From e44a55778f8a7d4db976e8b5787fa60b48ecff17 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 14 Apr 2026 12:05:21 +0200 Subject: [PATCH 12/14] 3.13 --- .github/workflows/tests.yml | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a81eb73..97c3dd6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,7 +14,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.12", "3.14"] + python-version: ["3.10", "3.12", "3.13"] steps: - uses: actions/checkout@v4 @@ -32,12 +32,10 @@ jobs: pytest --doctest-modules -x -m "not slow" cuequivariance - name: Downgrade numpy - if: matrix.python-version != '3.14' run: | python -m uv pip install -U "numpy==1.26.*" - name: Test with pytest (numpy 1.26) - if: matrix.python-version != '3.14' run: | pytest --doctest-modules -x -m "not slow" cuequivariance @@ -51,13 +49,13 @@ jobs: steps: - uses: actions/checkout@v4 - + - uses: ./.github/actions/setup-python-uv with: python-version: ${{ matrix.python-version }} - + - uses: ./.github/actions/setup-cuequivariance-jax - + - name: Test with pytest run: | XLA_PYTHON_CLIENT_PREALLOCATE=false pytest --doctest-modules -x -m "not slow" cuequivariance_jax @@ -68,17 +66,17 @@ jobs: steps: - uses: actions/checkout@v4 - + - uses: ./.github/actions/setup-python-uv with: python-version: "3.12" - + - name: Install without flax run: | python -m uv pip install -U jax python -m uv pip install -e ./cuequivariance python -m uv pip install -e ./cuequivariance_jax - + - name: Verify import without flax run: | python -c "import cuequivariance_jax; print('cuex', cuequivariance_jax.__version__)" @@ -95,13 +93,13 @@ jobs: steps: - uses: actions/checkout@v4 - + - uses: ./.github/actions/setup-python-uv with: python-version: ${{ matrix.python-version }} - + - uses: ./.github/actions/setup-cuequivariance-torch - + - name: Test with pytest run: | pytest --doctest-modules -x -m "not slow" cuequivariance_torch From afa9cac86c21e38b7468ed47a39b37bf18f46f09 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 14 Apr 2026 13:54:47 +0200 Subject: [PATCH 13/14] update skills --- cuequivariance/cuequivariance/SKILL.md | 102 +++++++++-- .../cuequivariance_jax/SKILL.md | 169 +++++++----------- .../cuequivariance_torch/SKILL.md | 16 +- 3 files changed, 168 insertions(+), 119 deletions(-) diff --git a/cuequivariance/cuequivariance/SKILL.md b/cuequivariance/cuequivariance/SKILL.md index a6325f7..537ad65 100644 --- a/cuequivariance/cuequivariance/SKILL.md +++ b/cuequivariance/cuequivariance/SKILL.md @@ -1,6 +1,6 @@ --- name: cuequivariance -description: Define custom groups (Irrep subclasses), build segmented tensor products with CG coefficients, create equivariant polynomials, and use built-in descriptors (linear, tensor products, spherical harmonics). Use when working with cuequivariance group theory, irreps, or segmented polynomials. +description: Define custom groups (Irrep subclasses), build segmented tensor products with CG coefficients, create equivariant polynomials and IrDictPolynomials, and use built-in descriptors (linear, tensor products, spherical harmonics). Use when working with cuequivariance group theory, irreps, or segmented polynomials. --- # cuequivariance: Groups, Irreps, and Segmented Polynomials @@ -10,7 +10,9 @@ description: Define custom groups (Irrep subclasses), build segmented tensor pro `cuequivariance` (imported as `cue`) provides two core abstractions: 1. **Group theory**: `Irrep` subclasses define irreducible representations of Lie groups (SO3, O3, SU2, or custom). `Irreps` manages collections with multiplicities. -2. **Segmented polynomials**: `SegmentedTensorProduct` describes tensor contractions over segments of varying shape, linked by `Path` objects carrying Clebsch-Gordan coefficients. `SegmentedPolynomial` wraps multiple STPs into a polynomial with named inputs/outputs. `EquivariantPolynomial` attaches group representations to each operand. +2. **Segmented polynomials**: `SegmentedTensorProduct` describes tensor contractions over segments of varying shape, linked by `Path` objects carrying Clebsch-Gordan coefficients. `SegmentedPolynomial` wraps multiple STPs into a polynomial with named inputs/outputs. Two higher-level wrappers attach group representations: + - `EquivariantPolynomial` — dense operands with `IrrepsAndLayout` metadata + - `IrDictPolynomial` — operands already split by irrep, with per-group `Irreps` metadata for the `dict[Irrep, Array]` workflow ## Defining a custom group @@ -122,8 +124,8 @@ for mul, ir in irreps: `IrrepsLayout` controls memory order within each `(mul, ir)` block: -- `cue.ir_mul`: data ordered as `(ir.dim, mul)` -- **used by all descriptors** -- `cue.mul_ir`: data ordered as `(mul, ir.dim)` -- **used by nnx dict[Irrep, Array]** +- `cue.ir_mul`: data ordered as `(ir.dim, mul)` — **used by all descriptors and ir_dict internally** +- `cue.mul_ir`: data ordered as `(mul, ir.dim)` — **used by nnx `dict[Irrep, Array]` and PyTorch** `IrrepsAndLayout` combines irreps with a layout into a `Rep`: @@ -181,9 +183,14 @@ for cg in cue.clebsch_gordan(ir1, ir2, ir3): d.add_path((mul1, mul2, mul3), seg_in1, seg_in2, seg_out, c=cg) ``` -## Using descriptors (high-level API) +## Descriptors -All descriptors return `cue.EquivariantPolynomial`: +All descriptors come in two variants: + +- **Original** — returns `EquivariantPolynomial` with dense operands +- **`_ir_dict`** — returns `IrDictPolynomial` with operands already split by irrep + +### EquivariantPolynomial descriptors ```python # Fully connected tensor product (all input-output irrep combinations) @@ -199,12 +206,17 @@ e = cue.descriptors.channelwise_tensor_product( cue.Irreps("SO3", "0 + 1"), simplify_irreps3=True, ) +# Full (weightless) tensor product +e = cue.descriptors.full_tensor_product( + cue.Irreps("SO3", "2x0 + 1x1"), cue.Irreps("SO3", "0 + 1"), +) + # Elementwise tensor product (paired channels) e = cue.descriptors.elementwise_tensor_product( cue.Irreps("SO3", "4x0 + 4x1"), cue.Irreps("SO3", "4x0 + 4x1"), ) -# Linear equivariant map (no second input, just weight x input) +# Linear equivariant map (weight x input) e = cue.descriptors.linear( cue.Irreps("SO3", "4x0 + 2x1"), cue.Irreps("SO3", "3x0 + 5x1"), @@ -217,10 +229,80 @@ e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3]) e = cue.descriptors.symmetric_contraction( 64 * cue.Irreps("SO3", "0 + 1 + 2"), 64 * cue.Irreps("SO3", "0 + 1"), - [0, 1, 2, 3], + (1, 2, 3), ) ``` +### IrDictPolynomial descriptors + +Each `_ir_dict` variant returns an `IrDictPolynomial` whose polynomial is already split by irrep. The `input_irreps` and `output_irreps` tuples describe the operand groups. + +```python +# Channelwise tensor product +desc = cue.descriptors.channelwise_tensor_product_ir_dict( + 64 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), +) +# desc.polynomial — SegmentedPolynomial, already split by irrep +# desc.input_irreps — (weight_irreps, irreps1, irreps2) +# desc.output_irreps — (irreps_out,) + +# Fully connected tensor product +desc = cue.descriptors.fully_connected_tensor_product_ir_dict(irreps1, irreps2, irreps3) + +# Full (weightless) tensor product +desc = cue.descriptors.full_tensor_product_ir_dict(irreps1, irreps2) + +# Elementwise tensor product +desc = cue.descriptors.elementwise_tensor_product_ir_dict(irreps1, irreps2) + +# Linear +desc = cue.descriptors.linear_ir_dict(irreps_in, irreps_out) + +# Spherical harmonics +desc = cue.descriptors.spherical_harmonics_ir_dict(cue.O3(1, -1), [0, 1, 2, 3]) + +# Symmetric contraction +desc = cue.descriptors.symmetric_contraction_ir_dict(irreps_in, irreps_out, (1, 2, 3)) +``` + +### IrDictPolynomial + +`IrDictPolynomial` pairs a `SegmentedPolynomial` (already split by irrep) with the `Irreps` that describe each operand group. + +```python +desc = cue.descriptors.channelwise_tensor_product_ir_dict( + 32 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), +) + +desc.polynomial # SegmentedPolynomial — each operand is one (mul, ir) block +desc.input_irreps # (weight_irreps, irreps1, irreps2) +desc.output_irreps # (irreps_out,) + +# Scale coefficients +scaled_poly = desc.polynomial * 0.5 + +# Access individual operand info +for i, op in enumerate(desc.polynomial.inputs): + print(f"Input {i}: size={op.size}, num_segments={op.num_segments}") +``` + +Contract: for each `(mul, ir)` block in `input_irreps` / `output_irreps`, the corresponding polynomial operand has size `mul * ir.dim`. + +### split_polynomial_by_irreps + +The low-level function underlying `_ir_dict` descriptors. Splits one polynomial operand at irrep boundaries: + +```python +poly = e.polynomial # from an EquivariantPolynomial +poly = cue.split_polynomial_by_irreps(poly, 2, irreps_sh) # split input 2 +poly = cue.split_polynomial_by_irreps(poly, 1, irreps_in) # split input 1 +poly = cue.split_polynomial_by_irreps(poly, -1, irreps_out) # split output +``` + ### EquivariantPolynomial key methods ```python @@ -318,7 +400,6 @@ y = np.random.randn(ep.inputs[2].dim) | Component | Path | |-----------|------| | `Irrep` base class | `cuequivariance/group_theory/representations/irrep.py` | -| `Rep` base class | `cuequivariance/group_theory/representations/rep.py` | | `SO3` | `cuequivariance/group_theory/representations/irrep_so3.py` | | `O3` | `cuequivariance/group_theory/representations/irrep_o3.py` | | `SU2` | `cuequivariance/group_theory/representations/irrep_su2.py` | @@ -328,7 +409,8 @@ y = np.random.randn(ep.inputs[2].dim) | `SegmentedTensorProduct` | `cuequivariance/segmented_polynomials/segmented_tensor_product.py` | | `SegmentedPolynomial` | `cuequivariance/segmented_polynomials/segmented_polynomial.py` | | `EquivariantPolynomial` | `cuequivariance/group_theory/equivariant_polynomial.py` | +| `IrDictPolynomial` | `cuequivariance/group_theory/ir_dict_polynomial.py` | | Descriptors | `cuequivariance/group_theory/descriptors/` | -| `fully_connected_tensor_product` etc. | `cuequivariance/group_theory/descriptors/irreps_tp.py` | +| Tensor product descriptors | `cuequivariance/group_theory/descriptors/irreps_tp.py` | | `spherical_harmonics` | `cuequivariance/group_theory/descriptors/spherical_harmonics_.py` | | `symmetric_contraction` | `cuequivariance/group_theory/descriptors/symmetric_contractions.py` | diff --git a/cuequivariance_jax/cuequivariance_jax/SKILL.md b/cuequivariance_jax/cuequivariance_jax/SKILL.md index d61d2d7..d69fa1a 100644 --- a/cuequivariance_jax/cuequivariance_jax/SKILL.md +++ b/cuequivariance_jax/cuequivariance_jax/SKILL.md @@ -1,6 +1,6 @@ --- name: cuequivariance-jax -description: Execute equivariant polynomials in JAX using segmented_polynomial (naive/uniform_1d), equivariant_polynomial with RepArray, ir_dict with dict[Irrep, Array], and Flax NNX layers (IrrepsLinear, SphericalHarmonics). Use when writing JAX code with cuequivariance. +description: Execute equivariant polynomials in JAX using segmented_polynomial (naive/uniform_1d), the ir_dict workflow with IrDictPolynomial and dict[Irrep, Array], and Flax NNX layers (IrrepsLinear, SphericalHarmonics, IrrepsIndexedLinear). Use when writing JAX code with cuequivariance. --- # cuequivariance_jax: Executing Equivariant Polynomials in JAX @@ -9,11 +9,11 @@ description: Execute equivariant polynomials in JAX using segmented_polynomial ( `cuequivariance_jax` (imported as `cuex`) executes `cuequivariance` polynomials on GPU via JAX. It provides: -1. **Core primitive**: `cuex.segmented_polynomial()` -- JAX primitive with full AD/vmap/JIT support +1. **Core primitive**: `cuex.segmented_polynomial()` — JAX primitive with full AD/vmap/JIT support 2. **Two data representations** (both built on `segmented_polynomial`): - - `cuex.equivariant_polynomial()` + `RepArray` -- the original interface, a single contiguous array with representation metadata - - `cuex.ir_dict` module -- `dict[Irrep, Array]` interface, conceptually simpler, works naturally with `jax.tree` operations -3. **NNX layers**: `cuex.nnx` module -- Flax NNX `Module` wrappers using `dict[Irrep, Array]` + - `cuex.equivariant_polynomial()` + `RepArray` — the original interface, a single contiguous array with representation metadata + - `cuex.ir_dict` module — `dict[Irrep, Array]` interface, uses `IrDictPolynomial` descriptors, works naturally with `jax.tree` +3. **NNX layers**: `cuex.nnx` module — Flax NNX `Module` wrappers using `dict[Irrep, Array]` ## Execution methods @@ -65,20 +65,8 @@ y = jax.random.normal(key, (batch, poly.inputs[2].size)) # batched input 2 Inputs can have any number of batch axes (everything before the last axis). Standard NumPy broadcasting applies: each batch axis is either size-1 or a common size. Inputs with fewer batch dimensions are implicitly prepended with size-1 axes: ```python -# 2 batch axes with size-1 broadcasting -w = jnp.ones((1, 10, poly.inputs[0].size)) # shared across axis 0 -x = jnp.ones((5, 10, poly.inputs[1].size)) # 5 along axis 0 -y = jnp.ones((5, 1, poly.inputs[2].size)) # shared across axis 1 - -[out] = cuex.segmented_polynomial( - poly, [w, x, y], - [jax.ShapeDtypeStruct((5, 10, poly.outputs[0].size), jnp.float32)], - method="uniform_1d", -) -# out.shape == (5, 10, ...) - # Fewer batch dims: weights with no batch axis broadcast across all -w = jnp.ones((poly.inputs[0].size,)) # 0 batch axes -> prepended as (1, 1, ...) +w = jnp.ones((poly.inputs[0].size,)) # 0 batch axes -> broadcasts x = jnp.ones((5, 10, poly.inputs[1].size)) y = jnp.ones((5, 10, poly.inputs[2].size)) @@ -91,13 +79,9 @@ y = jnp.ones((5, 10, poly.inputs[2].size)) ### Indexing (gather/scatter) -Index arrays provide gather (for inputs) and scatter (for outputs). One index per operand (inputs + outputs), `None` means no indexing. Index arrays decouple input/output batch shapes -- the output shape is determined by the index ranges, not by the input shapes: +Index arrays provide gather (for inputs) and scatter (for outputs). One index per operand (inputs + outputs), `None` means no indexing: ```python -a = jnp.ones((1, 50, poly.inputs[0].size)) -b = jnp.ones((10, 50, poly.inputs[1].size)) -c = jnp.ones((100, 1, poly.inputs[2].size)) - i = jax.random.randint(key, (100, 50), 0, 10) # gather b along axis 0 j1 = jax.random.randint(key, (100, 50), 0, 11) # scatter output axis 0 j2 = jax.random.randint(key, (100, 1), 0, 12) # scatter output axis 1 @@ -108,12 +92,11 @@ j2 = jax.random.randint(key, (100, 1), 0, 12) # scatter output axis 1 indices=[None, np.s_[i, :], None, np.s_[j1, j2]], method="uniform_1d", ) -# out.shape == (11, 12, ...) -- determined by index ranges, not input shapes ``` ### Gradients -Fully differentiable -- supports `jax.grad`, `jax.jacobian`, `jax.jvp`, `jax.vmap`: +Fully differentiable — supports `jax.grad`, `jax.jacobian`, `jax.jvp`, `jax.vmap`: ```python def loss(w, x, y): @@ -127,55 +110,28 @@ def loss(w, x, y): grad_w = jax.grad(loss, 0)(w, x, y) ``` -## RepArray interface: equivariant_polynomial - -The original interface. Wraps `segmented_polynomial` with `RepArray` -- a single contiguous array with representation metadata: - -```python -e = cue.descriptors.fully_connected_tensor_product( - 4 * cue.Irreps("SO3", "0 + 1"), - cue.Irreps("SO3", "0 + 1"), - 4 * cue.Irreps("SO3", "0 + 1"), -) - -inputs = [ - cuex.randn(jax.random.key(i), rep, (batch,), jnp.float32) - for i, rep in enumerate(e.inputs) -] - -# Returns a RepArray with representation metadata -out = cuex.equivariant_polynomial(e, inputs, method="naive") -out.array # the raw jax.Array -out.reps # dict mapping axes to Rep objects -``` - ## ir_dict interface -An alternative to `RepArray`. Uses `dict[Irrep, Array]` where each value has shape `(..., multiplicity, irrep_dim)`. Conceptually simpler: works naturally with `jax.tree` operations and is the standard representation for NNX layers. +Uses `dict[Irrep, Array]` where each value has shape `(..., multiplicity, irrep_dim)`. This is the standard representation for NNX layers and works naturally with `jax.tree` operations. -### Preparing a polynomial for ir_dict +### Getting an ir_dict-ready polynomial -Descriptors produce `EquivariantPolynomial` with dense operands. To use `ir_dict`, split operands by irrep: +Use `_ir_dict` descriptor variants, which return `IrDictPolynomial` with the polynomial already split by irrep: ```python -e = cue.descriptors.channelwise_tensor_product( +desc = cue.descriptors.channelwise_tensor_product_ir_dict( 32 * cue.Irreps("SO3", "0 + 1"), cue.Irreps("SO3", "0 + 1"), cue.Irreps("SO3", "0 + 1"), - simplify_irreps3=True, ) -# Split irreps-typed operands into per-irrep pieces -# Order matters: split inner operands first to preserve operand indices -poly = ( - e.split_operand_by_irrep(2) # split input 2 - .split_operand_by_irrep(1) # split input 1 - .split_operand_by_irrep(-1) # split output - .polynomial -) -# After split: each operand has a single irrep type, mapping naturally to dict[Irrep, Array] +poly = desc.polynomial # SegmentedPolynomial, already split by irrep +weight_irreps, irreps1, irreps2 = desc.input_irreps +(irreps_out,) = desc.output_irreps # tuple unpacking to get the single output group ``` +Each polynomial operand corresponds to exactly one `(mul, ir)` block. The `input_irreps` and `output_irreps` tuples describe how operands group into logical operand groups (weights, node features, spherical harmonics, output). + ### Executing with segmented_polynomial_uniform_1d ```python @@ -194,7 +150,7 @@ node_feats = { } x1 = jax.tree.map(lambda v: rearrange(v, "n m i -> n i m"), node_feats) -# Spherical harmonics: (edges, ir.dim) -- no multiplicity dimension +# Spherical harmonics: (edges, ir.dim) — no multiplicity dimension sph = { cue.SO3(0): jnp.ones((num_edges, 1)), cue.SO3(1): jnp.ones((num_edges, 3)), @@ -203,7 +159,6 @@ sph = { # Build output template senders = jax.random.randint(key, (num_edges,), 0, num_nodes) receivers = jax.random.randint(key, (num_edges,), 0, num_nodes) -irreps_out = e.outputs[0].irreps out_template = { ir: jax.ShapeDtypeStruct( (num_nodes, desc.num_segments) + desc.segment_shape, w.dtype @@ -242,6 +197,28 @@ z = cuex.ir_dict.irreps_zeros_like(x) template = cuex.ir_dict.mul_ir_dict(irreps, jax.ShapeDtypeStruct(shape, dtype)) ``` +## RepArray interface: equivariant_polynomial + +The original interface. Wraps `segmented_polynomial` with `RepArray` — a single contiguous array with representation metadata: + +```python +e = cue.descriptors.fully_connected_tensor_product( + 4 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + 4 * cue.Irreps("SO3", "0 + 1"), +) + +inputs = [ + cuex.randn(jax.random.key(i), rep, (batch,), jnp.float32) + for i, rep in enumerate(e.inputs) +] + +# Returns a RepArray with representation metadata +out = cuex.equivariant_polynomial(e, inputs, method="naive") +out.array # the raw jax.Array +out.reps # dict mapping axes to Rep objects +``` + ## NNX layers ### IrrepsLinear @@ -273,6 +250,8 @@ Implementation uses `jnp.einsum("uv,...ui->...vi", w, x[ir])` per irrep with `1/ ### SphericalHarmonics +Uses `spherical_harmonics_ir_dict` internally for the `dict[Irrep, Array]` output: + ```python sh = cuex.nnx.SphericalHarmonics(max_degree=3, eps=0.0) @@ -338,43 +317,25 @@ For `equivariant_polynomial()` (RepArray interface): ```python e = cue.descriptors.channelwise_tensor_product(...) e = e.squeeze_modes().flatten_coefficient_modes() -# If still >1 mode: e = e.flatten_modes(["u", "w"]) out = cuex.equivariant_polynomial(e, inputs, method="uniform_1d") ``` -For `ir_dict` (dict[Irrep, Array] interface): +For `ir_dict` (dict[Irrep, Array] interface), use `_ir_dict` descriptors directly: ```python -e = cue.descriptors.channelwise_tensor_product(..., simplify_irreps3=True) -poly = ( - e.split_operand_by_irrep(2) # split input 2 - .split_operand_by_irrep(1) # split input 1 - .split_operand_by_irrep(-1) # split output - .polynomial +desc = cue.descriptors.channelwise_tensor_product_ir_dict( + irreps_in, irreps_sh, irreps_out ) +poly = desc.polynomial # Each operand has a single irrep type -> maps naturally to dict[Irrep, Array] ``` -### Why split_operand_by_irrep matters +### Why splitting by irrep matters -Without splitting, a dense operand like `32x0+32x1` requires all irreps packed into a single contiguous buffer. After `split_operand_by_irrep`, each irrep gets its own separate buffer passed to the CUDA kernel via FFI. The buffers no longer need to be contiguous with each other. +Without splitting, a dense operand like `32x0+32x1` requires all irreps packed into a single contiguous buffer. After splitting, each irrep gets its own separate buffer passed to the CUDA kernel via FFI. The buffers no longer need to be contiguous with each other. This is especially useful when the polynomial is preceded or followed by per-irrep linear layers (like `IrrepsLinear`). With split operands, no transpose or copy is needed between the linear layers and the polynomial — the `dict[Irrep, Array]` flows directly through the pipeline. -## RepArray - -Representation-aware JAX array: - -```python -rep = cue.IrrepsAndLayout(cue.Irreps("SO3", "4x0 + 2x1"), cue.ir_mul) -x = cuex.RepArray(rep, jnp.ones((batch, rep.dim))) -x = cuex.randn(jax.random.key(0), rep, (batch,), jnp.float32) - -x.array # raw jax.Array -x.reps # {axis: Rep} -x.irreps # Irreps (if last axis is IrrepsAndLayout) -``` - ## Complete GNN message-passing example This pattern is used in NequIP, MACE, and similar equivariant GNN models: @@ -382,20 +343,13 @@ This pattern is used in NequIP, MACE, and similar equivariant GNN models: ```python class MessagePassing(nnx.Module): def __init__(self, irreps_in, irreps_sh, irreps_out, epsilon, *, name, dtype, rngs): - e = ( - cue.descriptors.channelwise_tensor_product( - irreps_in, irreps_sh, irreps_out, True - ) - * epsilon - ) - self.weight_numel = e.inputs[0].dim - self.irreps_out = e.outputs[0].irreps - self.poly = ( - e.split_operand_by_irrep(2) - .split_operand_by_irrep(1) - .split_operand_by_irrep(-1) - .polynomial + self.name = name + desc = cue.descriptors.channelwise_tensor_product_ir_dict( + irreps_in, irreps_sh, irreps_out ) + (self.irreps_out,) = desc.output_irreps + self.poly = desc.polynomial * epsilon + self.weight_numel = self.poly.inputs[0].size def __call__(self, weights, node_feats, sph, senders, receivers, num_nodes): # weights: (num_edges, weight_numel) @@ -425,6 +379,20 @@ class MessagePassing(nnx.Module): } ``` +## RepArray + +Representation-aware JAX array: + +```python +rep = cue.IrrepsAndLayout(cue.Irreps("SO3", "4x0 + 2x1"), cue.ir_mul) +x = cuex.RepArray(rep, jnp.ones((batch, rep.dim))) +x = cuex.randn(jax.random.key(0), rep, (batch,), jnp.float32) + +x.array # raw jax.Array +x.reps # {axis: Rep} +x.irreps # Irreps (if last axis is IrrepsAndLayout) +``` + ## Key file locations | Component | Path | @@ -437,6 +405,5 @@ class MessagePassing(nnx.Module): | `ir_dict` module | `cuequivariance_jax/ir_dict.py` | | `nnx` module | `cuequivariance_jax/nnx.py` | | `RepArray` | `cuequivariance_jax/rep_array/rep_array_.py` | -| `Repeats` / utilities | `cuequivariance_jax/segmented_polynomials/utils.py` | | NequIP example | `cuequivariance_jax/examples/nequip_nnx.py` | | MACE example | `cuequivariance_jax/examples/mace_nnx.py` | diff --git a/cuequivariance_torch/cuequivariance_torch/SKILL.md b/cuequivariance_torch/cuequivariance_torch/SKILL.md index 002765a..e2a92e4 100644 --- a/cuequivariance_torch/cuequivariance_torch/SKILL.md +++ b/cuequivariance_torch/cuequivariance_torch/SKILL.md @@ -9,7 +9,7 @@ description: Execute equivariant tensor products in PyTorch using SegmentedPolyn `cuequivariance_torch` (imported as `cuet`) executes `cuequivariance` polynomials on GPU via PyTorch. It provides: -1. **Core primitive**: `cuet.SegmentedPolynomial` -- `torch.nn.Module` with multiple CUDA backends +1. **Core primitive**: `cuet.SegmentedPolynomial` — `torch.nn.Module` with multiple CUDA backends 2. **High-level operations** (`torch.nn.Module`): `ChannelWiseTensorProduct`, `FullyConnectedTensorProduct`, `Linear`, `SymmetricContraction`, `SphericalHarmonics`, `Rotation`, `Inversion` 3. **Layers**: `cuet.layers.BatchNorm`, `cuet.layers.FullyConnectedTensorProductConv` (message passing) 4. **Utilities**: `triangle_attention`, `triangle_multiplicative_update`, `attention_pair_bias` (AlphaFold2-style) @@ -96,8 +96,8 @@ All operations are `torch.nn.Module` subclasses. They wrap `SegmentedPolynomial` `IrrepsLayout` controls memory order within each `(mul, ir)` block: -- `cue.mul_ir`: data ordered as `(mul, ir.dim)` -- **default, compatible with e3nn** -- `cue.ir_mul`: data ordered as `(ir.dim, mul)` -- **used internally by descriptors** +- `cue.mul_ir`: data ordered as `(mul, ir.dim)` — **default, compatible with e3nn** +- `cue.ir_mul`: data ordered as `(ir.dim, mul)` — **used internally by descriptors** Operations accept `layout` (applies to all), or per-operand `layout_in1`, `layout_in2`, `layout_out`. @@ -150,7 +150,7 @@ tp = cuet.FullyConnectedTensorProduct( cue.Irreps("O3", "0e + 1o"), # irreps_in2 cue.Irreps("O3", "4x0e + 4x1o"), # irreps_out layout=cue.mul_ir, - internal_weights=True, # store weights as parameters + internal_weights=True, device="cuda", ) @@ -195,7 +195,7 @@ MACE-style symmetric contraction with element-indexed weights. sc = cuet.SymmetricContraction( cue.Irreps("O3", "32x0e + 32x1o"), # irreps_in (uniform mul required) cue.Irreps("O3", "32x0e"), # irreps_out (uniform mul required) - contraction_degree=3, # polynomial degree + contraction_degree=3, num_elements=95, # number of chemical elements layout=cue.ir_mul, dtype=torch.float32, @@ -213,8 +213,8 @@ Default method: `"uniform_1d"` if segments are uniform, else `"naive"`. ```python sh = cuet.SphericalHarmonics( - ls=[0, 1, 2, 3], # degrees - normalize=True, # normalize input vectors + ls=[0, 1, 2, 3], + normalize=True, device="cuda", ) @@ -283,7 +283,7 @@ conv = cuet.layers.FullyConnectedTensorProductConv( in_irreps=cue.Irreps("O3", "4x0e + 4x1o"), sh_irreps=cue.Irreps("O3", "0e + 1o"), out_irreps=cue.Irreps("O3", "4x0e + 4x1o"), - mlp_channels=[16, 32, 32], # MLP for path weights + mlp_channels=[16, 32, 32], mlp_activation=torch.nn.ReLU(), batch_norm=True, layout=cue.ir_mul, From 31e2b6ccd93d9e45cc2ee1d82237034455cd0c4d Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 14 Apr 2026 14:21:56 +0200 Subject: [PATCH 14/14] comment --- cuequivariance/cuequivariance/etc/linalg.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cuequivariance/cuequivariance/etc/linalg.py b/cuequivariance/cuequivariance/etc/linalg.py index 3fdb33b..54bf357 100644 --- a/cuequivariance/cuequivariance/etc/linalg.py +++ b/cuequivariance/cuequivariance/etc/linalg.py @@ -107,6 +107,10 @@ def limit_denominator(n, d, max_denominator: int): n1, d1 = p0 + k * p1, q0 + k * q1 n2, d2 = p1, q1 with np.errstate(over="ignore"): + # The intermediate products (n2*d0, n0*d2) overflow int64 (~2^102), but the + # overflow is benign: their difference is bounded by d0 < 2^62 (fits in int64), + # and two's complement subtraction recovers it exactly. The final product + # d1*(difference) is also bounded by d0 < 2^63. mask = np.abs(d1 * (n2 * d0 - n0 * d2)) <= np.abs(d2 * (n1 * d0 - n0 * d1)) return np.where( d0 <= max_denominator,