diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py index 1b09bb7..49fe671 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py @@ -148,6 +148,9 @@ def __init__( ): super().__init__() + self._polynomial_orig = polynomial + self._math_dtype_orig = math_dtype + self._output_dtype_map_orig = output_dtype_map self.num_inputs = polynomial.num_inputs self.num_outputs = polynomial.num_outputs self.method = method @@ -181,29 +184,67 @@ def __init__( ) if method == "uniform_1d": - self.m = SegmentedPolynomialFromUniform1dJit( - polynomial, math_dtype, output_dtype_map, name, self.options - ) - self.fallback = self.m + try: + self.m = SegmentedPolynomialFromUniform1dJit( + polynomial, math_dtype, output_dtype_map, name, self.options + ) + except ImportError: + method = "naive" + warnings.warn( + "uniform_1d backend is not available. " + "Falling back to naive implementation." + ) + self.m = SegmentedPolynomialNaive( + polynomial, math_dtype, output_dtype_map, name + ) elif method == "naive": self.m = SegmentedPolynomialNaive( polynomial, math_dtype, output_dtype_map, name ) - self.fallback = self.m elif method == "fused_tp": - self.m = SegmentedPolynomialFusedTP( - polynomial, math_dtype, output_dtype_map, name - ) - self.fallback = SegmentedPolynomialNaive( - polynomial, math_dtype, output_dtype_map, name - ) + try: + self.m = SegmentedPolynomialFusedTP( + polynomial, math_dtype, output_dtype_map, name + ) + except ImportError: + method = "naive" + warnings.warn( + "fused_tp backend is not available. " + "Falling back to naive implementation." + ) + self.m = SegmentedPolynomialNaive( + polynomial, math_dtype, output_dtype_map, name + ) elif method == "indexed_linear": - self.m = SegmentedPolynomialIndexedLinear( - polynomial, math_dtype, output_dtype_map, name - ) - self.fallback = self.m + try: + self.m = SegmentedPolynomialIndexedLinear( + polynomial, math_dtype, output_dtype_map, name + ) + except ImportError: + method = "naive" + warnings.warn( + "indexed_linear backend is not available. " + "Falling back to naive implementation." + ) + self.m = SegmentedPolynomialNaive( + polynomial, math_dtype, output_dtype_map, name + ) else: raise ValueError(f"Invalid method: {method}") + self.method = method + + def __reduce__(self): + return ( + SegmentedPolynomial, + ( + self._polynomial_orig, + self.method, + self._math_dtype_orig, + self._output_dtype_map_orig, + "segmented_polynomial", + self.options, + ), + ) def __repr__(self): return self.repr + f"\n{super().__repr__()}" @@ -265,7 +306,7 @@ def forward( if ( not torch.jit.is_tracing() and not torch.compiler.is_compiling() - and not torch.fx._symbolic_trace.is_fx_tracing() + and not torch.fx._symbolic_trace.is_fx_symbolic_tracing() ): torch._assert( len(inputs) == self.num_inputs, @@ -304,7 +345,17 @@ def forward( warnings.warn( "Fused TP is not supported on CPU. Falling back to naive implementation." ) - return self.fallback( + if not hasattr(self, "_cpu_fallback"): + object.__setattr__( + self, + "_cpu_fallback", + SegmentedPolynomialNaive( + self._polynomial_orig, + self._math_dtype_orig, + self._output_dtype_map_orig, + ), + ) + return self._cpu_fallback( inputs, input_indices, output_shapes, output_indices ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py index 90f3416..8d1d072 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py @@ -81,6 +81,8 @@ def __init__( name: str = "segmented_polynomial", ): super().__init__() + self._polynomial_orig = polynomial + self._output_dtype_map_orig = output_dtype_map self.num_inputs = polynomial.num_inputs self.num_outputs = polynomial.num_outputs self.input_sizes = [o.size for o in polynomial.inputs] @@ -164,6 +166,22 @@ def __init__( ) self.b_outs.append(b_out - self.num_inputs) + def __reduce__(self): + from cuequivariance_torch.primitives.segmented_polynomial import ( + SegmentedPolynomial, + ) + + return ( + SegmentedPolynomial, + ( + self._polynomial_orig, + "fused_tp", + self.math_dtype, + self._output_dtype_map_orig, + self.name, + ), + ) + def forward( self, inputs: List[torch.Tensor], diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_indexed_linear.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_indexed_linear.py index 6dd7a4e..6c98daf 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_indexed_linear.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_indexed_linear.py @@ -165,6 +165,9 @@ def __init__( name: str = "segmented_polynomial", ): super().__init__() + self._polynomial_orig = polynomial + self._output_dtype_map_orig = output_dtype_map + self._math_dtype_orig = math_dtype self.num_inputs = polynomial.num_inputs self.num_outputs = polynomial.num_outputs self.input_sizes = [o.size for o in polynomial.inputs] @@ -225,6 +228,22 @@ def __init__( ) self.tps.append(IndexedLinear(d, signature[0], signature[1:], math_dtype)) + def __reduce__(self): + from cuequivariance_torch.primitives.segmented_polynomial import ( + SegmentedPolynomial, + ) + + return ( + SegmentedPolynomial, + ( + self._polynomial_orig, + "indexed_linear", + self._math_dtype_orig, + self._output_dtype_map_orig, + self.name, + ), + ) + def forward( self, inputs: List[torch.Tensor], diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_naive.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_naive.py index 522cc1c..704aa28 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_naive.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_naive.py @@ -188,6 +188,18 @@ class _no_input(torch.nn.Module): def __init__(self, descriptor: cue.SegmentedTensorProduct): super().__init__() + self._output_size = descriptor.operands[-1].size + self._math_dtype = math_dtype + self._num_paths = descriptor.num_paths + self._path_output_indices = [ + path.indices[-1] for path in descriptor.paths + ] + self._einsum_eq = ( + str(descriptor.coefficient_subscripts) + + "->" + + str(descriptor.subscripts.operands[-1]) + ) + for pid, path in enumerate(descriptor.paths): if math_dtype is not None: self.register_buffer( @@ -205,23 +217,21 @@ def __init__(self, descriptor: cue.SegmentedTensorProduct): ) def forward(self): - if math_dtype is not None: + if self._math_dtype is not None: output = torch.zeros( - (descriptor.operands[-1].size,), + (self._output_size,), device=self.c0.device, - dtype=math_dtype, + dtype=self._math_dtype, ) else: output = torch.zeros( - (descriptor.operands[-1].size,), + (self._output_size,), device=self.c0.device, dtype=self.c0.dtype, ) - for pid in range(descriptor.num_paths): - output[descriptor.paths[pid].indices[-1]] += torch.einsum( - descriptor.coefficient_subscripts - + "->" - + descriptor.subscripts.operands[-1], + for pid in range(self._num_paths): + output[self._path_output_indices[pid]] += torch.einsum( + self._einsum_eq, getattr(self, f"c{pid}"), ) return output diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py index dfefe8a..7662b5e 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py @@ -90,6 +90,10 @@ def __init__( ): super().__init__() + self._polynomial_orig = polynomial + self._output_dtype_map_orig = output_dtype_map + self._options_orig = options + if uniform_1d is None: raise ImportError( "Failed to construct SegmentedPolynomialFromUniform1dJit: " @@ -218,6 +222,23 @@ def __init__( self.BATCH_DIM_BATCHED = BATCH_DIM_BATCHED self.BATCH_DIM_INDEXED = BATCH_DIM_INDEXED + def __reduce__(self): + from cuequivariance_torch.primitives.segmented_polynomial import ( + SegmentedPolynomial, + ) + + return ( + SegmentedPolynomial, + ( + self._polynomial_orig, + "uniform_1d", + self.math_dtype, + self._output_dtype_map_orig, + self.name, + self._options_orig, + ), + ) + def forward( self, inputs: List[torch.Tensor],