diff --git a/cuequivariance/cuequivariance/SKILL.md b/cuequivariance/cuequivariance/SKILL.md index a6325f78..f14cbd8b 100644 --- a/cuequivariance/cuequivariance/SKILL.md +++ b/cuequivariance/cuequivariance/SKILL.md @@ -1,6 +1,6 @@ --- name: cuequivariance -description: Define custom groups (Irrep subclasses), build segmented tensor products with CG coefficients, create equivariant polynomials, and use built-in descriptors (linear, tensor products, spherical harmonics). Use when working with cuequivariance group theory, irreps, or segmented polynomials. +description: Define custom groups (Irrep subclasses), build segmented tensor products with CG coefficients, create equivariant polynomials and IrDictPolynomials, and use built-in descriptors (linear, tensor products, spherical harmonics). Use when working with cuequivariance group theory, irreps, or segmented polynomials. --- # cuequivariance: Groups, Irreps, and Segmented Polynomials @@ -10,7 +10,9 @@ description: Define custom groups (Irrep subclasses), build segmented tensor pro `cuequivariance` (imported as `cue`) provides two core abstractions: 1. **Group theory**: `Irrep` subclasses define irreducible representations of Lie groups (SO3, O3, SU2, or custom). `Irreps` manages collections with multiplicities. -2. **Segmented polynomials**: `SegmentedTensorProduct` describes tensor contractions over segments of varying shape, linked by `Path` objects carrying Clebsch-Gordan coefficients. `SegmentedPolynomial` wraps multiple STPs into a polynomial with named inputs/outputs. `EquivariantPolynomial` attaches group representations to each operand. +2. **Segmented polynomials**: `SegmentedTensorProduct` describes tensor contractions over segments of varying shape, linked by `Path` objects carrying Clebsch-Gordan coefficients. `SegmentedPolynomial` wraps multiple STPs into a polynomial with named inputs/outputs. Two higher-level wrappers attach group representations: + - `EquivariantPolynomial` — dense operands with `IrrepsAndLayout` metadata + - `IrDictPolynomial` — operands already split by irrep, with per-group `Irreps` metadata for the `dict[Irrep, Array]` workflow ## Defining a custom group @@ -122,8 +124,8 @@ for mul, ir in irreps: `IrrepsLayout` controls memory order within each `(mul, ir)` block: -- `cue.ir_mul`: data ordered as `(ir.dim, mul)` -- **used by all descriptors** -- `cue.mul_ir`: data ordered as `(mul, ir.dim)` -- **used by nnx dict[Irrep, Array]** +- `cue.ir_mul`: data ordered as `(ir.dim, mul)` — **used by all descriptors and ir_dict internally** +- `cue.mul_ir`: data ordered as `(mul, ir.dim)` — **used by nnx `dict[Irrep, Array]` and PyTorch** `IrrepsAndLayout` combines irreps with a layout into a `Rep`: @@ -181,9 +183,14 @@ for cg in cue.clebsch_gordan(ir1, ir2, ir3): d.add_path((mul1, mul2, mul3), seg_in1, seg_in2, seg_out, c=cg) ``` -## Using descriptors (high-level API) +## Descriptors -All descriptors return `cue.EquivariantPolynomial`: +All descriptors come in two variants: + +- **Original** — returns `EquivariantPolynomial` with dense operands +- **`_ir_dict`** — returns `IrDictPolynomial` with operands already split by irrep + +### EquivariantPolynomial descriptors ```python # Fully connected tensor product (all input-output irrep combinations) @@ -199,12 +206,17 @@ e = cue.descriptors.channelwise_tensor_product( cue.Irreps("SO3", "0 + 1"), simplify_irreps3=True, ) +# Full (weightless) tensor product +e = cue.descriptors.full_tensor_product( + cue.Irreps("SO3", "2x0 + 1x1"), cue.Irreps("SO3", "0 + 1"), +) + # Elementwise tensor product (paired channels) e = cue.descriptors.elementwise_tensor_product( cue.Irreps("SO3", "4x0 + 4x1"), cue.Irreps("SO3", "4x0 + 4x1"), ) -# Linear equivariant map (no second input, just weight x input) +# Linear equivariant map (weight x input) e = cue.descriptors.linear( cue.Irreps("SO3", "4x0 + 2x1"), cue.Irreps("SO3", "3x0 + 5x1"), @@ -217,10 +229,80 @@ e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3]) e = cue.descriptors.symmetric_contraction( 64 * cue.Irreps("SO3", "0 + 1 + 2"), 64 * cue.Irreps("SO3", "0 + 1"), - [0, 1, 2, 3], + (1, 2, 3), ) ``` +### IrDictPolynomial descriptors + +Each `_ir_dict` variant returns an `IrDictPolynomial` whose polynomial is already split by irrep. The `input_irreps` and `output_irreps` tuples describe the operand groups. + +```python +# Channelwise tensor product +desc = cue.descriptors.channelwise_tensor_product_ir_dict( + 64 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), +) +# desc.polynomial — SegmentedPolynomial, already split by irrep +# desc.input_irreps — (weight_irreps, irreps1, irreps2) +# desc.output_irreps — (irreps_out,) + +# Fully connected tensor product +desc = cue.descriptors.fully_connected_tensor_product_ir_dict(irreps1, irreps2, irreps3) + +# Full (weightless) tensor product +desc = cue.descriptors.full_tensor_product_ir_dict(irreps1, irreps2) + +# Elementwise tensor product +desc = cue.descriptors.elementwise_tensor_product_ir_dict(irreps1, irreps2) + +# Linear +desc = cue.descriptors.linear_ir_dict(irreps_in, irreps_out) + +# Spherical harmonics +desc = cue.descriptors.spherical_harmonics_ir_dict(cue.O3(1, -1), [0, 1, 2, 3]) + +# Symmetric contraction +desc = cue.descriptors.symmetric_contraction_ir_dict(irreps_in, irreps_out, (1, 2, 3)) +``` + +### IrDictPolynomial + +`IrDictPolynomial` pairs a `SegmentedPolynomial` (already split by irrep) with the `Irreps` that describe each operand group. + +```python +desc = cue.descriptors.channelwise_tensor_product_ir_dict( + 32 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), +) + +desc.polynomial # SegmentedPolynomial — each operand is one (mul, ir) block +desc.input_irreps # (weight_irreps, irreps1, irreps2) +desc.output_irreps # (irreps_out,) + +# Scale coefficients +scaled_poly = desc.polynomial * 0.5 + +# Access individual operand info +for i, op in enumerate(desc.polynomial.inputs): + print(f"Input {i}: size={op.size}, num_segments={op.num_segments}") +``` + +Contract: for each `(mul, ir)` block in `input_irreps` / `output_irreps`, the corresponding polynomial operand has size `mul * ir.dim`. + +### split_polynomial_by_irreps + +The low-level function underlying `_ir_dict` descriptors. Splits one polynomial operand at irrep boundaries: + +```python +poly = e.polynomial # from an EquivariantPolynomial +poly = cue.split_polynomial_by_irreps(poly, 2, irreps_sh) # split input 2 +poly = cue.split_polynomial_by_irreps(poly, 1, irreps_in) # split input 1 +poly = cue.split_polynomial_by_irreps(poly, -1, irreps_out) # split output +``` + ### EquivariantPolynomial key methods ```python @@ -328,7 +410,8 @@ y = np.random.randn(ep.inputs[2].dim) | `SegmentedTensorProduct` | `cuequivariance/segmented_polynomials/segmented_tensor_product.py` | | `SegmentedPolynomial` | `cuequivariance/segmented_polynomials/segmented_polynomial.py` | | `EquivariantPolynomial` | `cuequivariance/group_theory/equivariant_polynomial.py` | +| `IrDictPolynomial` | `cuequivariance/group_theory/ir_dict_polynomial.py` | | Descriptors | `cuequivariance/group_theory/descriptors/` | -| `fully_connected_tensor_product` etc. | `cuequivariance/group_theory/descriptors/irreps_tp.py` | +| Tensor product descriptors | `cuequivariance/group_theory/descriptors/irreps_tp.py` | | `spherical_harmonics` | `cuequivariance/group_theory/descriptors/spherical_harmonics_.py` | | `symmetric_contraction` | `cuequivariance/group_theory/descriptors/symmetric_contractions.py` | diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index 19e217ef..92210f48 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -55,6 +55,8 @@ reduced_antisymmetric_tensor_product_basis, EquivariantPolynomial, EquivariantTensorProduct, # deprecated + IrDictPolynomial, + split_polynomial_by_irreps, ) from cuequivariance import segmented_polynomials as segmented_polynomials @@ -93,6 +95,8 @@ "reduced_antisymmetric_tensor_product_basis", "EquivariantPolynomial", "EquivariantTensorProduct", + "IrDictPolynomial", + "split_polynomial_by_irreps", "segmented_polynomials", "group_theory", "descriptors", diff --git a/cuequivariance/cuequivariance/etc/linalg.py b/cuequivariance/cuequivariance/etc/linalg.py index 3fdb33b6..54bf3579 100644 --- a/cuequivariance/cuequivariance/etc/linalg.py +++ b/cuequivariance/cuequivariance/etc/linalg.py @@ -107,6 +107,10 @@ def limit_denominator(n, d, max_denominator: int): n1, d1 = p0 + k * p1, q0 + k * q1 n2, d2 = p1, q1 with np.errstate(over="ignore"): + # The intermediate products (n2*d0, n0*d2) overflow int64 (~2^102), but the + # overflow is benign: their difference is bounded by d0 < 2^62 (fits in int64), + # and two's complement subtraction recovers it exactly. The final product + # d1*(difference) is also bounded by d0 < 2^63. mask = np.abs(d1 * (n2 * d0 - n0 * d2)) <= np.abs(d2 * (n1 * d0 - n0 * d1)) return np.where( d0 <= max_denominator, diff --git a/cuequivariance/cuequivariance/group_theory/__init__.py b/cuequivariance/cuequivariance/group_theory/__init__.py index 4e221b98..f00be4b7 100644 --- a/cuequivariance/cuequivariance/group_theory/__init__.py +++ b/cuequivariance/cuequivariance/group_theory/__init__.py @@ -44,6 +44,7 @@ from .equivariant_polynomial import EquivariantPolynomial from .equivariant_tensor_product import EquivariantTensorProduct +from .ir_dict_polynomial import IrDictPolynomial, split_polynomial_by_irreps __all__ = [ @@ -72,4 +73,6 @@ "reduced_antisymmetric_tensor_product_basis", "EquivariantPolynomial", "EquivariantTensorProduct", + "IrDictPolynomial", + "split_polynomial_by_irreps", ] diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/__init__.py b/cuequivariance/cuequivariance/group_theory/descriptors/__init__.py index f18c438a..2de8b1c8 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/__init__.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/__init__.py @@ -15,12 +15,17 @@ from .transposition import transpose from .irreps_tp import ( full_tensor_product, + full_tensor_product_ir_dict, fully_connected_tensor_product, + fully_connected_tensor_product_ir_dict, channelwise_tensor_product, + channelwise_tensor_product_ir_dict, elementwise_tensor_product, + elementwise_tensor_product_ir_dict, linear, + linear_ir_dict, ) -from .symmetric_contractions import symmetric_contraction +from .symmetric_contractions import symmetric_contraction, symmetric_contraction_ir_dict from .rotations import ( fixed_axis_angle_rotation, y_rotation, @@ -30,16 +35,26 @@ yxy_rotation, inversion, ) -from .spherical_harmonics_ import sympy_spherical_harmonics, spherical_harmonics +from .spherical_harmonics_ import ( + sympy_spherical_harmonics, + spherical_harmonics, + spherical_harmonics_ir_dict, +) __all__ = [ "transpose", "full_tensor_product", + "full_tensor_product_ir_dict", "fully_connected_tensor_product", + "fully_connected_tensor_product_ir_dict", "channelwise_tensor_product", + "channelwise_tensor_product_ir_dict", "elementwise_tensor_product", + "elementwise_tensor_product_ir_dict", "linear", + "linear_ir_dict", "symmetric_contraction", + "symmetric_contraction_ir_dict", "fixed_axis_angle_rotation", "y_rotation", "x_rotation", @@ -49,4 +64,5 @@ "inversion", "sympy_spherical_harmonics", "spherical_harmonics", + "spherical_harmonics_ir_dict", ] diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py index 992d51f5..778130fd 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py @@ -19,6 +19,34 @@ from cuequivariance.group_theory.irreps_array.irrep_utils import into_list_of_irrep +def _fully_connected_tensor_product_core( + irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps +) -> cue.SegmentedPolynomial: + G = irreps1.irrep_class + + d = cue.SegmentedTensorProduct.from_subscripts("uvw,iu,jv,kw+ijk") + + for mul, ir in irreps1: + d.add_segment(1, (ir.dim, mul)) + for mul, ir in irreps2: + d.add_segment(2, (ir.dim, mul)) + for mul, ir in irreps3: + d.add_segment(3, (ir.dim, mul)) + + for (i1, (mul1, ir1)), (i2, (mul2, ir2)), (i3, (mul3, ir3)) in itertools.product( + enumerate(irreps1), enumerate(irreps2), enumerate(irreps3) + ): + if ir3 not in ir1 * ir2: + continue + + # for loop over the different solutions of the Clebsch-Gordan decomposition + for cg in G.clebsch_gordan(ir1, ir2, ir3): + d.add_path((mul1, mul2, mul3), i1, i2, i3, c=cg) + + d = d.normalize_paths_for_operand(-1) + return cue.SegmentedPolynomial.eval_last_operand(d) + + def fully_connected_tensor_product( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps ) -> cue.EquivariantPolynomial: @@ -51,59 +79,56 @@ def fully_connected_tensor_product( Where ``61440x0`` are the 61440 weights needed to mix all the inputs with all the outputs. """ - G = irreps1.irrep_class - - d = cue.SegmentedTensorProduct.from_subscripts("uvw,iu,jv,kw+ijk") - - for mul, ir in irreps1: - d.add_segment(1, (ir.dim, mul)) - for mul, ir in irreps2: - d.add_segment(2, (ir.dim, mul)) - for mul, ir in irreps3: - d.add_segment(3, (ir.dim, mul)) - - for (i1, (mul1, ir1)), (i2, (mul2, ir2)), (i3, (mul3, ir3)) in itertools.product( - enumerate(irreps1), enumerate(irreps2), enumerate(irreps3) - ): - if ir3 not in ir1 * ir2: - continue - - # for loop over the different solutions of the Clebsch-Gordan decomposition - for cg in G.clebsch_gordan(ir1, ir2, ir3): - d.add_path((mul1, mul2, mul3), i1, i2, i3, c=cg) - - d = d.normalize_paths_for_operand(-1) + poly = _fully_connected_tensor_product_core(irreps1, irreps2, irreps3) return cue.EquivariantPolynomial( [ - cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps1.new_scalars(poly.inputs[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), ], [cue.IrrepsAndLayout(irreps3, cue.ir_mul)], - cue.SegmentedPolynomial.eval_last_operand(d), + poly, ) -def full_tensor_product( - irreps1: cue.Irreps, - irreps2: cue.Irreps, - irreps3_filter: Optional[Sequence[cue.Irrep]] = None, -) -> cue.EquivariantPolynomial: +def fully_connected_tensor_product_ir_dict( + irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps +) -> cue.IrDictPolynomial: """ - subscripts: ``lhs[iu],rhs[jv],output[kuv]`` + subscripts: ``weights[uvw],lhs[iu],rhs[jv],output[kw]`` - Construct a weightless channelwise tensor product descriptor. + Construct a fully connected tensor product as an :class:`~cuequivariance.IrDictPolynomial`. + + This is the ``ir_dict`` variant of :func:`fully_connected_tensor_product`. .. currentmodule:: cuequivariance Args: irreps1 (Irreps): Irreps of the first operand. irreps2 (Irreps): Irreps of the second operand. - irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. + irreps3 (Irreps): Irreps of the output. Returns: - :class:`cue.EquivariantPolynomial `: Descriptor of the full tensor product. + :class:`cue.IrDictPolynomial `: The fully connected tensor product + with ``input_irreps = (weight_irreps, irreps1, irreps2)`` and ``output_irreps = (irreps3,)``. """ + poly = _fully_connected_tensor_product_core(irreps1, irreps2, irreps3) + weight_irreps = irreps1.new_scalars(poly.inputs[0].size) + poly = cue.split_polynomial_by_irreps(poly, 2, irreps2) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps1) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps3) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(weight_irreps, irreps1, irreps2), + output_irreps=(irreps3,), + ) + + +def _full_tensor_product_core( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter: Optional[Sequence[cue.Irrep]], +) -> tuple[cue.SegmentedPolynomial, cue.Irreps]: G = irreps1.irrep_class if irreps3_filter is not None: @@ -136,28 +161,51 @@ def full_tensor_product( d = d.permute_segments(2, inv) d = d.normalize_paths_for_operand(-1) + return cue.SegmentedPolynomial.eval_last_operand(d), irreps3 + + +def full_tensor_product( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter: Optional[Sequence[cue.Irrep]] = None, +) -> cue.EquivariantPolynomial: + """ + subscripts: ``lhs[iu],rhs[jv],output[kuv]`` + + Construct a weightless channelwise tensor product descriptor. + + .. currentmodule:: cuequivariance + + Args: + irreps1 (Irreps): Irreps of the first operand. + irreps2 (Irreps): Irreps of the second operand. + irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. + + Returns: + :class:`cue.EquivariantPolynomial `: Descriptor of the full tensor product. + """ + poly, irreps3 = _full_tensor_product_core(irreps1, irreps2, irreps3_filter) return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), ], [cue.IrrepsAndLayout(irreps3, cue.ir_mul)], - cue.SegmentedPolynomial.eval_last_operand(d), + poly, ) -def channelwise_tensor_product( +def full_tensor_product_ir_dict( irreps1: cue.Irreps, irreps2: cue.Irreps, - irreps3_filter=None, - simplify_irreps3: bool = False, -) -> cue.EquivariantPolynomial: + irreps3_filter: Optional[Sequence[cue.Irrep]] = None, +) -> cue.IrDictPolynomial: """ - subscripts: ``weights[uv],lhs[iu],rhs[jv],output[kuv]`` + subscripts: ``lhs[iu],rhs[jv],output[kuv]`` - Construct a channelwise tensor product descriptor. + Construct a weightless channelwise tensor product as an :class:`~cuequivariance.IrDictPolynomial`. - This operation is computationally sparser than the fully connected tensor product. + This is the ``ir_dict`` variant of :func:`full_tensor_product`. .. currentmodule:: cuequivariance @@ -165,11 +213,28 @@ def channelwise_tensor_product( irreps1 (Irreps): Irreps of the first operand. irreps2 (Irreps): Irreps of the second operand. irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. - simplify_irreps3 (bool, optional): If True, the irreps of the output are simplified. Returns: - :class:`cue.EquivariantPolynomial `: Descriptor of the channelwise tensor product. + :class:`cue.IrDictPolynomial `: The full tensor product + with ``input_irreps = (irreps1, irreps2)`` and ``output_irreps = (irreps3,)``. """ + poly, irreps3 = _full_tensor_product_core(irreps1, irreps2, irreps3_filter) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps2) + poly = cue.split_polynomial_by_irreps(poly, 0, irreps1) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps3) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(irreps1, irreps2), + output_irreps=(irreps3,), + ) + + +def _channelwise_tensor_product_core( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter, + simplify_irreps3: bool, +) -> tuple[cue.SegmentedPolynomial, cue.Irreps]: G = irreps1.irrep_class if irreps3_filter is not None: @@ -215,14 +280,84 @@ def channelwise_tensor_product( d = d.permute_segments(3, [sid for _, _, sid in segments]) irreps3 = irreps3.simplify() + return cue.SegmentedPolynomial.eval_last_operand(d), irreps3 + + +def channelwise_tensor_product( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter=None, + simplify_irreps3: bool = False, +) -> cue.EquivariantPolynomial: + """ + subscripts: ``weights[uv],lhs[iu],rhs[jv],output[kuv]`` + + Construct a channelwise tensor product descriptor. + + This operation is computationally sparser than the fully connected tensor product. + + .. currentmodule:: cuequivariance + + Args: + irreps1 (Irreps): Irreps of the first operand. + irreps2 (Irreps): Irreps of the second operand. + irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. + simplify_irreps3 (bool, optional): If True, the irreps of the output are simplified. + + Returns: + :class:`cue.EquivariantPolynomial `: Descriptor of the channelwise tensor product. + """ + poly, irreps3 = _channelwise_tensor_product_core( + irreps1, irreps2, irreps3_filter, simplify_irreps3 + ) return cue.EquivariantPolynomial( [ - cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps1.new_scalars(poly.inputs[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), ], [cue.IrrepsAndLayout(irreps3, cue.ir_mul)], - cue.SegmentedPolynomial.eval_last_operand(d), + poly, + ) + + +def channelwise_tensor_product_ir_dict( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter=None, +) -> cue.IrDictPolynomial: + """ + subscripts: ``weights[uv],lhs[iu],rhs[jv],output[kuv]`` + + Construct a channelwise tensor product as an :class:`~cuequivariance.IrDictPolynomial`. + + This is the ``ir_dict`` variant of :func:`channelwise_tensor_product`. + The returned polynomial is already split by irrep and ready for use with + :func:`cuequivariance_jax.ir_dict.segmented_polynomial_uniform_1d`. + The output irreps are always simplified (each irrep appears at most once). + + .. currentmodule:: cuequivariance + + Args: + irreps1 (Irreps): Irreps of the first operand. + irreps2 (Irreps): Irreps of the second operand. + irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. + + Returns: + :class:`cue.IrDictPolynomial `: The channelwise tensor product + with ``input_irreps = (weight_irreps, irreps1, irreps2)`` and ``output_irreps = (irreps3,)``. + """ + poly, irreps3 = _channelwise_tensor_product_core( + irreps1, irreps2, irreps3_filter, simplify_irreps3=True + ) + weight_irreps = irreps1.new_scalars(poly.inputs[0].size) + poly = cue.split_polynomial_by_irreps(poly, 2, irreps2) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps1) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps3) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(weight_irreps, irreps1, irreps2), + output_irreps=(irreps3,), ) @@ -255,24 +390,11 @@ def _align_two_irreps( return cue.Irreps(irreps1.irrep_class, l1), cue.Irreps(irreps2.irrep_class, l2) -def elementwise_tensor_product( +def _elementwise_tensor_product_core( irreps1: cue.Irreps, irreps2: cue.Irreps, - irreps3_filter: Optional[Sequence[cue.Irrep]] = None, -) -> cue.EquivariantPolynomial: - """ - subscripts: ``lhs[iu],rhs[ju],output[ku]`` - - Construct an elementwise tensor product descriptor. - - Args: - irreps1 (Irreps): Irreps of the first operand. - irreps2 (Irreps): Irreps of the second operand. - irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. - - Returns: - :class:`cue.EquivariantPolynomial `: Descriptor of the elementwise tensor product. - """ + irreps3_filter: Optional[Sequence[cue.Irrep]], +) -> tuple[cue.SegmentedPolynomial, cue.Irreps, cue.Irreps, cue.Irreps]: G = irreps1.irrep_class if irreps1.num_irreps != irreps2.num_irreps: @@ -300,29 +422,89 @@ def elementwise_tensor_product( irreps3 = cue.Irreps(G, irreps3) d = d.normalize_paths_for_operand(-1) + return ( + cue.SegmentedPolynomial.eval_last_operand(d), + irreps1_cut, + irreps2_cut, + irreps3, + ) + + +def elementwise_tensor_product( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter: Optional[Sequence[cue.Irrep]] = None, +) -> cue.EquivariantPolynomial: + """ + subscripts: ``lhs[iu],rhs[ju],output[ku]`` + + Construct an elementwise tensor product descriptor. + + Args: + irreps1 (Irreps): Irreps of the first operand. + irreps2 (Irreps): Irreps of the second operand. + irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. + + Returns: + :class:`cue.EquivariantPolynomial `: Descriptor of the elementwise tensor product. + """ + poly, _, _, irreps3 = _elementwise_tensor_product_core( + irreps1, irreps2, irreps3_filter + ) return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), ], [cue.IrrepsAndLayout(irreps3, cue.ir_mul)], - cue.SegmentedPolynomial.eval_last_operand(d), + poly, ) -def linear(irreps_in: cue.Irreps, irreps_out: cue.Irreps) -> cue.EquivariantPolynomial: +def elementwise_tensor_product_ir_dict( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter: Optional[Sequence[cue.Irrep]] = None, +) -> cue.IrDictPolynomial: """ - subscripts: ``weights[uv],input[iu],output[iv]`` + subscripts: ``lhs[iu],rhs[ju],output[ku]`` - Construct the descriptor of a linear equivariant transformation. + Construct an elementwise tensor product as an :class:`~cuequivariance.IrDictPolynomial`. + + This is the ``ir_dict`` variant of :func:`elementwise_tensor_product`. + + Note: + The input irreps may be refined (split into smaller blocks) to align + multiplicities. The actual irreps used are available in the returned + ``input_irreps``. + + .. currentmodule:: cuequivariance Args: - irreps_in (Irreps): Irreps of the input. - irreps_out (Irreps): Irreps of the output. + irreps1 (Irreps): Irreps of the first operand. + irreps2 (Irreps): Irreps of the second operand. + irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. Returns: - :class:`cue.EquivariantPolynomial `: Descriptor of the linear transformation. + :class:`cue.IrDictPolynomial `: The elementwise tensor product + with ``input_irreps = (irreps1_aligned, irreps2_aligned)`` and ``output_irreps = (irreps3,)``. """ + poly, irreps1_cut, irreps2_cut, irreps3 = _elementwise_tensor_product_core( + irreps1, irreps2, irreps3_filter + ) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps2_cut) + poly = cue.split_polynomial_by_irreps(poly, 0, irreps1_cut) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps3) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(irreps1_cut, irreps2_cut), + output_irreps=(irreps3,), + ) + + +def _linear_core( + irreps_in: cue.Irreps, irreps_out: cue.Irreps +) -> cue.SegmentedPolynomial: d = cue.SegmentedTensorProduct.from_subscripts("uv_iu_iv") for mul, ir in irreps_in: d.add_segment(1, (ir.dim, mul)) @@ -336,12 +518,59 @@ def linear(irreps_in: cue.Irreps, irreps_out: cue.Irreps) -> cue.EquivariantPoly d.add_path((mul1, mul2), i1, i2, c=1.0) d = d.normalize_paths_for_operand(-1) + return cue.SegmentedPolynomial.eval_last_operand(d) + +def linear(irreps_in: cue.Irreps, irreps_out: cue.Irreps) -> cue.EquivariantPolynomial: + """ + subscripts: ``weights[uv],input[iu],output[iv]`` + + Construct the descriptor of a linear equivariant transformation. + + Args: + irreps_in (Irreps): Irreps of the input. + irreps_out (Irreps): Irreps of the output. + + Returns: + :class:`cue.EquivariantPolynomial `: Descriptor of the linear transformation. + """ + poly = _linear_core(irreps_in, irreps_out) return cue.EquivariantPolynomial( [ - cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps_in.new_scalars(poly.inputs[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps_in, cue.ir_mul), ], [cue.IrrepsAndLayout(irreps_out, cue.ir_mul)], - cue.SegmentedPolynomial.eval_last_operand(d), + poly, + ) + + +def linear_ir_dict( + irreps_in: cue.Irreps, irreps_out: cue.Irreps +) -> cue.IrDictPolynomial: + """ + subscripts: ``weights[uv],input[iu],output[iv]`` + + Construct a linear equivariant transformation as an :class:`~cuequivariance.IrDictPolynomial`. + + This is the ``ir_dict`` variant of :func:`linear`. + + .. currentmodule:: cuequivariance + + Args: + irreps_in (Irreps): Irreps of the input. + irreps_out (Irreps): Irreps of the output. + + Returns: + :class:`cue.IrDictPolynomial `: The linear transformation + with ``input_irreps = (weight_irreps, irreps_in)`` and ``output_irreps = (irreps_out,)``. + """ + poly = _linear_core(irreps_in, irreps_out) + weight_irreps = irreps_in.new_scalars(poly.inputs[0].size) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps_in) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps_out) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(weight_irreps, irreps_in), + output_irreps=(irreps_out,), ) diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py index de36ea86..cb9d9738 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py @@ -20,6 +20,38 @@ from cuequivariance.etc.sympy_utils import sqrtQarray_to_sympy +def _spherical_harmonics_core( + ir_vec: cue.Irrep, ls: list[int] +) -> tuple[cue.SegmentedPolynomial, cue.Irreps]: + if len(ls) != 1: + results = [_spherical_harmonics_core(ir_vec, [ell]) for ell in ls] + poly = cue.SegmentedPolynomial.stack([r[0] for r in results], [False, True]) + irreps_out = cue.Irreps(type(ir_vec), sum([list(r[1]) for r in results], [])) + return poly, irreps_out + + [ell] = ls + ir, formula = sympy_spherical_harmonics(ir_vec, ell) + + assert ir_vec.dim == 3 + d = cue.SegmentedTensorProduct.empty_segments([3] * ell + [ir.dim]) + for i in range(ir.dim): + for degrees, coeff in ( + sympy.Poly(formula[i], sympy.symbols("x:3")).as_dict().items() + ): + indices = poly_degrees_to_path_indices(degrees) + d.add_path(*indices, i, c=coeff) + + d = d.symmetrize_operands(range(ell), force=True) + + poly = cue.SegmentedPolynomial( + [cue.SegmentedOperand([()] * 3)], + [cue.SegmentedOperand([()] * ir.dim)], + [(cue.Operation([0] * ell + [1]), d)], + ) + irreps_out = cue.Irreps(type(ir_vec), [(1, ir)]) + return poly, irreps_out + + def spherical_harmonics( ir_vec: cue.Irrep, ls: list[int], layout: cue.IrrepsLayout = cue.ir_mul ) -> cue.EquivariantPolynomial: @@ -42,33 +74,36 @@ def spherical_harmonics( │ []·a[]➜B[] ───── num_paths=3 ╰─ []·a[]·a[]➜B[] ─ num_paths=11 """ - if len(ls) != 1: - return cue.EquivariantPolynomial.stack( - [spherical_harmonics(ir_vec, [ell], layout) for ell in ls], [False, True] - ) + poly, irreps_out = _spherical_harmonics_core(ir_vec, ls) + return cue.EquivariantPolynomial( + [cue.IrrepsAndLayout(cue.Irreps(ir_vec), cue.ir_mul)], + [cue.IrrepsAndLayout(irreps_out, cue.ir_mul)], + poly, + ) - [ell] = ls - ir, formula = sympy_spherical_harmonics(ir_vec, ell) - assert ir_vec.dim == 3 - d = cue.SegmentedTensorProduct.empty_segments([3] * ell + [ir.dim]) - for i in range(ir.dim): - for degrees, coeff in ( - sympy.Poly(formula[i], sympy.symbols("x:3")).as_dict().items() - ): - indices = poly_degrees_to_path_indices(degrees) - d.add_path(*indices, i, c=coeff) +def spherical_harmonics_ir_dict( + ir_vec: cue.Irrep, ls: list[int] +) -> cue.IrDictPolynomial: + """Polynomial descriptor for the spherical harmonics as an :class:`~cuequivariance.IrDictPolynomial`. - d = d.symmetrize_operands(range(ell), force=True) + This is the ``ir_dict`` variant of :func:`spherical_harmonics`. - return cue.EquivariantPolynomial( - [cue.IrrepsAndLayout(cue.Irreps(ir_vec), cue.ir_mul)], - [cue.IrrepsAndLayout(cue.Irreps(ir), cue.ir_mul)], - cue.SegmentedPolynomial( - [cue.SegmentedOperand([()] * 3)], - [cue.SegmentedOperand([()] * ir.dim)], - [(cue.Operation([0] * ell + [1]), d)], - ), + Args: + ir_vec (Irrep): irrep of the input vector, for example ``cue.SO3(1)``. + ls (list of int): list of spherical harmonic degrees, for example ``[0, 1, 2]``. + + Returns: + :class:`cue.IrDictPolynomial `: The spherical harmonics + with ``input_irreps = (Irreps(ir_vec),)`` and ``output_irreps = (irreps_out,)``. + """ + poly, irreps_out = _spherical_harmonics_core(ir_vec, ls) + irreps_in = cue.Irreps(ir_vec) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps_out) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(irreps_in,), + output_irreps=(irreps_out,), ) diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py index 2aa1df92..64ba2a63 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py @@ -17,57 +17,20 @@ import cuequivariance as cue -def symmetric_contraction( - irreps_in: cue.Irreps, - irreps_out: cue.Irreps, - degrees: tuple[int, ...], -) -> cue.EquivariantPolynomial: - """Construct the descriptor for a symmetric contraction. - - The symmetric contraction is a weighted sum of the input contracted with itself degree times. - - Subscripts: ``weights[u],input[u],output[u]`` - - Args: - irreps_in (Irreps): The input irreps, the multiplicity are treated in parallel. - irreps_out (Irreps): The output irreps. - degrees (tuple[int, ...]): List of degrees for the symmetric contractions. - - Returns: - EquivariantPolynomial: The descriptor of the symmetric contraction. - The operands are the weights, the input degree times and the output. - - Example: - >>> cue.descriptors.symmetric_contraction( - ... 16 * cue.Irreps("SO3", "0 + 1 + 2"), - ... 16 * cue.Irreps("SO3", "0 + 1"), - ... (1, 2, 3) - ... ) - ╭ a=32x0+80x0+176x0 b=16x0+16x1+16x2 -> C=16x0+16x1 - │ []·a[u]·b[u]➜C[u] ─────────── num_paths=4 u=16 - │ []·a[u]·b[u]·b[u]➜C[u] ────── num_paths=37 u=16 - ╰─ []·a[u]·b[u]·b[u]·b[u]➜C[u] ─ num_paths=437 u=16 - - Where ``32x0+80x0+176x0`` are the weights needed for each degree (32 for degree 1, 80 for degree 2, 176 for degree 3). - """ - return symmetric_contraction_cached(irreps_in, irreps_out, tuple(degrees)) - - @cache -def symmetric_contraction_cached( +def _symmetric_contraction_core( irreps_in: cue.Irreps, irreps_out: cue.Irreps, degrees: tuple[int, ...], -) -> cue.EquivariantPolynomial: +) -> cue.SegmentedPolynomial: degrees = list(degrees) if len(degrees) != 1: - return cue.EquivariantPolynomial.stack( - [ - symmetric_contraction(irreps_in, irreps_out, (degree,)) - for degree in degrees - ], - [True, False, False], - ) + polys = [ + _symmetric_contraction_core(irreps_in, irreps_out, (degree,)) + for degree in degrees + ] + return cue.SegmentedPolynomial.stack(polys, [True, False, False]) + [degree] = degrees del degrees @@ -118,15 +81,84 @@ def symmetric_contraction_cached( for i in input_operands: assert d.operands[i] == input_operand + return cue.SegmentedPolynomial( + [d.operands[0], input_operand], + [d.operands[-1]], + [(cue.Operation([0] + [1] * degree + [2]), d)], + ) + + +def symmetric_contraction( + irreps_in: cue.Irreps, + irreps_out: cue.Irreps, + degrees: tuple[int, ...], +) -> cue.EquivariantPolynomial: + """Construct the descriptor for a symmetric contraction. + + The symmetric contraction is a weighted sum of the input contracted with itself degree times. + + Subscripts: ``weights[u],input[u],output[u]`` + + Args: + irreps_in (Irreps): The input irreps, the multiplicity are treated in parallel. + irreps_out (Irreps): The output irreps. + degrees (tuple[int, ...]): List of degrees for the symmetric contractions. + + Returns: + EquivariantPolynomial: The descriptor of the symmetric contraction. + The operands are the weights, the input degree times and the output. + + Example: + >>> cue.descriptors.symmetric_contraction( + ... 16 * cue.Irreps("SO3", "0 + 1 + 2"), + ... 16 * cue.Irreps("SO3", "0 + 1"), + ... (1, 2, 3) + ... ) + ╭ a=288x0 b=16x0+16x1+16x2 -> C=16x0+16x1 + │ []·a[u]·b[u]➜C[u] ─────────── num_paths=4 u=16 + │ []·a[u]·b[u]·b[u]➜C[u] ────── num_paths=37 u=16 + ╰─ []·a[u]·b[u]·b[u]·b[u]➜C[u] ─ num_paths=437 u=16 + + Where ``32x0+80x0+176x0`` are the weights needed for each degree (32 for degree 1, 80 for degree 2, 176 for degree 3). + """ + poly = _symmetric_contraction_core(irreps_in, irreps_out, tuple(degrees)) return cue.EquivariantPolynomial( [ - cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), - cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul), + cue.IrrepsAndLayout(irreps_in.new_scalars(poly.inputs[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps_in, cue.ir_mul), ], - [cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul)], - cue.SegmentedPolynomial( - [d.operands[0], input_operand], - [d.operands[-1]], - [(cue.Operation([0] + [1] * degree + [2]), d)], - ), + [cue.IrrepsAndLayout(irreps_out, cue.ir_mul)], + poly, + ) + + +def symmetric_contraction_ir_dict( + irreps_in: cue.Irreps, + irreps_out: cue.Irreps, + degrees: tuple[int, ...], +) -> cue.IrDictPolynomial: + """Construct a symmetric contraction as an :class:`~cuequivariance.IrDictPolynomial`. + + This is the ``ir_dict`` variant of :func:`symmetric_contraction`. + + .. currentmodule:: cuequivariance + + Args: + irreps_in (Irreps): The input irreps, the multiplicity are treated in parallel. + irreps_out (Irreps): The output irreps. + degrees (tuple[int, ...]): List of degrees for the symmetric contractions. + + Returns: + :class:`cue.IrDictPolynomial `: The symmetric contraction + with ``input_irreps = (weight_irreps, mul * irreps_in)`` and + ``output_irreps = (mul * irreps_out,)``. + """ + poly = _symmetric_contraction_core(irreps_in, irreps_out, tuple(degrees)) + weight_irreps = irreps_in.new_scalars(poly.inputs[0].size) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps_in) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps_out) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(weight_irreps, irreps_in), + output_irreps=(irreps_out,), ) diff --git a/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py index aa655ee3..23321a83 100644 --- a/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py @@ -19,6 +19,9 @@ import cuequivariance as cue from cuequivariance.etc.linalg import round_to_sqrt_rational, triu_array +from cuequivariance.group_theory.descriptors.symmetric_contractions import ( + _symmetric_contraction_core as _std_symmetric_contraction_core, +) def symmetric_contraction( @@ -43,21 +46,54 @@ def symmetric_contraction( x = cuex.randn(jax.random.key(1), e.inputs[1]) y = cuex.equivariant_polynomial(e, [w, x]) """ - return symmetric_contraction_cached(irreps_in, irreps_out, tuple(degrees)) + poly, projection = _symmetric_contraction_cached( + irreps_in, irreps_out, tuple(degrees) + ) + return cue.EquivariantPolynomial( + [ + cue.IrrepsAndLayout(irreps_in.new_scalars(poly.inputs[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps_in, cue.ir_mul), + ], + [cue.IrrepsAndLayout(irreps_out, cue.ir_mul)], + poly, + ), projection + + +def symmetric_contraction_ir_dict( + irreps_in: cue.Irreps, irreps_out: cue.Irreps, degrees: tuple[int, ...] +) -> tuple[cue.IrDictPolynomial, np.ndarray]: + r"""``ir_dict`` variant of :func:`symmetric_contraction`. + + Returns: + tuple of (:class:`~cuequivariance.IrDictPolynomial`, np.ndarray): + The polynomial (with ``input_irreps = (weight_irreps, irreps_in)`` + and ``output_irreps = (irreps_out,)``) and the projection matrix. + """ + poly, projection = _symmetric_contraction_cached( + irreps_in, irreps_out, tuple(degrees) + ) + weight_irreps = irreps_in.new_scalars(poly.inputs[0].size) + poly = cue.split_polynomial_by_irreps(poly, 1, irreps_in) + poly = cue.split_polynomial_by_irreps(poly, -1, irreps_out) + return cue.IrDictPolynomial( + polynomial=poly, + input_irreps=(weight_irreps, irreps_in), + output_irreps=(irreps_out,), + ), projection @cache -def symmetric_contraction_cached( +def _symmetric_contraction_cached( irreps_in: cue.Irreps, irreps_out: cue.Irreps, degrees: tuple[int, ...] -) -> tuple[cue.EquivariantPolynomial, np.ndarray]: +) -> tuple[cue.SegmentedPolynomial, np.ndarray]: assert min(degrees) > 0 # poly1 replicates the behavior of the original MACE implementation - poly1 = cue.EquivariantPolynomial.stack( + poly1 = cue.SegmentedPolynomial.stack( [ - cue.EquivariantPolynomial.stack( + cue.SegmentedPolynomial.stack( [ - _symmetric_contraction(irreps_in, irreps_out[i : i + 1], deg) + _symmetric_contraction_poly(irreps_in, irreps_out[i : i + 1], deg) for deg in reversed(degrees) ], [True, False, False], @@ -66,7 +102,7 @@ def symmetric_contraction_cached( ], [True, False, True], ) - poly2 = cue.descriptors.symmetric_contraction(irreps_in, irreps_out, degrees) + poly2 = _std_symmetric_contraction_core(irreps_in, irreps_out, tuple(degrees)) a1, a2 = [ np.concatenate( [ @@ -75,7 +111,7 @@ def symmetric_contraction_cached( 1, None, ) - for _, d in pol.polynomial.operations + for _, d in pol.operations ], axis=1, ) @@ -120,9 +156,9 @@ def _stp_to_matrix( # This function is an adaptation of https://github.com/ACEsuit/mace/blob/bd412319b11c5f56c37cec6c4cfae74b2a49ff43/mace/modules/symmetric_contraction.py -def _symmetric_contraction( +def _symmetric_contraction_poly( irreps_in: cue.Irreps, irreps_out: cue.Irreps, degree: int -) -> cue.EquivariantPolynomial: +) -> cue.SegmentedPolynomial: mul = irreps_in.muls[0] assert all(mul == m for m in irreps_in.muls) assert all(mul == m for m in irreps_out.muls) @@ -157,15 +193,8 @@ def _symmetric_contraction( assert d.num_operands >= 3 [w, x], y = d.operands[:2], d.operands[-1] - return cue.EquivariantPolynomial( - [ - cue.IrrepsAndLayout(irreps_in.new_scalars(w.size), cue.ir_mul), - cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul), - ], - [cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul)], - cue.SegmentedPolynomial( - [w, x], [y], [(cue.Operation([0] + [1] * degree + [2]), d)] - ), + return cue.SegmentedPolynomial( + [w, x], [y], [(cue.Operation([0] + [1] * degree + [2]), d)] ) diff --git a/cuequivariance/cuequivariance/group_theory/ir_dict_polynomial.py b/cuequivariance/cuequivariance/group_theory/ir_dict_polynomial.py new file mode 100644 index 00000000..529736a9 --- /dev/null +++ b/cuequivariance/cuequivariance/group_theory/ir_dict_polynomial.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import itertools + +import cuequivariance as cue + + +def split_polynomial_by_irreps( + polynomial: cue.SegmentedPolynomial, + operand_id: int, + irreps: cue.Irreps, +) -> cue.SegmentedPolynomial: + """Split a polynomial operand according to irreps boundaries. + + Each ``(mul, ir)`` block in the irreps becomes a separate operand + in the resulting polynomial. + + Args: + polynomial: The polynomial to split. + operand_id: Index of the operand to split (negative indices supported). + irreps: Irreps describing the operand's structure. + + Returns: + A new :class:`~cuequivariance.SegmentedPolynomial` with the specified + operand split into one operand per ``(mul, ir)`` block. + """ + offsets = list( + itertools.accumulate((mul * ir.dim for mul, ir in irreps), initial=0) + ) + return polynomial.split_operand_by_size(operand_id, offsets) + + +@dataclasses.dataclass(init=False, frozen=True) +class IrDictPolynomial: + """A segmented polynomial with per-operand irreps metadata for the ``ir_dict`` workflow. + + This class pairs a :class:`~cuequivariance.SegmentedPolynomial` (already split + by irrep) with the :class:`~cuequivariance.Irreps` that describe each operand group. + + Each :class:`~cuequivariance.Irreps` in ``input_irreps`` and ``output_irreps`` + corresponds to a logical operand group (e.g. weights, node features, spherical + harmonics, output features). Within each group, every ``(mul, ir)`` block maps + to one polynomial operand. + + Contract: + - The polynomial is already split by irrep: each operand corresponds to + exactly one ``(mul, ir)`` block. + - The ``(mul, ir)`` blocks in ``input_irreps`` and ``output_irreps`` + are in the same order as the polynomial's input and output operands. + - For each ``(mul, ir)`` block, the corresponding polynomial operand + has size ``mul * ir.dim``. + + Args: + polynomial: The underlying polynomial, already split by irrep. + input_irreps: One :class:`~cuequivariance.Irreps` per input group. + output_irreps: One :class:`~cuequivariance.Irreps` per output group. + """ + + polynomial: cue.SegmentedPolynomial + input_irreps: tuple[cue.Irreps, ...] + output_irreps: tuple[cue.Irreps, ...] + + def __init__( + self, + polynomial: cue.SegmentedPolynomial, + input_irreps: list[cue.Irreps] | tuple[cue.Irreps, ...], + output_irreps: list[cue.Irreps] | tuple[cue.Irreps, ...], + ): + object.__setattr__(self, "polynomial", polynomial) + object.__setattr__(self, "input_irreps", tuple(input_irreps)) + object.__setattr__(self, "output_irreps", tuple(output_irreps)) + + expected_inputs = sum(len(irreps) for irreps in self.input_irreps) + if expected_inputs != polynomial.num_inputs: + raise ValueError( + f"input_irreps describe {expected_inputs} operands, " + f"but polynomial has {polynomial.num_inputs} inputs" + ) + + expected_outputs = sum(len(irreps) for irreps in self.output_irreps) + if expected_outputs != polynomial.num_outputs: + raise ValueError( + f"output_irreps describe {expected_outputs} operands, " + f"but polynomial has {polynomial.num_outputs} outputs" + ) + + operand_idx = 0 + for irreps in self.input_irreps: + for mul, ir in irreps: + actual_size = polynomial.inputs[operand_idx].size + expected_size = mul * ir.dim + if expected_size != actual_size: + raise ValueError( + f"Input operand {operand_idx} ({mul}x{ir}): " + f"expected size {expected_size}, " + f"got {actual_size}" + ) + operand_idx += 1 + + operand_idx = 0 + for irreps in self.output_irreps: + for mul, ir in irreps: + actual_size = polynomial.outputs[operand_idx].size + expected_size = mul * ir.dim + if expected_size != actual_size: + raise ValueError( + f"Output operand {operand_idx} ({mul}x{ir}): " + f"expected size {expected_size}, " + f"got {actual_size}" + ) + operand_idx += 1 + + def __repr__(self): + labels = [] + for irreps in self.input_irreps: + for mul, ir in irreps: + labels.append(f"{mul}x{ir}" if mul > 1 else f"{ir}") + for irreps in self.output_irreps: + for mul, ir in irreps: + labels.append(f"{mul}x{ir}" if mul > 1 else f"{ir}") + return self.polynomial.to_string(labels) diff --git a/cuequivariance/tests/group_theory/ir_dict_polynomial_test.py b/cuequivariance/tests/group_theory/ir_dict_polynomial_test.py new file mode 100644 index 00000000..457bd982 --- /dev/null +++ b/cuequivariance/tests/group_theory/ir_dict_polynomial_test.py @@ -0,0 +1,328 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +import cuequivariance as cue + +# -------------------------------------------------------------------------- +# split_polynomial_by_irreps +# -------------------------------------------------------------------------- + + +def test_split_polynomial_by_irreps_matches_split_operand_by_irrep(): + """The new standalone function should produce the same result as + EquivariantPolynomial.split_operand_by_irrep.""" + irreps_in = cue.Irreps(cue.O3, "64x0e + 32x1o") + irreps_sh = cue.Irreps(cue.O3, "0e + 1o") + irreps_out = cue.Irreps(cue.O3, "0e + 1o + 2e") + + e = cue.descriptors.channelwise_tensor_product( + irreps_in, irreps_sh, irreps_out, True + ) + old = ( + e.split_operand_by_irrep(2) + .split_operand_by_irrep(1) + .split_operand_by_irrep(-1) + .polynomial + ) + + new = e.polynomial + new = cue.split_polynomial_by_irreps(new, 2, irreps_sh) + new = cue.split_polynomial_by_irreps(new, 1, irreps_in) + new = cue.split_polynomial_by_irreps(new, -1, e.outputs[0].irreps) + + assert old == new + + +# -------------------------------------------------------------------------- +# IrDictPolynomial validation +# -------------------------------------------------------------------------- + + +def test_ir_dict_polynomial_rejects_wrong_operand_count(): + irreps_in = cue.Irreps(cue.O3, "4x0e + 2x1o") + irreps_out = cue.Irreps(cue.O3, "3x0e") + + result = cue.descriptors.linear_ir_dict(irreps_in, irreps_out) + + with pytest.raises(ValueError, match="input_irreps describe"): + cue.IrDictPolynomial( + polynomial=result.polynomial, + input_irreps=(irreps_in,), # wrong: should include weight group + output_irreps=result.output_irreps, + ) + + +def test_ir_dict_polynomial_rejects_wrong_operand_size(): + irreps_in = cue.Irreps(cue.O3, "4x0e + 2x1o") + irreps_out = cue.Irreps(cue.O3, "3x0e") + + result = cue.descriptors.linear_ir_dict(irreps_in, irreps_out) + + with pytest.raises(ValueError, match="expected size"): + cue.IrDictPolynomial( + polynomial=result.polynomial, + input_irreps=( + result.input_irreps[0], + cue.Irreps(cue.O3, "3x0e + 2x1o"), # wrong mul for 0e + ), + output_irreps=result.output_irreps, + ) + + +# -------------------------------------------------------------------------- +# _ir_dict descriptor variants match the old EquivariantPolynomial path +# -------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "irreps1, irreps2, irreps3_filter", + [ + ( + cue.Irreps(cue.O3, "64x0e + 32x1o"), + cue.Irreps(cue.O3, "0e + 1o"), + cue.Irreps(cue.O3, "0e + 1o + 2e"), + ), + ( + cue.Irreps(cue.O3, "16x0e + 8x1o + 4x2e"), + cue.Irreps(cue.O3, "0e + 1o"), + None, + ), + ( + cue.Irreps(cue.SO3, "8x0 + 4x1 + 2x2"), + cue.Irreps(cue.SO3, "0 + 1"), + None, + ), + ], +) +def test_channelwise_tensor_product_ir_dict(irreps1, irreps2, irreps3_filter): + # channelwise_tensor_product_ir_dict always simplifies output irreps + e = cue.descriptors.channelwise_tensor_product( + irreps1, irreps2, irreps3_filter, simplify_irreps3=True + ) + old_poly = ( + e.split_operand_by_irrep(2) + .split_operand_by_irrep(1) + .split_operand_by_irrep(-1) + .polynomial + ) + + result = cue.descriptors.channelwise_tensor_product_ir_dict( + irreps1, irreps2, irreps3_filter + ) + + assert result.polynomial == old_poly + assert result.output_irreps[0] == e.outputs[0].irreps + assert result.input_irreps[1] == irreps1 + assert result.input_irreps[2] == irreps2 + + +@pytest.mark.parametrize( + "irreps1, irreps2, irreps3", + [ + ( + cue.Irreps(cue.O3, "4x0e + 2x1o"), + cue.Irreps(cue.O3, "0e + 1o"), + cue.Irreps(cue.O3, "4x0e + 2x1o"), + ), + ( + cue.Irreps(cue.SO3, "8x0 + 4x1"), + cue.Irreps(cue.SO3, "0 + 1 + 2"), + cue.Irreps(cue.SO3, "8x0 + 4x1 + 2x2"), + ), + ], +) +def test_fully_connected_tensor_product_ir_dict(irreps1, irreps2, irreps3): + e = cue.descriptors.fully_connected_tensor_product(irreps1, irreps2, irreps3) + old_poly = ( + e.split_operand_by_irrep(2) + .split_operand_by_irrep(1) + .split_operand_by_irrep(-1) + .polynomial + ) + + result = cue.descriptors.fully_connected_tensor_product_ir_dict( + irreps1, irreps2, irreps3 + ) + + assert result.polynomial == old_poly + assert result.output_irreps[0] == irreps3 + + +@pytest.mark.parametrize( + "irreps_in, irreps_out", + [ + ( + cue.Irreps(cue.O3, "4x0e + 2x1o"), + cue.Irreps(cue.O3, "3x0e + 5x1o"), + ), + ( + cue.Irreps(cue.SO3, "16x0 + 8x1 + 4x2"), + cue.Irreps(cue.SO3, "8x0 + 4x1"), + ), + ], +) +def test_linear_ir_dict(irreps_in, irreps_out): + e = cue.descriptors.linear(irreps_in, irreps_out) + old_poly = e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial + + result = cue.descriptors.linear_ir_dict(irreps_in, irreps_out) + + assert result.polynomial == old_poly + assert result.output_irreps[0] == irreps_out + assert result.input_irreps[1] == irreps_in + + +def test_full_tensor_product_ir_dict(): + irreps1 = cue.Irreps(cue.O3, "2x0e + 1x1o") + irreps2 = cue.Irreps(cue.O3, "0e + 1o") + + e = cue.descriptors.full_tensor_product(irreps1, irreps2) + old_poly = ( + e.split_operand_by_irrep(1) + .split_operand_by_irrep(0) + .split_operand_by_irrep(-1) + .polynomial + ) + + result = cue.descriptors.full_tensor_product_ir_dict(irreps1, irreps2) + + assert result.polynomial == old_poly + assert result.input_irreps[0] == irreps1 + assert result.input_irreps[1] == irreps2 + + +def test_elementwise_tensor_product_ir_dict(): + irreps1 = cue.Irreps(cue.O3, "4x0e + 4x1o") + irreps2 = cue.Irreps(cue.O3, "4x0e + 4x1o") + + e = cue.descriptors.elementwise_tensor_product(irreps1, irreps2) + old_poly = ( + e.split_operand_by_irrep(1) + .split_operand_by_irrep(0) + .split_operand_by_irrep(-1) + .polynomial + ) + + result = cue.descriptors.elementwise_tensor_product_ir_dict(irreps1, irreps2) + + assert result.polynomial == old_poly + + +def test_symmetric_contraction_ir_dict(): + irreps_in = 16 * cue.Irreps("SO3", "0 + 1 + 2") + irreps_out = 16 * cue.Irreps("SO3", "0 + 1") + + e = cue.descriptors.symmetric_contraction(irreps_in, irreps_out, (1, 2, 3)) + old_poly = e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial + + result = cue.descriptors.symmetric_contraction_ir_dict( + irreps_in, irreps_out, (1, 2, 3) + ) + + assert result.polynomial == old_poly + (output_irreps,) = result.output_irreps + assert output_irreps == irreps_out + + +def test_mace_symmetric_contraction_ir_dict(): + from cuequivariance.group_theory.experimental.mace.symmetric_contractions import ( + symmetric_contraction as mace_sc, + ) + from cuequivariance.group_theory.experimental.mace.symmetric_contractions import ( + symmetric_contraction_ir_dict as mace_sc_ir_dict, + ) + + irreps_in = 4 * cue.Irreps("SO3", "0 + 1 + 2") + irreps_out = 4 * cue.Irreps("SO3", "0 + 1") + + e, projection_old = mace_sc(irreps_in, irreps_out, [1, 2, 3]) + old_poly = e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial + + result, projection_new = mace_sc_ir_dict(irreps_in, irreps_out, [1, 2, 3]) + + assert result.polynomial == old_poly + np.testing.assert_array_equal(projection_old, projection_new) + (output_irreps,) = result.output_irreps + assert output_irreps == irreps_out + + +@pytest.mark.parametrize("max_degree", [1, 2, 3, 4]) +def test_spherical_harmonics_ir_dict(max_degree): + ir_vec = cue.O3(1, -1) + ls = list(range(max_degree + 1)) + + e = cue.descriptors.spherical_harmonics(ir_vec, ls) + old_poly = e.split_operand_by_irrep(-1).polynomial + + result = cue.descriptors.spherical_harmonics_ir_dict(ir_vec, ls) + + assert result.polynomial == old_poly + (output_irreps,) = result.output_irreps + assert output_irreps == e.outputs[0].irreps + + # Numpy evaluation: verify output matches unsplit + vec = np.array([0.3, -0.5, 0.8]) + [out_flat] = e.polynomial(vec) + + out_parts = result.polynomial(vec) + out_concat = np.concatenate(out_parts) + np.testing.assert_allclose(out_flat, out_concat, atol=1e-12) + + +# -------------------------------------------------------------------------- +# Numpy evaluation: ir_dict variant produces same results as original +# -------------------------------------------------------------------------- + + +def test_channelwise_numpy_evaluation(): + """Evaluate both paths with numpy and compare outputs.""" + irreps1 = cue.Irreps(cue.O3, "4x0e + 2x1o") + irreps_sh = cue.Irreps(cue.O3, "0e + 1o") + + e = cue.descriptors.channelwise_tensor_product( + irreps1, irreps_sh, simplify_irreps3=True + ) + result = cue.descriptors.channelwise_tensor_product_ir_dict(irreps1, irreps_sh) + + # Use the same flat data for both + # Unsplit: [weights, input1_flat, input2_flat] + # Split: [weights, input1_ir0, input1_ir1, input2_ir0, input2_ir1] + w = np.random.randn(result.polynomial.inputs[0].size) + x1 = np.random.randn(e.polynomial.inputs[1].size) + x2 = np.random.randn(e.polynomial.inputs[2].size) + + [out_orig] = e.polynomial(w, x1, x2) + + # Split x1 and x2 by irrep boundaries + x1_parts = [] + offset = 0 + for mul, ir in irreps1: + size = mul * ir.dim + x1_parts.append(x1[offset : offset + size]) + offset += size + + x2_parts = [] + offset = 0 + for mul, ir in irreps_sh: + size = mul * ir.dim + x2_parts.append(x2[offset : offset + size]) + offset += size + + out_split = result.polynomial(w, *x1_parts, *x2_parts) + out_split_concat = np.concatenate(out_split) + + np.testing.assert_allclose(out_orig, out_split_concat, atol=1e-12) diff --git a/cuequivariance_jax/cuequivariance_jax/SKILL.md b/cuequivariance_jax/cuequivariance_jax/SKILL.md index d61d2d73..1da88a63 100644 --- a/cuequivariance_jax/cuequivariance_jax/SKILL.md +++ b/cuequivariance_jax/cuequivariance_jax/SKILL.md @@ -1,6 +1,6 @@ --- name: cuequivariance-jax -description: Execute equivariant polynomials in JAX using segmented_polynomial (naive/uniform_1d), equivariant_polynomial with RepArray, ir_dict with dict[Irrep, Array], and Flax NNX layers (IrrepsLinear, SphericalHarmonics). Use when writing JAX code with cuequivariance. +description: Execute equivariant polynomials in JAX using segmented_polynomial (naive/uniform_1d), the ir_dict workflow with IrDictPolynomial and dict[Irrep, Array], and Flax NNX layers (IrrepsLinear, SphericalHarmonics, IrrepsIndexedLinear). Use when writing JAX code with cuequivariance. --- # cuequivariance_jax: Executing Equivariant Polynomials in JAX @@ -9,11 +9,11 @@ description: Execute equivariant polynomials in JAX using segmented_polynomial ( `cuequivariance_jax` (imported as `cuex`) executes `cuequivariance` polynomials on GPU via JAX. It provides: -1. **Core primitive**: `cuex.segmented_polynomial()` -- JAX primitive with full AD/vmap/JIT support +1. **Core primitive**: `cuex.segmented_polynomial()` — JAX primitive with full AD/vmap/JIT support 2. **Two data representations** (both built on `segmented_polynomial`): - - `cuex.equivariant_polynomial()` + `RepArray` -- the original interface, a single contiguous array with representation metadata - - `cuex.ir_dict` module -- `dict[Irrep, Array]` interface, conceptually simpler, works naturally with `jax.tree` operations -3. **NNX layers**: `cuex.nnx` module -- Flax NNX `Module` wrappers using `dict[Irrep, Array]` + - `cuex.equivariant_polynomial()` + `RepArray` — the original interface, a single contiguous array with representation metadata + - `cuex.ir_dict` module — `dict[Irrep, Array]` interface, uses `IrDictPolynomial` descriptors, works naturally with `jax.tree` +3. **NNX layers**: `cuex.nnx` module — Flax NNX `Module` wrappers using `dict[Irrep, Array]` ## Execution methods @@ -65,20 +65,8 @@ y = jax.random.normal(key, (batch, poly.inputs[2].size)) # batched input 2 Inputs can have any number of batch axes (everything before the last axis). Standard NumPy broadcasting applies: each batch axis is either size-1 or a common size. Inputs with fewer batch dimensions are implicitly prepended with size-1 axes: ```python -# 2 batch axes with size-1 broadcasting -w = jnp.ones((1, 10, poly.inputs[0].size)) # shared across axis 0 -x = jnp.ones((5, 10, poly.inputs[1].size)) # 5 along axis 0 -y = jnp.ones((5, 1, poly.inputs[2].size)) # shared across axis 1 - -[out] = cuex.segmented_polynomial( - poly, [w, x, y], - [jax.ShapeDtypeStruct((5, 10, poly.outputs[0].size), jnp.float32)], - method="uniform_1d", -) -# out.shape == (5, 10, ...) - # Fewer batch dims: weights with no batch axis broadcast across all -w = jnp.ones((poly.inputs[0].size,)) # 0 batch axes -> prepended as (1, 1, ...) +w = jnp.ones((poly.inputs[0].size,)) # 0 batch axes -> broadcasts x = jnp.ones((5, 10, poly.inputs[1].size)) y = jnp.ones((5, 10, poly.inputs[2].size)) @@ -91,13 +79,9 @@ y = jnp.ones((5, 10, poly.inputs[2].size)) ### Indexing (gather/scatter) -Index arrays provide gather (for inputs) and scatter (for outputs). One index per operand (inputs + outputs), `None` means no indexing. Index arrays decouple input/output batch shapes -- the output shape is determined by the index ranges, not by the input shapes: +Index arrays provide gather (for inputs) and scatter (for outputs). One index per operand (inputs + outputs), `None` means no indexing: ```python -a = jnp.ones((1, 50, poly.inputs[0].size)) -b = jnp.ones((10, 50, poly.inputs[1].size)) -c = jnp.ones((100, 1, poly.inputs[2].size)) - i = jax.random.randint(key, (100, 50), 0, 10) # gather b along axis 0 j1 = jax.random.randint(key, (100, 50), 0, 11) # scatter output axis 0 j2 = jax.random.randint(key, (100, 1), 0, 12) # scatter output axis 1 @@ -108,12 +92,11 @@ j2 = jax.random.randint(key, (100, 1), 0, 12) # scatter output axis 1 indices=[None, np.s_[i, :], None, np.s_[j1, j2]], method="uniform_1d", ) -# out.shape == (11, 12, ...) -- determined by index ranges, not input shapes ``` ### Gradients -Fully differentiable -- supports `jax.grad`, `jax.jacobian`, `jax.jvp`, `jax.vmap`: +Fully differentiable — supports `jax.grad`, `jax.jacobian`, `jax.jvp`, `jax.vmap`: ```python def loss(w, x, y): @@ -127,55 +110,28 @@ def loss(w, x, y): grad_w = jax.grad(loss, 0)(w, x, y) ``` -## RepArray interface: equivariant_polynomial - -The original interface. Wraps `segmented_polynomial` with `RepArray` -- a single contiguous array with representation metadata: - -```python -e = cue.descriptors.fully_connected_tensor_product( - 4 * cue.Irreps("SO3", "0 + 1"), - cue.Irreps("SO3", "0 + 1"), - 4 * cue.Irreps("SO3", "0 + 1"), -) - -inputs = [ - cuex.randn(jax.random.key(i), rep, (batch,), jnp.float32) - for i, rep in enumerate(e.inputs) -] - -# Returns a RepArray with representation metadata -out = cuex.equivariant_polynomial(e, inputs, method="naive") -out.array # the raw jax.Array -out.reps # dict mapping axes to Rep objects -``` - ## ir_dict interface -An alternative to `RepArray`. Uses `dict[Irrep, Array]` where each value has shape `(..., multiplicity, irrep_dim)`. Conceptually simpler: works naturally with `jax.tree` operations and is the standard representation for NNX layers. +Uses `dict[Irrep, Array]` where each value has shape `(..., multiplicity, irrep_dim)`. This is the standard representation for NNX layers and works naturally with `jax.tree` operations. -### Preparing a polynomial for ir_dict +### Getting an ir_dict-ready polynomial -Descriptors produce `EquivariantPolynomial` with dense operands. To use `ir_dict`, split operands by irrep: +Use `_ir_dict` descriptor variants, which return `IrDictPolynomial` with the polynomial already split by irrep: ```python -e = cue.descriptors.channelwise_tensor_product( +desc = cue.descriptors.channelwise_tensor_product_ir_dict( 32 * cue.Irreps("SO3", "0 + 1"), cue.Irreps("SO3", "0 + 1"), cue.Irreps("SO3", "0 + 1"), - simplify_irreps3=True, ) -# Split irreps-typed operands into per-irrep pieces -# Order matters: split inner operands first to preserve operand indices -poly = ( - e.split_operand_by_irrep(2) # split input 2 - .split_operand_by_irrep(1) # split input 1 - .split_operand_by_irrep(-1) # split output - .polynomial -) -# After split: each operand has a single irrep type, mapping naturally to dict[Irrep, Array] +poly = desc.polynomial # SegmentedPolynomial, already split by irrep +weight_irreps, irreps1, irreps2 = desc.input_irreps +(irreps_out,) = desc.output_irreps # tuple unpacking to get the single output group ``` +Each polynomial operand corresponds to exactly one `(mul, ir)` block. The `input_irreps` and `output_irreps` tuples describe how operands group into logical operand groups (weights, node features, spherical harmonics, output). + ### Executing with segmented_polynomial_uniform_1d ```python @@ -194,7 +150,7 @@ node_feats = { } x1 = jax.tree.map(lambda v: rearrange(v, "n m i -> n i m"), node_feats) -# Spherical harmonics: (edges, ir.dim) -- no multiplicity dimension +# Spherical harmonics: (edges, ir.dim) — no multiplicity dimension sph = { cue.SO3(0): jnp.ones((num_edges, 1)), cue.SO3(1): jnp.ones((num_edges, 3)), @@ -203,7 +159,6 @@ sph = { # Build output template senders = jax.random.randint(key, (num_edges,), 0, num_nodes) receivers = jax.random.randint(key, (num_edges,), 0, num_nodes) -irreps_out = e.outputs[0].irreps out_template = { ir: jax.ShapeDtypeStruct( (num_nodes, desc.num_segments) + desc.segment_shape, w.dtype @@ -242,6 +197,28 @@ z = cuex.ir_dict.irreps_zeros_like(x) template = cuex.ir_dict.mul_ir_dict(irreps, jax.ShapeDtypeStruct(shape, dtype)) ``` +## RepArray interface: equivariant_polynomial + +The original interface. Wraps `segmented_polynomial` with `RepArray` — a single contiguous array with representation metadata: + +```python +e = cue.descriptors.fully_connected_tensor_product( + 4 * cue.Irreps("SO3", "0 + 1"), + cue.Irreps("SO3", "0 + 1"), + 4 * cue.Irreps("SO3", "0 + 1"), +) + +inputs = [ + cuex.randn(jax.random.key(i), rep, (batch,), jnp.float32) + for i, rep in enumerate(e.inputs) +] + +# Returns a RepArray with representation metadata +out = cuex.equivariant_polynomial(e, inputs, method="naive") +out.array # the raw jax.Array +out.reps # dict mapping axes to Rep objects +``` + ## NNX layers ### IrrepsLinear @@ -273,6 +250,8 @@ Implementation uses `jnp.einsum("uv,...ui->...vi", w, x[ir])` per irrep with `1/ ### SphericalHarmonics +Uses `spherical_harmonics_ir_dict` internally for the `dict[Irrep, Array]` output: + ```python sh = cuex.nnx.SphericalHarmonics(max_degree=3, eps=0.0) @@ -338,43 +317,25 @@ For `equivariant_polynomial()` (RepArray interface): ```python e = cue.descriptors.channelwise_tensor_product(...) e = e.squeeze_modes().flatten_coefficient_modes() -# If still >1 mode: e = e.flatten_modes(["u", "w"]) out = cuex.equivariant_polynomial(e, inputs, method="uniform_1d") ``` -For `ir_dict` (dict[Irrep, Array] interface): +For `ir_dict` (dict[Irrep, Array] interface), use `_ir_dict` descriptors directly: ```python -e = cue.descriptors.channelwise_tensor_product(..., simplify_irreps3=True) -poly = ( - e.split_operand_by_irrep(2) # split input 2 - .split_operand_by_irrep(1) # split input 1 - .split_operand_by_irrep(-1) # split output - .polynomial +desc = cue.descriptors.channelwise_tensor_product_ir_dict( + irreps_in, irreps_sh, irreps_out ) +poly = desc.polynomial # Each operand has a single irrep type -> maps naturally to dict[Irrep, Array] ``` -### Why split_operand_by_irrep matters +### Why splitting by irrep matters -Without splitting, a dense operand like `32x0+32x1` requires all irreps packed into a single contiguous buffer. After `split_operand_by_irrep`, each irrep gets its own separate buffer passed to the CUDA kernel via FFI. The buffers no longer need to be contiguous with each other. +Without splitting, a dense operand like `32x0+32x1` requires all irreps packed into a single contiguous buffer. After splitting, each irrep gets its own separate buffer passed to the CUDA kernel via FFI. The buffers no longer need to be contiguous with each other. This is especially useful when the polynomial is preceded or followed by per-irrep linear layers (like `IrrepsLinear`). With split operands, no transpose or copy is needed between the linear layers and the polynomial — the `dict[Irrep, Array]` flows directly through the pipeline. -## RepArray - -Representation-aware JAX array: - -```python -rep = cue.IrrepsAndLayout(cue.Irreps("SO3", "4x0 + 2x1"), cue.ir_mul) -x = cuex.RepArray(rep, jnp.ones((batch, rep.dim))) -x = cuex.randn(jax.random.key(0), rep, (batch,), jnp.float32) - -x.array # raw jax.Array -x.reps # {axis: Rep} -x.irreps # Irreps (if last axis is IrrepsAndLayout) -``` - ## Complete GNN message-passing example This pattern is used in NequIP, MACE, and similar equivariant GNN models: @@ -382,20 +343,13 @@ This pattern is used in NequIP, MACE, and similar equivariant GNN models: ```python class MessagePassing(nnx.Module): def __init__(self, irreps_in, irreps_sh, irreps_out, epsilon, *, name, dtype, rngs): - e = ( - cue.descriptors.channelwise_tensor_product( - irreps_in, irreps_sh, irreps_out, True - ) - * epsilon - ) - self.weight_numel = e.inputs[0].dim - self.irreps_out = e.outputs[0].irreps - self.poly = ( - e.split_operand_by_irrep(2) - .split_operand_by_irrep(1) - .split_operand_by_irrep(-1) - .polynomial + self.name = name + desc = cue.descriptors.channelwise_tensor_product_ir_dict( + irreps_in, irreps_sh, irreps_out ) + (self.irreps_out,) = desc.output_irreps + self.poly = desc.polynomial * epsilon + self.weight_numel = self.poly.inputs[0].size def __call__(self, weights, node_feats, sph, senders, receivers, num_nodes): # weights: (num_edges, weight_numel) @@ -425,6 +379,20 @@ class MessagePassing(nnx.Module): } ``` +## RepArray + +Representation-aware JAX array: + +```python +rep = cue.IrrepsAndLayout(cue.Irreps("SO3", "4x0 + 2x1"), cue.ir_mul) +x = cuex.RepArray(rep, jnp.ones((batch, rep.dim))) +x = cuex.randn(jax.random.key(0), rep, (batch,), jnp.float32) + +x.array # raw jax.Array +x.reps # {axis: Rep} +x.irreps # Irreps (if last axis is IrrepsAndLayout) +``` + ## Key file locations | Component | Path | diff --git a/cuequivariance_jax/cuequivariance_jax/nnx.py b/cuequivariance_jax/cuequivariance_jax/nnx.py index 10435923..47169617 100644 --- a/cuequivariance_jax/cuequivariance_jax/nnx.py +++ b/cuequivariance_jax/cuequivariance_jax/nnx.py @@ -28,10 +28,8 @@ from . import ir_dict from .activation import normalize_function -from .rep_array.rep_array_ import RepArray from .segmented_polynomials.segmented_polynomial import segmented_polynomial from .segmented_polynomials.utils import Repeats -from .spherical_harmonics import spherical_harmonics try: from flax import nnx @@ -143,33 +141,31 @@ def __call__(self, x: dict[Irrep, Array]) -> dict[Irrep, Array]: class SphericalHarmonics(nnx.Module): def __init__(self, max_degree: int, eps: float = 0.0): self.eps = eps - self.max_degree = max_degree - self.irreps_in = cue.Irreps(cue.O3, "1o") - self.irreps_out = cue.Irreps( - cue.O3, [(1, cue.O3(L, (-1) ** L)) for L in range(max_degree + 1)] + desc = cue.descriptors.spherical_harmonics_ir_dict( + cue.O3(1, -1), list(range(max_degree + 1)) ) + self.poly = desc.polynomial + (self.irreps_out,) = desc.output_irreps def __call__(self, x: Array) -> dict[Irrep, Array]: assert x.shape[-1] == 3 shape = x.shape[:-1] - x = RepArray(self.irreps_in, x, cue.ir_mul) - x = jax.tree.map( - lambda v: v / _safe_norm(v, self.eps, axis=-1, keepdim=True), x + x = x / _safe_norm(x, self.eps, axis=-1, keepdim=True) + outputs = segmented_polynomial( + self.poly, + [x], + [ + jax.ShapeDtypeStruct(shape + (out.size,), x.dtype) + for out in self.poly.outputs + ], + method="naive", + name="spherical_harmonics", ) - y = spherical_harmonics(range(self.max_degree + 1), x, normalize=False) - - y = { - ir: rearrange(v, "... i m -> ... m i") - for (_, ir), v in zip(y.irreps, y.segments) - } - actual = jax.tree.map(lambda x: x.shape, y) - expected = { - cue.O3(L, (-1) ** L): shape + (1, 2 * L + 1) - for L in range(self.max_degree + 1) + return { + ir: y.reshape(shape + (1, ir.dim)) + for (_, ir), y in zip(self.irreps_out, outputs) } - assert actual == expected, f"y: {actual}, expected: {expected}" - return y class IrrepsNormalize(nnx.Module): @@ -245,12 +241,14 @@ def __init__( self.irreps_in = irreps_in self.irreps_out = irreps_out self.num_indices = num_indices - self.scale = scale / jnp.sqrt(num_indices) self.name = name - self.e = cue.descriptors.linear(irreps_in, irreps_out) * self.scale + scale = scale / jnp.sqrt(num_indices) + self.poly = cue.descriptors.linear(irreps_in, irreps_out).polynomial * scale self.w = nnx.Param( - jax.random.normal(rngs.params(), (num_indices, self.e.inputs[0].dim), dtype) + jax.random.normal( + rngs.params(), (num_indices, self.poly.inputs[0].size), dtype + ) ) def __call__( @@ -261,13 +259,17 @@ def __call__( # Convert dict (batch, mul, ir.dim) -> ir_mul flat order x_ir_mul = jax.tree.map(lambda v: rearrange(v, "... m i -> ... i m"), x) x_flat = ir_dict.dict_to_flat(self.irreps_in, x_ir_mul) + x_flat = x_flat.astype(self.w[...].dtype) num_elements = x_flat.shape[0] - p = self.e.polynomial [y_flat] = segmented_polynomial( - p, + self.poly, [self.w[...], x_flat], - [jax.ShapeDtypeStruct((num_elements, p.outputs[0].size), x_flat.dtype)], + [ + jax.ShapeDtypeStruct( + (num_elements, self.poly.outputs[0].size), x_flat.dtype + ) + ], [Repeats(num_index_counts), None, None], method="indexed_linear", name=self.name, diff --git a/cuequivariance_jax/examples/mace_nnx.py b/cuequivariance_jax/examples/mace_nnx.py index e1439b4d..72aa5572 100644 --- a/cuequivariance_jax/examples/mace_nnx.py +++ b/cuequivariance_jax/examples/mace_nnx.py @@ -26,8 +26,8 @@ import jax import jax.numpy as jnp import numpy as np -from cuequivariance.group_theory.experimental.mace import ( - symmetric_contraction as mace_symmetric_contraction, +from cuequivariance.group_theory.experimental.mace.symmetric_contractions import ( + symmetric_contraction_ir_dict as mace_symmetric_contraction_ir_dict, ) from cuequivariance_jax.nnx import ( MLP, @@ -89,21 +89,12 @@ def __init__( rngs: nnx.Rngs, ): self.name = name - e = ( - cue.descriptors.channelwise_tensor_product( - irreps_in, irreps_sh, irreps_out, True - ) - * epsilon - ) - self.weight_numel = e.inputs[0].dim - self.irreps_out = e.outputs[0].irreps - - self.poly = ( - e.split_operand_by_irrep(2) - .split_operand_by_irrep(1) - .split_operand_by_irrep(-1) - .polynomial + desc = cue.descriptors.channelwise_tensor_product_ir_dict( + irreps_in, irreps_sh, irreps_out ) + (self.irreps_out,) = desc.output_irreps + self.poly = desc.polynomial * epsilon + self.weight_numel = self.poly.inputs[0].size def __call__( self, @@ -156,12 +147,12 @@ def __init__( self.irreps_out = irreps_out self.name = name - e, projection = mace_symmetric_contraction( + desc, projection = mace_symmetric_contraction_ir_dict( irreps_in, irreps_out, range(1, correlation + 1) ) self.projection = jnp.array(projection, dtype=dtype) - self.poly = e.split_operand_by_irrep(1).split_operand_by_irrep(-1).polynomial + self.poly = desc.polynomial self.w = nnx.Param( jax.random.normal( rngs.params(), diff --git a/cuequivariance_jax/examples/nequip_nnx.py b/cuequivariance_jax/examples/nequip_nnx.py index f1016da3..cef7cb00 100644 --- a/cuequivariance_jax/examples/nequip_nnx.py +++ b/cuequivariance_jax/examples/nequip_nnx.py @@ -222,20 +222,12 @@ def __init__( rngs: nnx.Rngs, ): self.name = name - e = ( - cue.descriptors.channelwise_tensor_product( - irreps_in, irreps_sh, irreps_out, True - ) - * epsilon - ) - self.weight_numel = e.inputs[0].dim - self.irreps_out = e.outputs[0].irreps - self.poly = ( - e.split_operand_by_irrep(2) - .split_operand_by_irrep(1) - .split_operand_by_irrep(-1) - .polynomial + desc = cue.descriptors.channelwise_tensor_product_ir_dict( + irreps_in, irreps_sh, irreps_out ) + (self.irreps_out,) = desc.output_irreps + self.poly = desc.polynomial * epsilon + self.weight_numel = self.poly.inputs[0].size def __call__( self, diff --git a/cuequivariance_torch/cuequivariance_torch/SKILL.md b/cuequivariance_torch/cuequivariance_torch/SKILL.md index 002765aa..e2a92e41 100644 --- a/cuequivariance_torch/cuequivariance_torch/SKILL.md +++ b/cuequivariance_torch/cuequivariance_torch/SKILL.md @@ -9,7 +9,7 @@ description: Execute equivariant tensor products in PyTorch using SegmentedPolyn `cuequivariance_torch` (imported as `cuet`) executes `cuequivariance` polynomials on GPU via PyTorch. It provides: -1. **Core primitive**: `cuet.SegmentedPolynomial` -- `torch.nn.Module` with multiple CUDA backends +1. **Core primitive**: `cuet.SegmentedPolynomial` — `torch.nn.Module` with multiple CUDA backends 2. **High-level operations** (`torch.nn.Module`): `ChannelWiseTensorProduct`, `FullyConnectedTensorProduct`, `Linear`, `SymmetricContraction`, `SphericalHarmonics`, `Rotation`, `Inversion` 3. **Layers**: `cuet.layers.BatchNorm`, `cuet.layers.FullyConnectedTensorProductConv` (message passing) 4. **Utilities**: `triangle_attention`, `triangle_multiplicative_update`, `attention_pair_bias` (AlphaFold2-style) @@ -96,8 +96,8 @@ All operations are `torch.nn.Module` subclasses. They wrap `SegmentedPolynomial` `IrrepsLayout` controls memory order within each `(mul, ir)` block: -- `cue.mul_ir`: data ordered as `(mul, ir.dim)` -- **default, compatible with e3nn** -- `cue.ir_mul`: data ordered as `(ir.dim, mul)` -- **used internally by descriptors** +- `cue.mul_ir`: data ordered as `(mul, ir.dim)` — **default, compatible with e3nn** +- `cue.ir_mul`: data ordered as `(ir.dim, mul)` — **used internally by descriptors** Operations accept `layout` (applies to all), or per-operand `layout_in1`, `layout_in2`, `layout_out`. @@ -150,7 +150,7 @@ tp = cuet.FullyConnectedTensorProduct( cue.Irreps("O3", "0e + 1o"), # irreps_in2 cue.Irreps("O3", "4x0e + 4x1o"), # irreps_out layout=cue.mul_ir, - internal_weights=True, # store weights as parameters + internal_weights=True, device="cuda", ) @@ -195,7 +195,7 @@ MACE-style symmetric contraction with element-indexed weights. sc = cuet.SymmetricContraction( cue.Irreps("O3", "32x0e + 32x1o"), # irreps_in (uniform mul required) cue.Irreps("O3", "32x0e"), # irreps_out (uniform mul required) - contraction_degree=3, # polynomial degree + contraction_degree=3, num_elements=95, # number of chemical elements layout=cue.ir_mul, dtype=torch.float32, @@ -213,8 +213,8 @@ Default method: `"uniform_1d"` if segments are uniform, else `"naive"`. ```python sh = cuet.SphericalHarmonics( - ls=[0, 1, 2, 3], # degrees - normalize=True, # normalize input vectors + ls=[0, 1, 2, 3], + normalize=True, device="cuda", ) @@ -283,7 +283,7 @@ conv = cuet.layers.FullyConnectedTensorProductConv( in_irreps=cue.Irreps("O3", "4x0e + 4x1o"), sh_irreps=cue.Irreps("O3", "0e + 1o"), out_irreps=cue.Irreps("O3", "4x0e + 4x1o"), - mlp_channels=[16, 32, 32], # MLP for path weights + mlp_channels=[16, 32, 32], mlp_activation=torch.nn.ReLU(), batch_norm=True, layout=cue.ir_mul,