Run the following script. Note that my code hits the same issue without setting torch._dynamo.config.capture_dynamic_output_shape_ops = True (so the failure is not limited to that case), but it was significantly easier to create a repro that did use it
"""Minimal repro: cuequivariance fails with unbacked SymInts from torch.compile.
Bug: segmented_polynomial_fused_tp.py:205 does `if size != 1:` with unbacked SymInt.
"""
import torch
from cuequivariance import Irreps
from cuequivariance.group_theory.descriptors import full_tensor_product
from cuequivariance_torch import SegmentedPolynomial
torch._dynamo.config.capture_dynamic_output_shape_ops = True
def create_cg_module(device):
lhs_irreps = Irreps("SO3", [(4, 0), (2, 1)])
rhs_irreps = Irreps("SO3", [(3, 0)])
equation = full_tensor_product(lhs_irreps, rhs_irreps, None)
return SegmentedPolynomial(equation.polynomial, method="fused_tp").to(device)
@torch.compile(fullgraph=True)
def test_with_cuequivariance(positions, radius, cg_module, call_cg):
# Create unbacked SymInt via nonzero
dist = torch.cdist(positions, positions)
mask = dist < radius
edge_indices = torch.nonzero(mask)
num_edges = edge_indices.shape[0] # unbacked SymInt 'u0'
# Create features with dynamic shape
lhs = torch.randn(num_edges, 10, device=positions.device)
rhs = torch.randn(num_edges, 3, device=positions.device)
if not call_cg:
return lhs.sum() + rhs.sum() # Works fine with unbacked SymInts
# FAILS HERE: cuequivariance can't handle unbacked SymInt
result = cg_module([lhs, rhs])
return result[0]
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cg_module = create_cg_module(device)
positions = torch.randn(1, 50, 3, device=device)
# Test 1: Without cuequivariance - works
print("Test 1 (no cuequivariance):", end=" ")
result = test_with_cuequivariance(positions, 0.5, cg_module, False)
print(f"✓ Pass")
# Test 2: With cuequivariance - fails
print("Test 2 (with cuequivariance):", end=" ")
result = test_with_cuequivariance(positions, 0.5, cg_module, True)
print(f"✗ Unexpected pass")
ryan@ryan-dev-box:~/src/env$ TORCHDYNAMO_VERBOSE=1 python /home/ryan/src/env/cuequivariance_bug_final_repro.py 2>&1
/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py:89: UserWarning: `math_dtype` is not provided for method `fused_tp`: using float32.
warnings.warn(
Test 1 (no cuequivariance): ✓ Pass
Test 2 (with cuequivariance): Traceback (most recent call last):
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/tensor.py", line 1410, in evaluate_expr
return guard_scalar(self.sym_num)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1519, in guard_scalar
return guard_bool(a)
^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1711, in guard_bool
return a.node.guard_bool("", 0) # NB: uses Python backtrace
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 538, in guard_bool
r = self.evaluate()
^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 512, in evaluate
return self.shape_env.evaluate_sym_node(self, size_oblivious)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7223, in evaluate_sym_node
return self.evaluate_expr(
^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7323, in evaluate_expr
return self._inner_evaluate_expr(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 272, in wrapper
return retlog(fn(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7346, in _inner_evaluate_expr
return self._evaluate_expr(
^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7570, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Ne(u0, 1) (unhinted: Ne(u0, 1)). (Size-like symbols: u0)
ATTENTION: guard_size_oblivious would fix the error, evaluating expression to True.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.
Caused by: if size != 1: # cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py:205 in forward (_dynamo/variables/tensor.py:1410 in evaluate_expr)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/home/ryan/src/env/cuequivariance_bug_final_repro.py", line 37, in test_with_cuequivariance
result = cg_module([lhs, rhs])
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial.py", line 283, in forward
return self.m(inputs, input_indices, output_shapes, output_indices)
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py", line 205, in forward
if size != 1:
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/ryan/src/env/cuequivariance_bug_final_repro.py", line 53, in <module>
result = test_with_cuequivariance(positions, 0.5, cg_module, True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 736, in compile_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1495, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 629, in __call__
return _compile(
^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1111, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 793, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 832, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
transformations(instructions, code_options)
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 267, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 753, in transform
tracer.run()
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3497, in run
super().run()
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
while self.step():
^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 834, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2910, in CALL
self._call(inst)
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2904, in _call
self.call_function(fn, args, kwargs)
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1193, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py", line 1000, in call_function
return variables.UserFunctionVariable(fn, source=source).call_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 529, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1210, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3698, in inline_call
return tracer.inline_call_()
^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3901, in inline_call_
self.run()
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
while self.step():
^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 834, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2910, in CALL
self._call(inst)
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2904, in _call
self.call_function(fn, args, kwargs)
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1193, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py", line 1000, in call_function
return variables.UserFunctionVariable(fn, source=source).call_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 529, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1210, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3698, in inline_call
return tracer.inline_call_()
^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3901, in inline_call_
self.run()
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
while self.step():
^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 790, in inner
eval_result = value.evaluate_expr(self.output)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/tensor.py", line 1415, in evaluate_expr
raise UserError( # noqa: B904
torch._dynamo.exc.UserError: Consider annotating your code using torch._check*(). Could not guard on data-dependent expression Ne(u0, 1) (unhinted: Ne(u0, 1)). (Size-like symbols: u0)
ATTENTION: guard_size_oblivious would fix the error, evaluating expression to True.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.
Caused by: if size != 1: # cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py:205 in forward (_dynamo/variables/tensor.py:1410 in evaluate_expr)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/home/ryan/src/env/cuequivariance_bug_final_repro.py", line 37, in test_with_cuequivariance
result = cg_module([lhs, rhs])
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial.py", line 283, in forward
return self.m(inputs, input_indices, output_shapes, output_indices)
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py", line 205, in forward
if size != 1:
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example
from user code:
File "/home/ryan/src/env/cuequivariance_bug_final_repro.py", line 37, in test_with_cuequivariance
result = cg_module([lhs, rhs])
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial.py", line 283, in forward
return self.m(inputs, input_indices, output_shapes, output_indices)
File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py", line 205, in forward
if size != 1:
Describe the bug
Unbacked
SymInts (i.e. when the size isnt known at trace time) causeSegmentedPolynomialto crash. This was just the first call I tried it on, I would expect the same issue is present elsewhere tooTo Reproduce
Run the following script. Note that my code hits the same issue without setting
torch._dynamo.config.capture_dynamic_output_shape_ops = True(so the failure is not limited to that case), but it was significantly easier to create a repro that did use itIt crashes with error
Expected behavior
A clear and concise description of what you expected to happen.
Screenshots
If applicable, add screenshots to help explain your problem.
GPU HW/SW(please complete the following information):
Here is the full details on my system, as previously reported for a pytorch bug.