From 7671721807e0cf8ff523d88e83ed60b9cf5700b0 Mon Sep 17 00:00:00 2001 From: Franco Pellegrini <16089353+phiandark@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:34:22 -0700 Subject: [PATCH 1/3] Fix checkpoint portability: GPU-saved models now load correctly on CPU Add __reduce__ to SegmentedPolynomialFromUniform1dJit, SegmentedPolynomialFusedTP, SegmentedPolynomialIndexedLinear, and SegmentedPolynomial so that unpickling re-delegates to SegmentedPolynomial.__init__, which selects the appropriate backend for the loading machine. Also wrap backend construction in try/except ImportError to handle the case where cuequivariance_ops_torch is importable but specific extensions (e.g. uniform_1d) are not available. Remove the pre-built fallback submodule for fused_tp in favor of lazy construction from the stored polynomial when CPU fallback is needed. Made-with: Cursor --- .../primitives/segmented_polynomial.py | 83 +++++++++++++++---- .../segmented_polynomial_fused_tp.py | 18 ++++ .../segmented_polynomial_indexed_linear.py | 19 +++++ .../segmented_polynomial_uniform_1d.py | 21 +++++ 4 files changed, 125 insertions(+), 16 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py index 1b09bb7d..e3c7e576 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py @@ -148,6 +148,9 @@ def __init__( ): super().__init__() + self._polynomial_orig = polynomial + self._math_dtype_orig = math_dtype + self._output_dtype_map_orig = output_dtype_map self.num_inputs = polynomial.num_inputs self.num_outputs = polynomial.num_outputs self.method = method @@ -181,29 +184,67 @@ def __init__( ) if method == "uniform_1d": - self.m = SegmentedPolynomialFromUniform1dJit( - polynomial, math_dtype, output_dtype_map, name, self.options - ) - self.fallback = self.m + try: + self.m = SegmentedPolynomialFromUniform1dJit( + polynomial, math_dtype, output_dtype_map, name, self.options + ) + except ImportError: + method = "naive" + warnings.warn( + "uniform_1d backend is not available. " + "Falling back to naive implementation." + ) + self.m = SegmentedPolynomialNaive( + polynomial, math_dtype, output_dtype_map, name + ) elif method == "naive": self.m = SegmentedPolynomialNaive( polynomial, math_dtype, output_dtype_map, name ) - self.fallback = self.m elif method == "fused_tp": - self.m = SegmentedPolynomialFusedTP( - polynomial, math_dtype, output_dtype_map, name - ) - self.fallback = SegmentedPolynomialNaive( - polynomial, math_dtype, output_dtype_map, name - ) + try: + self.m = SegmentedPolynomialFusedTP( + polynomial, math_dtype, output_dtype_map, name + ) + except ImportError: + method = "naive" + warnings.warn( + "fused_tp backend is not available. " + "Falling back to naive implementation." + ) + self.m = SegmentedPolynomialNaive( + polynomial, math_dtype, output_dtype_map, name + ) elif method == "indexed_linear": - self.m = SegmentedPolynomialIndexedLinear( - polynomial, math_dtype, output_dtype_map, name - ) - self.fallback = self.m + try: + self.m = SegmentedPolynomialIndexedLinear( + polynomial, math_dtype, output_dtype_map, name + ) + except ImportError: + method = "naive" + warnings.warn( + "indexed_linear backend is not available. " + "Falling back to naive implementation." + ) + self.m = SegmentedPolynomialNaive( + polynomial, math_dtype, output_dtype_map, name + ) else: raise ValueError(f"Invalid method: {method}") + self.method = method + + def __reduce__(self): + return ( + SegmentedPolynomial, + ( + self._polynomial_orig, + self.method, + self._math_dtype_orig, + self._output_dtype_map_orig, + "segmented_polynomial", + self.options, + ), + ) def __repr__(self): return self.repr + f"\n{super().__repr__()}" @@ -304,7 +345,17 @@ def forward( warnings.warn( "Fused TP is not supported on CPU. Falling back to naive implementation." ) - return self.fallback( + if not hasattr(self, "_cpu_fallback"): + object.__setattr__( + self, + "_cpu_fallback", + SegmentedPolynomialNaive( + self._polynomial_orig, + self._math_dtype_orig, + self._output_dtype_map_orig, + ), + ) + return self._cpu_fallback( inputs, input_indices, output_shapes, output_indices ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py index 90f34162..8d1d072e 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py @@ -81,6 +81,8 @@ def __init__( name: str = "segmented_polynomial", ): super().__init__() + self._polynomial_orig = polynomial + self._output_dtype_map_orig = output_dtype_map self.num_inputs = polynomial.num_inputs self.num_outputs = polynomial.num_outputs self.input_sizes = [o.size for o in polynomial.inputs] @@ -164,6 +166,22 @@ def __init__( ) self.b_outs.append(b_out - self.num_inputs) + def __reduce__(self): + from cuequivariance_torch.primitives.segmented_polynomial import ( + SegmentedPolynomial, + ) + + return ( + SegmentedPolynomial, + ( + self._polynomial_orig, + "fused_tp", + self.math_dtype, + self._output_dtype_map_orig, + self.name, + ), + ) + def forward( self, inputs: List[torch.Tensor], diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_indexed_linear.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_indexed_linear.py index 6dd7a4e5..6c98dafc 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_indexed_linear.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_indexed_linear.py @@ -165,6 +165,9 @@ def __init__( name: str = "segmented_polynomial", ): super().__init__() + self._polynomial_orig = polynomial + self._output_dtype_map_orig = output_dtype_map + self._math_dtype_orig = math_dtype self.num_inputs = polynomial.num_inputs self.num_outputs = polynomial.num_outputs self.input_sizes = [o.size for o in polynomial.inputs] @@ -225,6 +228,22 @@ def __init__( ) self.tps.append(IndexedLinear(d, signature[0], signature[1:], math_dtype)) + def __reduce__(self): + from cuequivariance_torch.primitives.segmented_polynomial import ( + SegmentedPolynomial, + ) + + return ( + SegmentedPolynomial, + ( + self._polynomial_orig, + "indexed_linear", + self._math_dtype_orig, + self._output_dtype_map_orig, + self.name, + ), + ) + def forward( self, inputs: List[torch.Tensor], diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py index dfefe8aa..7662b5eb 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py @@ -90,6 +90,10 @@ def __init__( ): super().__init__() + self._polynomial_orig = polynomial + self._output_dtype_map_orig = output_dtype_map + self._options_orig = options + if uniform_1d is None: raise ImportError( "Failed to construct SegmentedPolynomialFromUniform1dJit: " @@ -218,6 +222,23 @@ def __init__( self.BATCH_DIM_BATCHED = BATCH_DIM_BATCHED self.BATCH_DIM_INDEXED = BATCH_DIM_INDEXED + def __reduce__(self): + from cuequivariance_torch.primitives.segmented_polynomial import ( + SegmentedPolynomial, + ) + + return ( + SegmentedPolynomial, + ( + self._polynomial_orig, + "uniform_1d", + self.math_dtype, + self._output_dtype_map_orig, + self.name, + self._options_orig, + ), + ) + def forward( self, inputs: List[torch.Tensor], From 7f28ba0773daa235bc77e9faa60c27ac513ee9ee Mon Sep 17 00:00:00 2001 From: Franco Pellegrini <16089353+phiandark@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:41:51 -0700 Subject: [PATCH 2/3] Replace deprecated is_fx_tracing with is_fx_symbolic_tracing Made-with: Cursor --- .../cuequivariance_torch/primitives/segmented_polynomial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py index e3c7e576..49fe671c 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py @@ -306,7 +306,7 @@ def forward( if ( not torch.jit.is_tracing() and not torch.compiler.is_compiling() - and not torch.fx._symbolic_trace.is_fx_tracing() + and not torch.fx._symbolic_trace.is_fx_symbolic_tracing() ): torch._assert( len(inputs) == self.num_inputs, From 323ef72991fe35940cb30fa0a3e262d9ac3fa28b Mon Sep 17 00:00:00 2001 From: Franco Pellegrini <16089353+phiandark@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:46:12 -0700 Subject: [PATCH 3/3] Fix naive backend _no_input to support torch.compile (#265) Precompute descriptor-derived values (einsum equation, output size, path indices) in __init__ instead of referencing the descriptor closure at forward time. Dynamo cannot trace custom Subscripts operations. Made-with: Cursor --- .../primitives/segmented_polynomial_naive.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_naive.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_naive.py index 522cc1cd..704aa28b 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_naive.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_naive.py @@ -188,6 +188,18 @@ class _no_input(torch.nn.Module): def __init__(self, descriptor: cue.SegmentedTensorProduct): super().__init__() + self._output_size = descriptor.operands[-1].size + self._math_dtype = math_dtype + self._num_paths = descriptor.num_paths + self._path_output_indices = [ + path.indices[-1] for path in descriptor.paths + ] + self._einsum_eq = ( + str(descriptor.coefficient_subscripts) + + "->" + + str(descriptor.subscripts.operands[-1]) + ) + for pid, path in enumerate(descriptor.paths): if math_dtype is not None: self.register_buffer( @@ -205,23 +217,21 @@ def __init__(self, descriptor: cue.SegmentedTensorProduct): ) def forward(self): - if math_dtype is not None: + if self._math_dtype is not None: output = torch.zeros( - (descriptor.operands[-1].size,), + (self._output_size,), device=self.c0.device, - dtype=math_dtype, + dtype=self._math_dtype, ) else: output = torch.zeros( - (descriptor.operands[-1].size,), + (self._output_size,), device=self.c0.device, dtype=self.c0.dtype, ) - for pid in range(descriptor.num_paths): - output[descriptor.paths[pid].indices[-1]] += torch.einsum( - descriptor.coefficient_subscripts - + "->" - + descriptor.subscripts.operands[-1], + for pid in range(self._num_paths): + output[self._path_output_indices[pid]] += torch.einsum( + self._einsum_eq, getattr(self, f"c{pid}"), ) return output