diff --git a/cuequivariance/cuequivariance/SKILL.md b/cuequivariance/cuequivariance/SKILL.md new file mode 100644 index 00000000..a6325f78 --- /dev/null +++ b/cuequivariance/cuequivariance/SKILL.md @@ -0,0 +1,334 @@ +--- +name: cuequivariance +description: Define custom groups (Irrep subclasses), build segmented tensor products with CG coefficients, create equivariant polynomials, and use built-in descriptors (linear, tensor products, spherical harmonics). Use when working with cuequivariance group theory, irreps, or segmented polynomials. +--- + +# cuequivariance: Groups, Irreps, and Segmented Polynomials + +## Overview + +`cuequivariance` (imported as `cue`) provides two core abstractions: + +1. **Group theory**: `Irrep` subclasses define irreducible representations of Lie groups (SO3, O3, SU2, or custom). `Irreps` manages collections with multiplicities. +2. **Segmented polynomials**: `SegmentedTensorProduct` describes tensor contractions over segments of varying shape, linked by `Path` objects carrying Clebsch-Gordan coefficients. `SegmentedPolynomial` wraps multiple STPs into a polynomial with named inputs/outputs. `EquivariantPolynomial` attaches group representations to each operand. + +## Defining a custom group + +Subclass `cue.Irrep` (a frozen dataclass) and implement: + +```python +from __future__ import annotations +import dataclasses +import re +from typing import Iterator +import numpy as np +import cuequivariance as cue + +@dataclasses.dataclass(frozen=True) +class Z2(cue.Irrep): + odd: bool # dataclass field -- required for correct __eq__ and __hash__ + + # No __init__ needed -- @dataclass(frozen=True) generates it: Z2(odd=True) + + @classmethod + def regexp_pattern(cls) -> re.Pattern: + # Pattern whose first group is passed to from_string + return re.compile(r"(odd|even)") + + @classmethod + def from_string(cls, string: str) -> Z2: + return cls(odd=string == "odd") + + def __repr__(rep: Z2) -> str: + return "odd" if rep.odd else "even" + + def __mul__(rep1: Z2, rep2: Z2) -> Iterator[Z2]: + # Selection rule: which irreps appear in the tensor product rep1 x rep2 + return [Z2(odd=rep1.odd ^ rep2.odd)] + + @classmethod + def clebsch_gordan(cls, rep1: Z2, rep2: Z2, rep3: Z2) -> np.ndarray: + # Shape: (num_paths, rep1.dim, rep2.dim, rep3.dim) + if rep3 in rep1 * rep2: + return np.array([[[[1]]]]) + else: + return np.zeros((0, 1, 1, 1)) + + @property + def dim(rep: Z2) -> int: + return 1 + + def __lt__(rep1: Z2, rep2: Z2) -> bool: + # Ordering for sorting; dimension is compared first by the base class + return rep1.odd < rep2.odd + + @classmethod + def iterator(cls) -> Iterator[Z2]: + # Must yield trivial irrep first + for odd in [False, True]: + yield Z2(odd=odd) + + def discrete_generators(rep: Z2) -> np.ndarray: + # Shape: (num_generators, dim, dim) + if rep.odd: + return -np.ones((1, 1, 1)) + else: + return np.ones((1, 1, 1)) + + def continuous_generators(rep: Z2) -> np.ndarray: + # Shape: (lie_dim, dim, dim) -- Z2 is discrete, so lie_dim=0 + return np.zeros((0, rep.dim, rep.dim)) + + def algebra(self) -> np.ndarray: + # Shape: (lie_dim, lie_dim, lie_dim) -- structure constants [X_i, X_j] = A_ijk X_k + return np.zeros((0, 0, 0)) + + +# Usage: +irreps = cue.Irreps(Z2, "3x odd + 2x even") # dim=5 +``` + +### Required methods summary + +| Method | Returns | Purpose | +|--------|---------|---------| +| `regexp_pattern()` | `re.Pattern` | Parse string like `"1"`, `"0e"`, `"odd"` | +| `from_string(s)` | `Irrep` | Construct from matched string | +| `__repr__` | `str` | Canonical string form | +| `__mul__(a, b)` | `Iterator[Irrep]` | Selection rule for tensor product | +| `clebsch_gordan(a, b, c)` | `ndarray (n, d1, d2, d3)` | CG coefficients | +| `dim` (property) | `int` | Dimension of representation | +| `__lt__(a, b)` | `bool` | Ordering (dimension first, then custom) | +| `iterator()` | `Iterator[Irrep]` | All irreps, trivial first | +| `continuous_generators()` | `ndarray (lie_dim, dim, dim)` | Lie algebra generators | +| `discrete_generators()` | `ndarray (n, dim, dim)` | Finite symmetry generators | +| `algebra()` | `ndarray (lie_dim, lie_dim, lie_dim)` | Structure constants | + +### Built-in groups + +- **`cue.SO3(l)`**: 3D rotations. `l` is a non-negative integer. `dim = 2l+1`. String: `"0"`, `"1"`, `"2"`. +- **`cue.O3(l, p)`**: 3D rotations + parity. `p=+1` (even) or `p=-1` (odd). String: `"0e"`, `"1o"`, `"2e"`. +- **`cue.SU2(j)`**: Spin group. `j` is a non-negative half-integer. String: `"0"`, `"1/2"`, `"1"`. + +## Irreps and layout + +```python +irreps = cue.Irreps("SO3", "16x0 + 4x1 + 2x2") # 16 scalars, 4 vectors, 2 rank-2 +irreps.dim # 16*1 + 4*3 + 2*5 = 38 + +for mul, ir in irreps: + print(mul, ir, ir.dim) # 16 0 1, then 4 1 3, then 2 2 5 +``` + +`IrrepsLayout` controls memory order within each `(mul, ir)` block: + +- `cue.ir_mul`: data ordered as `(ir.dim, mul)` -- **used by all descriptors** +- `cue.mul_ir`: data ordered as `(mul, ir.dim)` -- **used by nnx dict[Irrep, Array]** + +`IrrepsAndLayout` combines irreps with a layout into a `Rep`: + +```python +rep = cue.IrrepsAndLayout(cue.Irreps("SO3", "4x0 + 2x1"), cue.ir_mul) +rep.dim # 10 +``` + +## Building a SegmentedTensorProduct from scratch + +The subscripts string uses Einstein notation. Operands are comma-separated, coefficient modes follow `+`. + +```python +# Matrix-vector multiply: y_i = sum_j M_ij * x_j +d = cue.SegmentedTensorProduct.from_subscripts("ij,j,i") +d.add_segment(0, (3, 4)) # operand 0: matrix segment of shape (3, 4) +d.add_segment(1, (4,)) # operand 1: vector of size 4 +d.add_segment(2, (3,)) # operand 2: output vector of size 3 +d.add_path(0, 0, 0, c=1.0) # link segments 0,0,0 with coefficient=1.0 + +poly = cue.SegmentedPolynomial.eval_last_operand(d) # last operand becomes output +[y] = poly(M_flat, x) # numpy evaluation +``` + +### Multi-segment STP (how descriptors work internally) + +Descriptors build STPs with multiple segments per operand. Each segment corresponds to an irrep block: + +```python +# Linear equivariant map: output[iv] = sum_u weight[uv] * input[iu] +d = cue.SegmentedTensorProduct.from_subscripts("uv,iu,iv") + +# Segment for l=1: ir_dim=3, mul_in=2, mul_out=5 +s_in_0 = d.add_segment(1, (3, 2)) # input block +s_out_0 = d.add_segment(2, (3, 5)) # output block +d.add_path((2, 5), s_in_0, s_out_0, c=1.0) + +# Segment for l=0: ir_dim=1, mul_in=4, mul_out=3 +s_in_1 = d.add_segment(1, (1, 4)) +s_out_1 = d.add_segment(2, (1, 3)) +d.add_path((4, 3), s_in_1, s_out_1, c=1.0) +``` + +### Weights operand + +For weighted tensor products (subscript starting with `uvw` or `uv`), the first operand is always weights. The weight segment shape is `(mul_1, mul_2, ...)` matching the multiplicity modes. The weights operand gets `new_scalars()` irreps since weights are invariant. + +### CG coefficients as path coefficients + +```python +d = cue.SegmentedTensorProduct.from_subscripts("uvw,iu,jv,kw+ijk") +# For each pair of input irreps and each output irrep in the selection rule: +for cg in cue.clebsch_gordan(ir1, ir2, ir3): + # cg has shape (ir1.dim, ir2.dim, ir3.dim) + d.add_path((mul1, mul2, mul3), seg_in1, seg_in2, seg_out, c=cg) +``` + +## Using descriptors (high-level API) + +All descriptors return `cue.EquivariantPolynomial`: + +```python +# Fully connected tensor product (all input-output irrep combinations) +e = cue.descriptors.fully_connected_tensor_product( + 16 * cue.Irreps("SO3", "0 + 1 + 2"), + 16 * cue.Irreps("SO3", "0 + 1 + 2"), + 16 * cue.Irreps("SO3", "0 + 1 + 2"), +) + +# Channelwise tensor product (same-channel only, sparse) +e = cue.descriptors.channelwise_tensor_product( + 64 * cue.Irreps("SO3", "0 + 1"), cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), simplify_irreps3=True, +) + +# Elementwise tensor product (paired channels) +e = cue.descriptors.elementwise_tensor_product( + cue.Irreps("SO3", "4x0 + 4x1"), cue.Irreps("SO3", "4x0 + 4x1"), +) + +# Linear equivariant map (no second input, just weight x input) +e = cue.descriptors.linear( + cue.Irreps("SO3", "4x0 + 2x1"), + cue.Irreps("SO3", "3x0 + 5x1"), +) + +# Spherical harmonics +e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3]) + +# Symmetric contraction (MACE-style) +e = cue.descriptors.symmetric_contraction( + 64 * cue.Irreps("SO3", "0 + 1 + 2"), + 64 * cue.Irreps("SO3", "0 + 1"), + [0, 1, 2, 3], +) +``` + +### EquivariantPolynomial key methods + +```python +e.inputs # tuple of Rep (group representations for each input) +e.outputs # tuple of Rep +e.polynomial # the underlying SegmentedPolynomial + +# Numpy evaluation +[out] = e(weights, input1, input2) + +# Preparing for uniform_1d execution (see cuequivariance_jax SKILL.md) +e_ready = e.squeeze_modes().flatten_coefficient_modes() + +# Split an operand into per-irrep pieces (for ir_dict interface) +e_split = e.split_operand_by_irrep(1).split_operand_by_irrep(-1) + +# Scale all coefficients +e_scaled = e * 0.5 + +# Fuse compatible STPs +e_fused = e.fuse_stps() +``` + +### normalize_paths_for_operand + +Called internally by descriptors. Normalizes path coefficients so that a random input produces unit-variance output for the specified operand. Critical for numerical stability. + +## SegmentedPolynomial structure + +```python +poly = e.polynomial +poly.num_inputs # number of input operands +poly.num_outputs # number of output operands +poly.inputs # tuple of SegmentedOperand +poly.outputs # tuple of SegmentedOperand +poly.operations # tuple of (Operation, SegmentedTensorProduct) + +# Each operation maps buffers to STP operands +for op, stp in poly.operations: + print(op.buffers) # e.g., (0, 1, 2) means inputs[0], inputs[1] -> outputs[0] + print(stp.subscripts) +``` + +### SegmentedOperand + +```python +operand = poly.inputs[0] +operand.num_segments # how many segments +operand.segments # tuple of shape tuples, e.g., ((3, 4), (1, 2)) +operand.size # total flattened size (sum of products of segment shapes) +operand.ndim # number of dimensions per segment +operand.all_same_segment_shape() # True if all segments have identical shape +operand.segment_shape # the common shape (only if all_same_segment_shape) +``` + +## Custom equivariant polynomial from scratch + +```python +import numpy as np +import cuequivariance as cue + +# Build a fully-connected SO3(1)xSO3(1)->SO3(0) tensor product manually +cg = cue.clebsch_gordan(cue.SO3(1), cue.SO3(1), cue.SO3(0)) # shape (1, 3, 3, 1) + +d = cue.SegmentedTensorProduct.from_subscripts("uvw,iu,jv,kw+ijk") +d.add_segment(1, (3, 4)) # input1: 4x SO3(1), shape=(ir_dim, mul) +d.add_segment(2, (3, 4)) # input2: 4x SO3(1) +d.add_segment(3, (1, 16)) # output: 16x SO3(0) (4*4 fully connected) + +for c in cg: + d.add_path((4, 4, 16), 0, 0, 0, c=c) + +d = d.normalize_paths_for_operand(-1) + +poly = cue.SegmentedPolynomial.eval_last_operand(d) +ep = cue.EquivariantPolynomial( + [ + cue.IrrepsAndLayout(cue.Irreps("SO3", "4x1").new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(cue.Irreps("SO3", "4x1"), cue.ir_mul), + cue.IrrepsAndLayout(cue.Irreps("SO3", "4x1"), cue.ir_mul), + ], + [cue.IrrepsAndLayout(cue.Irreps("SO3", "16x0"), cue.ir_mul)], + poly, +) + +# Numpy evaluation +w = np.random.randn(ep.inputs[0].dim) +x = np.random.randn(ep.inputs[1].dim) +y = np.random.randn(ep.inputs[2].dim) +[out] = ep(w, x, y) +``` + +## Key file locations + +| Component | Path | +|-----------|------| +| `Irrep` base class | `cuequivariance/group_theory/representations/irrep.py` | +| `Rep` base class | `cuequivariance/group_theory/representations/rep.py` | +| `SO3` | `cuequivariance/group_theory/representations/irrep_so3.py` | +| `O3` | `cuequivariance/group_theory/representations/irrep_o3.py` | +| `SU2` | `cuequivariance/group_theory/representations/irrep_su2.py` | +| `Irreps` | `cuequivariance/group_theory/irreps_array/irreps.py` | +| `IrrepsLayout` | `cuequivariance/group_theory/irreps_array/irreps_layout.py` | +| `IrrepsAndLayout` | `cuequivariance/group_theory/irreps_array/irreps_and_layout.py` | +| `SegmentedTensorProduct` | `cuequivariance/segmented_polynomials/segmented_tensor_product.py` | +| `SegmentedPolynomial` | `cuequivariance/segmented_polynomials/segmented_polynomial.py` | +| `EquivariantPolynomial` | `cuequivariance/group_theory/equivariant_polynomial.py` | +| Descriptors | `cuequivariance/group_theory/descriptors/` | +| `fully_connected_tensor_product` etc. | `cuequivariance/group_theory/descriptors/irreps_tp.py` | +| `spherical_harmonics` | `cuequivariance/group_theory/descriptors/spherical_harmonics_.py` | +| `symmetric_contraction` | `cuequivariance/group_theory/descriptors/symmetric_contractions.py` | diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index dd92a588..19e217ef 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -12,6 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""cuEquivariance - GPU-accelerated equivariant operations for 3D neural networks. + +For AI coding assistants: run `python -m cuequivariance skill` for detailed usage guidance. +""" import importlib.resources __version__ = ( diff --git a/cuequivariance/cuequivariance/__main__.py b/cuequivariance/cuequivariance/__main__.py new file mode 100644 index 00000000..83e8da60 --- /dev/null +++ b/cuequivariance/cuequivariance/__main__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from pathlib import Path + + +def main(): + if len(sys.argv) >= 2 and sys.argv[1] == "skill": + skill_path = Path(__file__).parent / "SKILL.md" + print(skill_path.read_text()) + else: + print("Usage: python -m cuequivariance skill") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/cuequivariance_jax/cuequivariance_jax/SKILL.md b/cuequivariance_jax/cuequivariance_jax/SKILL.md new file mode 100644 index 00000000..d61d2d73 --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/SKILL.md @@ -0,0 +1,442 @@ +--- +name: cuequivariance-jax +description: Execute equivariant polynomials in JAX using segmented_polynomial (naive/uniform_1d), equivariant_polynomial with RepArray, ir_dict with dict[Irrep, Array], and Flax NNX layers (IrrepsLinear, SphericalHarmonics). Use when writing JAX code with cuequivariance. +--- + +# cuequivariance_jax: Executing Equivariant Polynomials in JAX + +## Overview + +`cuequivariance_jax` (imported as `cuex`) executes `cuequivariance` polynomials on GPU via JAX. It provides: + +1. **Core primitive**: `cuex.segmented_polynomial()` -- JAX primitive with full AD/vmap/JIT support +2. **Two data representations** (both built on `segmented_polynomial`): + - `cuex.equivariant_polynomial()` + `RepArray` -- the original interface, a single contiguous array with representation metadata + - `cuex.ir_dict` module -- `dict[Irrep, Array]` interface, conceptually simpler, works naturally with `jax.tree` operations +3. **NNX layers**: `cuex.nnx` module -- Flax NNX `Module` wrappers using `dict[Irrep, Array]` + +## Execution methods + +| Method | Backend | Requirements | +|--------|---------|-------------| +| `"naive"` | Pure JAX | Always works, any platform | +| `"uniform_1d"` | CUDA kernel | GPU, all segments uniform shape within each operand, single mode | +| `"indexed_linear"` | CUDA kernel | GPU, linear operations with `cuex.Repeats` indexing | + +## Core primitive: segmented_polynomial + +```python +import jax +import jax.numpy as jnp +import cuequivariance as cue +import cuequivariance_jax as cuex + +# Build a descriptor +e = cue.descriptors.channelwise_tensor_product( + 32 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), +) +poly = e.polynomial + +batch = 64 +w = jnp.ones((poly.inputs[0].size,)) # weights (shared across batch) +x = jax.random.normal(key, (batch, poly.inputs[1].size)) # batched input 1 +y = jax.random.normal(key, (batch, poly.inputs[2].size)) # batched input 2 + +# Execute with naive method +[out] = cuex.segmented_polynomial( + poly, + [w, x, y], # inputs + [jax.ShapeDtypeStruct((batch, poly.outputs[0].size), jnp.float32)], # output spec + method="naive", +) + +# Execute with uniform_1d (GPU, requires uniform segments) +[out] = cuex.segmented_polynomial( + poly, [w, x, y], + [jax.ShapeDtypeStruct((batch, poly.outputs[0].size), jnp.float32)], + method="uniform_1d", +) +``` + +### Multiple batch axes with broadcasting + +Inputs can have any number of batch axes (everything before the last axis). Standard NumPy broadcasting applies: each batch axis is either size-1 or a common size. Inputs with fewer batch dimensions are implicitly prepended with size-1 axes: + +```python +# 2 batch axes with size-1 broadcasting +w = jnp.ones((1, 10, poly.inputs[0].size)) # shared across axis 0 +x = jnp.ones((5, 10, poly.inputs[1].size)) # 5 along axis 0 +y = jnp.ones((5, 1, poly.inputs[2].size)) # shared across axis 1 + +[out] = cuex.segmented_polynomial( + poly, [w, x, y], + [jax.ShapeDtypeStruct((5, 10, poly.outputs[0].size), jnp.float32)], + method="uniform_1d", +) +# out.shape == (5, 10, ...) + +# Fewer batch dims: weights with no batch axis broadcast across all +w = jnp.ones((poly.inputs[0].size,)) # 0 batch axes -> prepended as (1, 1, ...) +x = jnp.ones((5, 10, poly.inputs[1].size)) +y = jnp.ones((5, 10, poly.inputs[2].size)) + +[out] = cuex.segmented_polynomial( + poly, [w, x, y], + [jax.ShapeDtypeStruct((5, 10, poly.outputs[0].size), jnp.float32)], + method="uniform_1d", +) +``` + +### Indexing (gather/scatter) + +Index arrays provide gather (for inputs) and scatter (for outputs). One index per operand (inputs + outputs), `None` means no indexing. Index arrays decouple input/output batch shapes -- the output shape is determined by the index ranges, not by the input shapes: + +```python +a = jnp.ones((1, 50, poly.inputs[0].size)) +b = jnp.ones((10, 50, poly.inputs[1].size)) +c = jnp.ones((100, 1, poly.inputs[2].size)) + +i = jax.random.randint(key, (100, 50), 0, 10) # gather b along axis 0 +j1 = jax.random.randint(key, (100, 50), 0, 11) # scatter output axis 0 +j2 = jax.random.randint(key, (100, 1), 0, 12) # scatter output axis 1 + +[out] = cuex.segmented_polynomial( + poly, [a, b, c], + [jax.ShapeDtypeStruct((11, 12, poly.outputs[0].size), jnp.float32)], + indices=[None, np.s_[i, :], None, np.s_[j1, j2]], + method="uniform_1d", +) +# out.shape == (11, 12, ...) -- determined by index ranges, not input shapes +``` + +### Gradients + +Fully differentiable -- supports `jax.grad`, `jax.jacobian`, `jax.jvp`, `jax.vmap`: + +```python +def loss(w, x, y): + [out] = cuex.segmented_polynomial( + poly, [w, x, y], + [jax.ShapeDtypeStruct((batch, poly.outputs[0].size), jnp.float32)], + method="naive", + ) + return jnp.sum(out ** 2) + +grad_w = jax.grad(loss, 0)(w, x, y) +``` + +## RepArray interface: equivariant_polynomial + +The original interface. Wraps `segmented_polynomial` with `RepArray` -- a single contiguous array with representation metadata: + +```python +e = cue.descriptors.fully_connected_tensor_product( + 4 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + 4 * cue.Irreps("SO3", "0 + 1"), +) + +inputs = [ + cuex.randn(jax.random.key(i), rep, (batch,), jnp.float32) + for i, rep in enumerate(e.inputs) +] + +# Returns a RepArray with representation metadata +out = cuex.equivariant_polynomial(e, inputs, method="naive") +out.array # the raw jax.Array +out.reps # dict mapping axes to Rep objects +``` + +## ir_dict interface + +An alternative to `RepArray`. Uses `dict[Irrep, Array]` where each value has shape `(..., multiplicity, irrep_dim)`. Conceptually simpler: works naturally with `jax.tree` operations and is the standard representation for NNX layers. + +### Preparing a polynomial for ir_dict + +Descriptors produce `EquivariantPolynomial` with dense operands. To use `ir_dict`, split operands by irrep: + +```python +e = cue.descriptors.channelwise_tensor_product( + 32 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + simplify_irreps3=True, +) + +# Split irreps-typed operands into per-irrep pieces +# Order matters: split inner operands first to preserve operand indices +poly = ( + e.split_operand_by_irrep(2) # split input 2 + .split_operand_by_irrep(1) # split input 1 + .split_operand_by_irrep(-1) # split output + .polynomial +) +# After split: each operand has a single irrep type, mapping naturally to dict[Irrep, Array] +``` + +### Executing with segmented_polynomial_uniform_1d + +```python +from einops import rearrange + +num_edges, num_nodes = 100, 30 + +# Weights: reshape to (batch, num_segments, segment_size) +w_flat = jax.random.normal(key, (num_edges, poly.inputs[0].size)) +w = rearrange(w_flat, "e (s m) -> e s m", s=poly.inputs[0].num_segments) + +# Node features: dict[Irrep, Array] reshaped to (nodes, ir.dim, mul) for ir_mul layout +node_feats = { + cue.SO3(0): jnp.ones((num_nodes, 32, 1)), # 32x scalar + cue.SO3(1): jnp.ones((num_nodes, 32, 3)), # 32x vector +} +x1 = jax.tree.map(lambda v: rearrange(v, "n m i -> n i m"), node_feats) + +# Spherical harmonics: (edges, ir.dim) -- no multiplicity dimension +sph = { + cue.SO3(0): jnp.ones((num_edges, 1)), + cue.SO3(1): jnp.ones((num_edges, 3)), +} + +# Build output template +senders = jax.random.randint(key, (num_edges,), 0, num_nodes) +receivers = jax.random.randint(key, (num_edges,), 0, num_nodes) +irreps_out = e.outputs[0].irreps +out_template = { + ir: jax.ShapeDtypeStruct( + (num_nodes, desc.num_segments) + desc.segment_shape, w.dtype + ) + for (_, ir), desc in zip(irreps_out, poly.outputs) +} + +# Execute with gather (senders) and scatter (receivers) +y = cuex.ir_dict.segmented_polynomial_uniform_1d( + poly, + [w, x1, sph], + out_template, + input_indices=[None, senders, None], + output_indices=receivers, + name="tensor_product", +) +# y is dict[Irrep, Array] with accumulated results at receiver nodes +``` + +### ir_dict utility functions + +```python +# Validate dict matches irreps +cuex.ir_dict.assert_mul_ir_dict(irreps, x) # asserts shape (..., mul, ir.dim) + +# Convert flat array <-> dict +d = cuex.ir_dict.flat_to_dict(irreps, flat_array) # layout="mul_ir" default +d = cuex.ir_dict.flat_to_dict(irreps, flat_array, layout="ir_mul") +flat = cuex.ir_dict.dict_to_flat(irreps, d) + +# Arithmetic +z = cuex.ir_dict.irreps_add(x, y) +z = cuex.ir_dict.irreps_zeros_like(x) + +# Create template dict +template = cuex.ir_dict.mul_ir_dict(irreps, jax.ShapeDtypeStruct(shape, dtype)) +``` + +## NNX layers + +### IrrepsLinear + +Equivariant linear layer using `dict[Irrep, Array]`: + +```python +from flax import nnx + +linear = cuex.nnx.IrrepsLinear( + irreps_in=cue.Irreps(cue.SO3, "4x0 + 2x1").regroup(), # must be regrouped + irreps_out=cue.Irreps(cue.SO3, "3x0 + 5x1").regroup(), + scale=1.0, + dtype=jnp.float32, + rngs=nnx.Rngs(0), +) + +# Input/output: dict[Irrep, Array] with shape (batch, mul, ir.dim) +x = { + cue.SO3(0): jnp.ones((batch, 4, 1)), + cue.SO3(1): jnp.ones((batch, 2, 3)), +} +y = linear(x) +# y[cue.SO3(0)].shape == (batch, 3, 1) +# y[cue.SO3(1)].shape == (batch, 5, 3) +``` + +Implementation uses `jnp.einsum("uv,...ui->...vi", w, x[ir])` per irrep with `1/sqrt(mul_in)` normalization. + +### SphericalHarmonics + +```python +sh = cuex.nnx.SphericalHarmonics(max_degree=3, eps=0.0) + +vectors = jax.random.normal(key, (batch, 3)) # 3D vectors +y = sh(vectors) +# y[cue.O3(0, 1)].shape == (batch, 1, 1) # L=0 +# y[cue.O3(1, -1)].shape == (batch, 1, 3) # L=1 +# y[cue.O3(2, 1)].shape == (batch, 1, 5) # L=2 +# y[cue.O3(3, -1)].shape == (batch, 1, 7) # L=3 +``` + +### IrrepsNormalize + +```python +norm = cuex.nnx.IrrepsNormalize(eps=1e-6, scale=1.0, skip_scalars=True) +y = norm(x) # normalizes non-scalar irreps by RMS over ir.dim, averaged over mul +``` + +### MLP (scalar only) + +```python +mlp = cuex.nnx.MLP( + layer_sizes=[64, 128, 64], + activation=jax.nn.silu, + output_activation=False, + dtype=jnp.float32, + rngs=nnx.Rngs(0), +) +y = mlp(x_scalar) # standard dense MLP with 1/sqrt(fan_in) normalization +``` + +### IrrepsIndexedLinear + +For species-indexed linear layers (different weights per atom type): + +```python +indexed_linear = cuex.nnx.IrrepsIndexedLinear( + irreps_in=cue.Irreps(cue.O3, "8x0e").regroup(), + irreps_out=cue.Irreps(cue.O3, "16x0e").regroup(), + num_indices=50, # number of species + scale=1.0, + dtype=jnp.float32, + rngs=nnx.Rngs(0), +) + +# num_index_counts: how many atoms of each species +species_counts = jnp.array([3, 4, 3, ...]) # sum = batch_size +y = indexed_linear(x, species_counts) +``` + +Uses `method="indexed_linear"` internally with `cuex.Repeats`. + +## Preparing polynomials for uniform_1d + +The `uniform_1d` CUDA kernel requires: +1. All segments within each operand have **the same shape** +2. A **single mode** in the subscripts (after preprocessing) + +### From EquivariantPolynomial to uniform_1d-ready + +For `equivariant_polynomial()` (RepArray interface): + +```python +e = cue.descriptors.channelwise_tensor_product(...) +e = e.squeeze_modes().flatten_coefficient_modes() +# If still >1 mode: e = e.flatten_modes(["u", "w"]) +out = cuex.equivariant_polynomial(e, inputs, method="uniform_1d") +``` + +For `ir_dict` (dict[Irrep, Array] interface): + +```python +e = cue.descriptors.channelwise_tensor_product(..., simplify_irreps3=True) +poly = ( + e.split_operand_by_irrep(2) # split input 2 + .split_operand_by_irrep(1) # split input 1 + .split_operand_by_irrep(-1) # split output + .polynomial +) +# Each operand has a single irrep type -> maps naturally to dict[Irrep, Array] +``` + +### Why split_operand_by_irrep matters + +Without splitting, a dense operand like `32x0+32x1` requires all irreps packed into a single contiguous buffer. After `split_operand_by_irrep`, each irrep gets its own separate buffer passed to the CUDA kernel via FFI. The buffers no longer need to be contiguous with each other. + +This is especially useful when the polynomial is preceded or followed by per-irrep linear layers (like `IrrepsLinear`). With split operands, no transpose or copy is needed between the linear layers and the polynomial — the `dict[Irrep, Array]` flows directly through the pipeline. + +## RepArray + +Representation-aware JAX array: + +```python +rep = cue.IrrepsAndLayout(cue.Irreps("SO3", "4x0 + 2x1"), cue.ir_mul) +x = cuex.RepArray(rep, jnp.ones((batch, rep.dim))) +x = cuex.randn(jax.random.key(0), rep, (batch,), jnp.float32) + +x.array # raw jax.Array +x.reps # {axis: Rep} +x.irreps # Irreps (if last axis is IrrepsAndLayout) +``` + +## Complete GNN message-passing example + +This pattern is used in NequIP, MACE, and similar equivariant GNN models: + +```python +class MessagePassing(nnx.Module): + def __init__(self, irreps_in, irreps_sh, irreps_out, epsilon, *, name, dtype, rngs): + e = ( + cue.descriptors.channelwise_tensor_product( + irreps_in, irreps_sh, irreps_out, True + ) + * epsilon + ) + self.weight_numel = e.inputs[0].dim + self.irreps_out = e.outputs[0].irreps + self.poly = ( + e.split_operand_by_irrep(2) + .split_operand_by_irrep(1) + .split_operand_by_irrep(-1) + .polynomial + ) + + def __call__(self, weights, node_feats, sph, senders, receivers, num_nodes): + # weights: (num_edges, weight_numel) + w = rearrange(weights, "e (s m) -> e s m", s=self.poly.inputs[0].num_segments) + # node_feats: dict[Irrep, Array] with (nodes, mul, ir.dim) + x1 = jax.tree.map(lambda v: rearrange(v, "n m i -> n i m"), node_feats) + # sph: dict[Irrep, Array] with (edges, 1, ir.dim) or (edges, ir.dim) + x2 = jax.tree.map(lambda v: rearrange(v, "e 1 i -> e i"), sph) + + out_template = { + ir: jax.ShapeDtypeStruct( + (num_nodes, desc.num_segments) + desc.segment_shape, w.dtype + ) + for (_, ir), desc in zip(self.irreps_out, self.poly.outputs) + } + + y = cuex.ir_dict.segmented_polynomial_uniform_1d( + self.poly, [w, x1, x2], out_template, + input_indices=[None, senders, None], + output_indices=receivers, + name="tensor_product", + ) + # Rearrange output back to (nodes, mul, ir.dim) for downstream layers + return { + ir: rearrange(v, "n (i s) m -> n (s m) i", i=ir.dim) + for ir, v in y.items() + } +``` + +## Key file locations + +| Component | Path | +|-----------|------| +| `segmented_polynomial` primitive | `cuequivariance_jax/segmented_polynomials/segmented_polynomial.py` | +| `uniform_1d` backend | `cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py` | +| `naive` backend | `cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py` | +| `indexed_linear` backend | `cuequivariance_jax/segmented_polynomials/segmented_polynomial_indexed_linear.py` | +| `equivariant_polynomial` | `cuequivariance_jax/equivariant_polynomial.py` | +| `ir_dict` module | `cuequivariance_jax/ir_dict.py` | +| `nnx` module | `cuequivariance_jax/nnx.py` | +| `RepArray` | `cuequivariance_jax/rep_array/rep_array_.py` | +| `Repeats` / utilities | `cuequivariance_jax/segmented_polynomials/utils.py` | +| NequIP example | `cuequivariance_jax/examples/nequip_nnx.py` | +| MACE example | `cuequivariance_jax/examples/mace_nnx.py` | diff --git a/cuequivariance_jax/cuequivariance_jax/__init__.py b/cuequivariance_jax/cuequivariance_jax/__init__.py index 913e2730..7d9ae9df 100644 --- a/cuequivariance_jax/cuequivariance_jax/__init__.py +++ b/cuequivariance_jax/cuequivariance_jax/__init__.py @@ -12,6 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""cuEquivariance JAX - JAX backend for cuEquivariance equivariant operations. + +For AI coding assistants: run `python -m cuequivariance_jax skill` for detailed usage guidance. +""" import importlib.resources __version__ = ( diff --git a/cuequivariance_jax/cuequivariance_jax/__main__.py b/cuequivariance_jax/cuequivariance_jax/__main__.py new file mode 100644 index 00000000..b4585791 --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/__main__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from pathlib import Path + + +def main(): + if len(sys.argv) >= 2 and sys.argv[1] == "skill": + skill_path = Path(__file__).parent / "SKILL.md" + print(skill_path.read_text()) + else: + print("Usage: python -m cuequivariance_jax skill") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/cuequivariance_torch/cuequivariance_torch/SKILL.md b/cuequivariance_torch/cuequivariance_torch/SKILL.md new file mode 100644 index 00000000..002765aa --- /dev/null +++ b/cuequivariance_torch/cuequivariance_torch/SKILL.md @@ -0,0 +1,388 @@ +--- +name: cuequivariance-torch +description: Execute equivariant tensor products in PyTorch using SegmentedPolynomial (naive/uniform_1d/fused_tp/indexed_linear), high-level operations (ChannelWiseTensorProduct, FullyConnectedTensorProduct, Linear, SymmetricContraction, SphericalHarmonics, Rotation), and layers (BatchNorm, FullyConnectedTensorProductConv). Use when writing PyTorch code with cuequivariance. +--- + +# cuequivariance_torch: Executing Equivariant Polynomials in PyTorch + +## Overview + +`cuequivariance_torch` (imported as `cuet`) executes `cuequivariance` polynomials on GPU via PyTorch. It provides: + +1. **Core primitive**: `cuet.SegmentedPolynomial` -- `torch.nn.Module` with multiple CUDA backends +2. **High-level operations** (`torch.nn.Module`): `ChannelWiseTensorProduct`, `FullyConnectedTensorProduct`, `Linear`, `SymmetricContraction`, `SphericalHarmonics`, `Rotation`, `Inversion` +3. **Layers**: `cuet.layers.BatchNorm`, `cuet.layers.FullyConnectedTensorProductConv` (message passing) +4. **Utilities**: `triangle_attention`, `triangle_multiplicative_update`, `attention_pair_bias` (AlphaFold2-style) +5. **Export support**: `onnx_custom_translation_table()`, `register_tensorrt_plugins()` + +## Execution methods + +| Method | Backend | Requirements | +|--------|---------|-------------| +| `"naive"` | Pure PyTorch (einsum) | Always works, any platform | +| `"uniform_1d"` | CUDA kernel | GPU, all segments uniform shape within each operand, single mode | +| `"fused_tp"` | CUDA kernel | GPU, 3- or 4-operand contractions, float32/float64 | +| `"indexed_linear"` | CUDA kernel | GPU, linear with indexed weights, sorted indices | + +## Core primitive: SegmentedPolynomial + +```python +import torch +import cuequivariance as cue +import cuequivariance_torch as cuet + +# Build a descriptor +e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) +poly = e.polynomial + +# Create the module +sp = cuet.SegmentedPolynomial(poly, method="uniform_1d") + +# Forward pass +x = torch.randn(batch, 3, device="cuda") +[output] = sp([x]) +# output.shape == (batch, 9) -- 1 + 3 + 5 +``` + +### Inputs, indexing, and scatter + +```python +e = cue.descriptors.channelwise_tensor_product( + 16 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), +) +poly = e.polynomial + +sp = cuet.SegmentedPolynomial(poly, method="uniform_1d") + +w = torch.randn(1, poly.inputs[0].size, device="cuda") # shared weights +x1 = torch.randn(batch, poly.inputs[1].size, device="cuda") # batched input 1 +x2 = torch.randn(batch, poly.inputs[2].size, device="cuda") # batched input 2 + +# Basic forward +[out] = sp([w, x1, x2]) + +# With input gathering (e.g., gather x1 by node index) +senders = torch.randint(0, num_nodes, (num_edges,), device="cuda") +[out] = sp([w, x1, x2], input_indices={1: senders}) + +# With output scattering (accumulate into target nodes) +receivers = torch.randint(0, num_nodes, (num_edges,), device="cuda") +[out] = sp( + [w, x1, x2], + input_indices={1: senders}, + output_indices={0: receivers}, + output_shapes={0: torch.empty(num_nodes, 1, device="cuda")}, +) +``` + +### Math dtype control + +```python +# Compute in float32 regardless of input dtype +sp = cuet.SegmentedPolynomial(poly, method="fused_tp", math_dtype=torch.float32) + +# For fused_tp, math_dtype must be float32 or float64 +# For naive, any torch.dtype works +# For uniform_1d, float32 or float64 (auto-selects float32 if input is e.g. float16) +``` + +## High-level operations + +All operations are `torch.nn.Module` subclasses. They wrap `SegmentedPolynomial` and handle layout transposition automatically. + +### Memory layout + +`IrrepsLayout` controls memory order within each `(mul, ir)` block: + +- `cue.mul_ir`: data ordered as `(mul, ir.dim)` -- **default, compatible with e3nn** +- `cue.ir_mul`: data ordered as `(ir.dim, mul)` -- **used internally by descriptors** + +Operations accept `layout` (applies to all), or per-operand `layout_in1`, `layout_in2`, `layout_out`. + +### ChannelWiseTensorProduct + +Channel-wise tensor product: pairs channels of `x1` with channels of `x2`. + +```python +# With internal weights (default: shared_weights=True, internal_weights=True) +tp = cuet.ChannelWiseTensorProduct( + cue.Irreps("SO3", "32x0 + 32x1"), # irreps_in1 + cue.Irreps("SO3", "0 + 1"), # irreps_in2 + layout=cue.mul_ir, + device="cuda", + dtype=torch.float32, +) +# tp.weight_numel -- number of weight parameters +# tp.irreps_out -- output irreps (auto-computed) + +x1 = torch.randn(batch, tp.irreps_in1.dim, device="cuda") +x2 = torch.randn(batch, tp.irreps_in2.dim, device="cuda") + +out = tp(x1, x2) # uses internal weight parameter +# out.shape == (batch, tp.irreps_out.dim) + +# With external weights (shared_weights=False) +tp = cuet.ChannelWiseTensorProduct( + cue.Irreps("SO3", "32x0 + 32x1"), + cue.Irreps("SO3", "0 + 1"), + layout=cue.mul_ir, + shared_weights=False, + device="cuda", +) +w = torch.randn(batch, tp.weight_numel, device="cuda") +out = tp(x1, x2, weight=w) + +# With gather/scatter for graph neural networks +out = tp(x1, x2, weight=w, indices_1=senders, indices_out=receivers, size_out=num_nodes) +``` + +Default method: `"uniform_1d"` if segments are uniform, else `"naive"`. + +### FullyConnectedTensorProduct + +All input irrep pairs contribute to all output irreps (dense contraction). + +```python +tp = cuet.FullyConnectedTensorProduct( + cue.Irreps("O3", "4x0e + 4x1o"), # irreps_in1 + cue.Irreps("O3", "0e + 1o"), # irreps_in2 + cue.Irreps("O3", "4x0e + 4x1o"), # irreps_out + layout=cue.mul_ir, + internal_weights=True, # store weights as parameters + device="cuda", +) + +out = tp(x1, x2) # uses internal weights +# or: out = tp(x1, x2, weight=w) # external weights +``` + +Default method: `"fused_tp"`. + +### Linear + +Equivariant linear layer (weight-only, no second input). + +```python +linear = cuet.Linear( + cue.Irreps("SO3", "4x0 + 2x1"), # irreps_in + cue.Irreps("SO3", "3x0 + 5x1"), # irreps_out + layout=cue.mul_ir, + internal_weights=True, + device="cuda", +) + +out = linear(x) + +# Species-indexed weights (different weights per atom type) +linear = cuet.Linear( + irreps_in, irreps_out, + weight_classes=50, # 50 different weight sets + internal_weights=True, + device="cuda", +) +out = linear(x, weight_indices=species_indices) # species_indices: (batch,) int tensor +``` + +Default method: `"naive"`. Use `method="fused_tp"` for CUDA acceleration. + +### SymmetricContraction + +MACE-style symmetric contraction with element-indexed weights. + +```python +sc = cuet.SymmetricContraction( + cue.Irreps("O3", "32x0e + 32x1o"), # irreps_in (uniform mul required) + cue.Irreps("O3", "32x0e"), # irreps_out (uniform mul required) + contraction_degree=3, # polynomial degree + num_elements=95, # number of chemical elements + layout=cue.ir_mul, + dtype=torch.float32, + device="cuda", +) + +# indices: (batch,) int tensor selecting which element weights to use +out = sc(x, indices) +# out.shape == (batch, irreps_out.dim) +``` + +Default method: `"uniform_1d"` if segments are uniform, else `"naive"`. + +### SphericalHarmonics + +```python +sh = cuet.SphericalHarmonics( + ls=[0, 1, 2, 3], # degrees + normalize=True, # normalize input vectors + device="cuda", +) + +vectors = torch.randn(batch, 3, device="cuda") +out = sh(vectors) +# out.shape == (batch, 1 + 3 + 5 + 7) -- sum of 2l+1 +``` + +Default method: `"uniform_1d"`. + +### Rotation and Inversion + +```python +# Rotation (SO3 or O3 irreps) +rot = cuet.Rotation( + cue.Irreps("SO3", "4x0 + 2x1 + 1x2"), + layout=cue.ir_mul, + device="cuda", +) + +# Euler angles (YXY convention) +gamma = torch.tensor([0.1], device="cuda") +beta = torch.tensor([0.2], device="cuda") +alpha = torch.tensor([0.3], device="cuda") +out = rot(gamma, beta, alpha, x) + +# Helper: encode angle for rotation +encoded = cuet.encode_rotation_angle(angle, ell=3) # cos/sin encoding + +# Helper: 3D vector to Euler angles +beta, alpha = cuet.vector_to_euler_angles(vector) + +# Inversion (O3 irreps only) +inv = cuet.Inversion( + cue.Irreps("O3", "4x0e + 2x1o"), + layout=cue.ir_mul, + device="cuda", +) +out = inv(x) +``` + +## Layers + +### BatchNorm + +Batch normalization for equivariant representations (adapted from e3nn). + +```python +bn = cuet.layers.BatchNorm( + cue.Irreps("O3", "4x0e + 4x1o"), + layout=cue.mul_ir, + eps=1e-5, + momentum=0.1, + affine=True, +) + +out = bn(x) # x.shape == (batch, irreps.dim) +``` + +### FullyConnectedTensorProductConv + +Message passing layer for equivariant GNNs (DiffDock-style). + +```python +conv = cuet.layers.FullyConnectedTensorProductConv( + in_irreps=cue.Irreps("O3", "4x0e + 4x1o"), + sh_irreps=cue.Irreps("O3", "0e + 1o"), + out_irreps=cue.Irreps("O3", "4x0e + 4x1o"), + mlp_channels=[16, 32, 32], # MLP for path weights + mlp_activation=torch.nn.ReLU(), + batch_norm=True, + layout=cue.ir_mul, +) + +# graph = ((src, dst), (num_src_nodes, num_dst_nodes)) +graph = ((src, dst), (num_src_nodes, num_dst_nodes)) + +out = conv(src_features, edge_sh, edge_emb, graph, reduce="mean") +# out.shape == (num_dst_nodes, out_irreps.dim) + +# Optional: separate scalar features for efficient first-layer GEMM +out = conv(src_features, edge_sh, edge_emb, graph, + src_scalars=src_scalars, dst_scalars=dst_scalars) +``` + +## Triangle operations (AlphaFold2-style) + +Require `cuequivariance_ops_torch`. + +```python +# Triangle attention with pair bias +out = cuet.triangle_attention(q, k, v, bias, mask=mask, scale=scale) +# q, k, v: (B, N, H, Q/K, D), bias: (B, 1, H, Q, K) + +# Triangle multiplicative update +out = cuet.triangle_multiplicative_update( + x, # (B, I, J, C) + mask=mask, # (B, I, J) + precision=cuet.TriMulPrecision.DEFAULT, +) + +# Attention with pair bias (diffusion models) +out = cuet.attention_pair_bias(q, k, v, bias, mask=mask) +``` + +## ONNX and TensorRT export + +```python +# ONNX export +table = cuet.onnx_custom_translation_table() +onnx_program = torch.onnx.export(model, inputs, custom_translation_table=table) + +# TensorRT plugin registration +cuet.register_tensorrt_plugins() +``` + +## Complete GNN example + +```python +import torch +import cuequivariance as cue +import cuequivariance_torch as cuet + +class SimpleGNN(torch.nn.Module): + def __init__(self, irreps_in, irreps_sh, irreps_out): + super().__init__() + self.tp = cuet.ChannelWiseTensorProduct( + irreps_in, irreps_sh, layout=cue.mul_ir, + shared_weights=False, device="cuda", + ) + self.linear = cuet.Linear( + self.tp.irreps_out, irreps_out, + layout=cue.mul_ir, internal_weights=True, device="cuda", + ) + self.sh = cuet.SphericalHarmonics( + ls=[ir.l for _, ir in irreps_sh], normalize=True, device="cuda", + ) + + def forward(self, node_feats, edge_vec, edge_index, num_nodes): + src, dst = edge_index + edge_sh = self.sh(edge_vec) + w = torch.randn(1, self.tp.weight_numel, device=node_feats.device) + + # Message: tensor product on edges with scatter to destination nodes + messages = self.tp( + node_feats, edge_sh, weight=w, + indices_1=src, indices_2=None, + indices_out=dst, size_out=num_nodes, + ) + return self.linear(messages) +``` + +## Key file locations + +| Component | Path | +|-----------|------| +| `SegmentedPolynomial` | `cuequivariance_torch/primitives/segmented_polynomial.py` | +| `uniform_1d` backend | `cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py` | +| `naive` backend | `cuequivariance_torch/primitives/segmented_polynomial_naive.py` | +| `fused_tp` backend | `cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py` | +| `indexed_linear` backend | `cuequivariance_torch/primitives/segmented_polynomial_indexed_linear.py` | +| `ChannelWiseTensorProduct` | `cuequivariance_torch/operations/tp_channel_wise.py` | +| `FullyConnectedTensorProduct` | `cuequivariance_torch/operations/tp_fully_connected.py` | +| `Linear` | `cuequivariance_torch/operations/linear.py` | +| `SymmetricContraction` | `cuequivariance_torch/operations/symmetric_contraction.py` | +| `SphericalHarmonics` | `cuequivariance_torch/operations/spherical_harmonics.py` | +| `Rotation` / `Inversion` | `cuequivariance_torch/operations/rotation.py` | +| `BatchNorm` | `cuequivariance_torch/layers/batchnorm.py` | +| `FullyConnectedTensorProductConv` | `cuequivariance_torch/layers/tp_conv_fully_connected.py` | +| Triangle operations | `cuequivariance_torch/primitives/triangle.py` | +| Layout transposition | `cuequivariance_torch/primitives/transpose.py` | diff --git a/cuequivariance_torch/cuequivariance_torch/__init__.py b/cuequivariance_torch/cuequivariance_torch/__init__.py index 23c929e4..44fcd853 100644 --- a/cuequivariance_torch/cuequivariance_torch/__init__.py +++ b/cuequivariance_torch/cuequivariance_torch/__init__.py @@ -12,6 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""cuEquivariance PyTorch - PyTorch backend for cuEquivariance equivariant operations. + +For AI coding assistants: run `python -m cuequivariance_torch skill` for detailed usage guidance. +""" import importlib.resources __version__ = ( diff --git a/cuequivariance_torch/cuequivariance_torch/__main__.py b/cuequivariance_torch/cuequivariance_torch/__main__.py new file mode 100644 index 00000000..094f969c --- /dev/null +++ b/cuequivariance_torch/cuequivariance_torch/__main__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from pathlib import Path + + +def main(): + if len(sys.argv) >= 2 and sys.argv[1] == "skill": + skill_path = Path(__file__).parent / "SKILL.md" + print(skill_path.read_text()) + else: + print("Usage: python -m cuequivariance_torch skill") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/docs/api/cuequivariance_jax.rst b/docs/api/cuequivariance_jax.rst index 567da6fc..103589bd 100644 --- a/docs/api/cuequivariance_jax.rst +++ b/docs/api/cuequivariance_jax.rst @@ -47,6 +47,38 @@ Tensor Products equivariant_polynomial segmented_polynomial +ir_dict +------- + +Utilities for working with ``dict[Irrep, Array]`` representation, an alternative to ``RepArray``. + +.. autosummary:: + :toctree: generated/ + :template: function_template.rst + + ir_dict.segmented_polynomial_uniform_1d + ir_dict.assert_mul_ir_dict + ir_dict.mul_ir_dict + ir_dict.flat_to_dict + ir_dict.dict_to_flat + ir_dict.irreps_add + ir_dict.irreps_zeros_like + +NNX Layers +---------- + +Flax NNX modules using ``dict[Irrep, Array]`` representation. + +.. autosummary:: + :toctree: generated/ + :template: class_template.rst + + nnx.IrrepsLinear + nnx.SphericalHarmonics + nnx.IrrepsNormalize + nnx.MLP + nnx.IrrepsIndexedLinear + Extra Modules ------------- @@ -62,6 +94,15 @@ Extra Modules spherical_harmonics +Utilities +--------- + +.. autosummary:: + :toctree: generated/ + :template: class_template.rst + + Repeats + Triangle -------- diff --git a/docs/tutorials/jax/index.rst b/docs/tutorials/jax/index.rst index 2c3c2bdf..aa0cd74f 100644 --- a/docs/tutorials/jax/index.rst +++ b/docs/tutorials/jax/index.rst @@ -20,4 +20,5 @@ JAX Examples :maxdepth: 1 poly + ir_dict diff --git a/docs/tutorials/jax/ir_dict.rst b/docs/tutorials/jax/ir_dict.rst new file mode 100644 index 00000000..0f1c1316 --- /dev/null +++ b/docs/tutorials/jax/ir_dict.rst @@ -0,0 +1,191 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +The ``ir_dict`` Interface +========================= + +The :mod:`cuequivariance_jax.ir_dict` module provides an alternative to :class:`~cuequivariance_jax.RepArray` for working with equivariant data. Instead of a single contiguous array, features are stored as ``dict[Irrep, Array]`` where each value has shape ``(..., multiplicity, irrep_dim)``. + +This representation works naturally with ``jax.tree`` operations and is used by the :mod:`cuequivariance_jax.nnx` layers. + +From Descriptor to ``ir_dict`` +------------------------------ + +Descriptors produce :class:`~cuequivariance.EquivariantPolynomial` objects with dense operands (e.g., ``32x0+32x1``). A dense operand requires all irreps to be packed into a single contiguous buffer. By splitting each operand by irrep with :meth:`~cuequivariance.EquivariantPolynomial.split_operand_by_irrep`, each irrep gets its own separate buffer. This relaxes the memory layout constraint: the buffers for different irreps no longer need to be contiguous with each other. + +This is especially useful when the polynomial is preceded or followed by linear layers that act independently on each irrep (like :class:`~cuequivariance_jax.nnx.IrrepsLinear`). With split operands, there is no need for any transpose or copy between the linear layers and the polynomial — the ``dict[Irrep, Array]`` flows directly through the pipeline. + +.. jupyter-execute:: + + import jax + import jax.numpy as jnp + from einops import rearrange + import cuequivariance as cue + import cuequivariance_jax as cuex + + # Build a channelwise tensor product + e = cue.descriptors.channelwise_tensor_product( + 32 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + simplify_irreps3=True, + ) + print("Before split:") + print(e) + +.. jupyter-execute:: + + # Split operands by irrep + # Order: split inner operands first to preserve indices + e_split = ( + e.split_operand_by_irrep(2) + .split_operand_by_irrep(1) + .split_operand_by_irrep(-1) + ) + poly = e_split.polynomial + + print("After split:") + print(e_split) + print() + for i, op in enumerate(poly.inputs): + print(f" Input {i}: num_segments={op.num_segments}, uniform={op.all_same_segment_shape()}") + for i, op in enumerate(poly.outputs): + print(f" Output {i}: num_segments={op.num_segments}, uniform={op.all_same_segment_shape()}") + + +Executing with ``segmented_polynomial_uniform_1d`` +-------------------------------------------------- + +The :func:`~cuequivariance_jax.ir_dict.segmented_polynomial_uniform_1d` function handles the flattening/unflattening between the ``dict[Irrep, Array]`` pytree structure and the flat arrays that the kernel expects. + +Each input array has shape ``(..., num_segments, *segment_shape)``. For the weight operand, we reshape the flat weights into this form. For ``dict[Irrep, Array]`` operands, each value is one leaf of the pytree. + +.. jupyter-execute:: + + batch = 16 + + # Weights: reshape flat -> (batch, num_segments, segment_size) + w_flat = jax.random.normal(jax.random.key(0), (batch, poly.inputs[0].size)) + w = rearrange(w_flat, "b (s m) -> b s m", s=poly.inputs[0].num_segments) + print(f"Weights: {w.shape} (batch, num_segments, segment_size)") + + # Inputs as dict[Irrep, Array] + # Shape convention: (batch, ir.dim, mul) for ir_mul layout + node_feats = { + cue.SO3(0): jax.random.normal(jax.random.key(1), (batch, 32, 1)), + cue.SO3(1): jax.random.normal(jax.random.key(2), (batch, 32, 3)), + } + # Rearrange from (batch, mul, ir.dim) to (batch, ir.dim, mul) for ir_mul layout + x = jax.tree.map(lambda v: rearrange(v, "b m i -> b i m"), node_feats) + print(f"Input l=0: {x[cue.SO3(0)].shape} (batch, ir.dim, mul)") + print(f"Input l=1: {x[cue.SO3(1)].shape} (batch, ir.dim, mul)") + + # Second input (e.g. spherical harmonics): (batch, ir.dim) + sph = { + cue.SO3(0): jax.random.normal(jax.random.key(3), (batch, 1)), + cue.SO3(1): jax.random.normal(jax.random.key(4), (batch, 3)), + } + + # Build output template: one entry per split output + irreps_out = e.outputs[0].irreps + out_template = { + ir: jax.ShapeDtypeStruct( + (batch, desc.num_segments) + desc.segment_shape, w.dtype + ) + for (_, ir), desc in zip(irreps_out, poly.outputs) + } + print(f"Output template: { {str(k): v.shape for k, v in out_template.items()} }") + +.. jupyter-execute:: + + # Execute + y = cuex.ir_dict.segmented_polynomial_uniform_1d( + poly, [w, x, sph], out_template, + ) + + for ir, v in y.items(): + print(f" Output {ir}: {v.shape}") + + +Indexing (Gather/Scatter) +------------------------- + +In graph neural networks, features live on nodes and edges with different batch sizes. Index arrays handle the gather (for inputs) and scatter (for outputs): + +.. jupyter-execute:: + + num_edges, num_nodes = 100, 30 + + w = jax.random.normal(jax.random.key(0), (num_edges, poly.inputs[0].size)) + w = rearrange(w, "e (s m) -> e s m", s=poly.inputs[0].num_segments) + + node_feats = { + cue.SO3(0): jax.random.normal(jax.random.key(1), (num_nodes, 1, 32)), + cue.SO3(1): jax.random.normal(jax.random.key(2), (num_nodes, 3, 32)), + } + + sph = { + cue.SO3(0): jax.random.normal(jax.random.key(3), (num_edges, 1)), + cue.SO3(1): jax.random.normal(jax.random.key(4), (num_edges, 3)), + } + + senders = jax.random.randint(jax.random.key(5), (num_edges,), 0, num_nodes) + receivers = jax.random.randint(jax.random.key(6), (num_edges,), 0, num_nodes) + + out_template = { + ir: jax.ShapeDtypeStruct( + (num_nodes, desc.num_segments) + desc.segment_shape, w.dtype + ) + for (_, ir), desc in zip(irreps_out, poly.outputs) + } + + # Gather node features at senders, scatter results to receivers + y = cuex.ir_dict.segmented_polynomial_uniform_1d( + poly, + [w, node_feats, sph], + out_template, + input_indices=[None, senders, None], + output_indices=receivers, + ) + + for ir, v in y.items(): + print(f" Output {ir}: {v.shape}") + + +Utility Functions +----------------- + +The ``ir_dict`` module provides helpers for converting between flat arrays and ``dict[Irrep, Array]``: + +.. jupyter-execute:: + + irreps = cue.Irreps(cue.SO3, "4x0 + 2x1") + + # Flat array -> dict + flat = jnp.ones((8, irreps.dim)) + d = cuex.ir_dict.flat_to_dict(irreps, flat) + print(f"flat_to_dict: l=0 {d[cue.SO3(0)].shape}, l=1 {d[cue.SO3(1)].shape}") + + # Dict -> flat array + flat_back = cuex.ir_dict.dict_to_flat(irreps, d) + print(f"dict_to_flat: {flat_back.shape}") + + # Arithmetic + z = cuex.ir_dict.irreps_add(d, d) + print(f"irreps_add: l=0 sum={float(z[cue.SO3(0)].sum())}") + + # Validation + cuex.ir_dict.assert_mul_ir_dict(irreps, d) + print("assert_mul_ir_dict: passed") diff --git a/docs/tutorials/jax/poly.rst b/docs/tutorials/jax/poly.rst index 68043c20..19c040ef 100644 --- a/docs/tutorials/jax/poly.rst +++ b/docs/tutorials/jax/poly.rst @@ -44,7 +44,8 @@ Basic Usage # 2. Define output shapes/dtypes # We need to tell JAX what the output will look like - output_structs = [jax.ShapeDtypeStruct((3, 3), jnp.float32)] + # Use -1 for the last dimension to infer from the polynomial descriptor + output_structs = [jax.ShapeDtypeStruct((-1,), jnp.float32)] # 3. Execute input_arr = jnp.ones((3,))