Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def __init__(
):
super().__init__()

self._polynomial_orig = polynomial
self._math_dtype_orig = math_dtype
self._output_dtype_map_orig = output_dtype_map
self.num_inputs = polynomial.num_inputs
self.num_outputs = polynomial.num_outputs
self.method = method
Expand Down Expand Up @@ -181,29 +184,67 @@ def __init__(
)

if method == "uniform_1d":
self.m = SegmentedPolynomialFromUniform1dJit(
polynomial, math_dtype, output_dtype_map, name, self.options
)
self.fallback = self.m
try:
self.m = SegmentedPolynomialFromUniform1dJit(
polynomial, math_dtype, output_dtype_map, name, self.options
)
except ImportError:
method = "naive"
warnings.warn(
"uniform_1d backend is not available. "
"Falling back to naive implementation."
)
self.m = SegmentedPolynomialNaive(
polynomial, math_dtype, output_dtype_map, name
)
elif method == "naive":
self.m = SegmentedPolynomialNaive(
polynomial, math_dtype, output_dtype_map, name
)
self.fallback = self.m
elif method == "fused_tp":
self.m = SegmentedPolynomialFusedTP(
polynomial, math_dtype, output_dtype_map, name
)
self.fallback = SegmentedPolynomialNaive(
polynomial, math_dtype, output_dtype_map, name
)
try:
self.m = SegmentedPolynomialFusedTP(
polynomial, math_dtype, output_dtype_map, name
)
except ImportError:
method = "naive"
warnings.warn(
"fused_tp backend is not available. "
"Falling back to naive implementation."
)
self.m = SegmentedPolynomialNaive(
polynomial, math_dtype, output_dtype_map, name
)
elif method == "indexed_linear":
self.m = SegmentedPolynomialIndexedLinear(
polynomial, math_dtype, output_dtype_map, name
)
self.fallback = self.m
try:
self.m = SegmentedPolynomialIndexedLinear(
polynomial, math_dtype, output_dtype_map, name
)
except ImportError:
method = "naive"
warnings.warn(
"indexed_linear backend is not available. "
"Falling back to naive implementation."
)
self.m = SegmentedPolynomialNaive(
polynomial, math_dtype, output_dtype_map, name
)
else:
raise ValueError(f"Invalid method: {method}")
self.method = method

def __reduce__(self):
return (
SegmentedPolynomial,
(
self._polynomial_orig,
self.method,
self._math_dtype_orig,
self._output_dtype_map_orig,
"segmented_polynomial",
self.options,
),
)

def __repr__(self):
return self.repr + f"\n{super().__repr__()}"
Expand Down Expand Up @@ -265,7 +306,7 @@ def forward(
if (
not torch.jit.is_tracing()
and not torch.compiler.is_compiling()
and not torch.fx._symbolic_trace.is_fx_tracing()
and not torch.fx._symbolic_trace.is_fx_symbolic_tracing()
):
torch._assert(
len(inputs) == self.num_inputs,
Expand Down Expand Up @@ -304,7 +345,17 @@ def forward(
warnings.warn(
"Fused TP is not supported on CPU. Falling back to naive implementation."
)
return self.fallback(
if not hasattr(self, "_cpu_fallback"):
object.__setattr__(
self,
"_cpu_fallback",
SegmentedPolynomialNaive(
self._polynomial_orig,
self._math_dtype_orig,
self._output_dtype_map_orig,
),
)
return self._cpu_fallback(
inputs, input_indices, output_shapes, output_indices
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(
name: str = "segmented_polynomial",
):
super().__init__()
self._polynomial_orig = polynomial
self._output_dtype_map_orig = output_dtype_map
self.num_inputs = polynomial.num_inputs
self.num_outputs = polynomial.num_outputs
self.input_sizes = [o.size for o in polynomial.inputs]
Expand Down Expand Up @@ -164,6 +166,22 @@ def __init__(
)
self.b_outs.append(b_out - self.num_inputs)

def __reduce__(self):
from cuequivariance_torch.primitives.segmented_polynomial import (
SegmentedPolynomial,
)

return (
SegmentedPolynomial,
(
self._polynomial_orig,
"fused_tp",
self.math_dtype,
self._output_dtype_map_orig,
self.name,
),
)

def forward(
self,
inputs: List[torch.Tensor],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ def __init__(
name: str = "segmented_polynomial",
):
super().__init__()
self._polynomial_orig = polynomial
self._output_dtype_map_orig = output_dtype_map
self._math_dtype_orig = math_dtype
self.num_inputs = polynomial.num_inputs
self.num_outputs = polynomial.num_outputs
self.input_sizes = [o.size for o in polynomial.inputs]
Expand Down Expand Up @@ -225,6 +228,22 @@ def __init__(
)
self.tps.append(IndexedLinear(d, signature[0], signature[1:], math_dtype))

def __reduce__(self):
from cuequivariance_torch.primitives.segmented_polynomial import (
SegmentedPolynomial,
)

return (
SegmentedPolynomial,
(
self._polynomial_orig,
"indexed_linear",
self._math_dtype_orig,
self._output_dtype_map_orig,
self.name,
),
)

def forward(
self,
inputs: List[torch.Tensor],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,18 @@ class _no_input(torch.nn.Module):
def __init__(self, descriptor: cue.SegmentedTensorProduct):
super().__init__()

self._output_size = descriptor.operands[-1].size
self._math_dtype = math_dtype
self._num_paths = descriptor.num_paths
self._path_output_indices = [
path.indices[-1] for path in descriptor.paths
]
self._einsum_eq = (
str(descriptor.coefficient_subscripts)
+ "->"
+ str(descriptor.subscripts.operands[-1])
)

for pid, path in enumerate(descriptor.paths):
if math_dtype is not None:
self.register_buffer(
Expand All @@ -205,23 +217,21 @@ def __init__(self, descriptor: cue.SegmentedTensorProduct):
)

def forward(self):
if math_dtype is not None:
if self._math_dtype is not None:
output = torch.zeros(
(descriptor.operands[-1].size,),
(self._output_size,),
device=self.c0.device,
dtype=math_dtype,
dtype=self._math_dtype,
)
else:
output = torch.zeros(
(descriptor.operands[-1].size,),
(self._output_size,),
device=self.c0.device,
dtype=self.c0.dtype,
)
for pid in range(descriptor.num_paths):
output[descriptor.paths[pid].indices[-1]] += torch.einsum(
descriptor.coefficient_subscripts
+ "->"
+ descriptor.subscripts.operands[-1],
for pid in range(self._num_paths):
output[self._path_output_indices[pid]] += torch.einsum(
self._einsum_eq,
getattr(self, f"c{pid}"),
)
return output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def __init__(
):
super().__init__()

self._polynomial_orig = polynomial
self._output_dtype_map_orig = output_dtype_map
self._options_orig = options

if uniform_1d is None:
raise ImportError(
"Failed to construct SegmentedPolynomialFromUniform1dJit: "
Expand Down Expand Up @@ -218,6 +222,23 @@ def __init__(
self.BATCH_DIM_BATCHED = BATCH_DIM_BATCHED
self.BATCH_DIM_INDEXED = BATCH_DIM_INDEXED

def __reduce__(self):
from cuequivariance_torch.primitives.segmented_polynomial import (
SegmentedPolynomial,
)

return (
SegmentedPolynomial,
(
self._polynomial_orig,
"uniform_1d",
self.math_dtype,
self._output_dtype_map_orig,
self.name,
self._options_orig,
),
)

def forward(
self,
inputs: List[torch.Tensor],
Expand Down
Loading