From 7319a30eb9a6001136e8c3446796b3de0f3767be Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 8 Feb 2026 00:27:47 +0900 Subject: [PATCH 1/2] add a testcase --- tests/python/relax/test_frontend_dynamo.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index a48907eae5dd..b6d23455711b 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -123,6 +123,25 @@ def main( tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5, atol=1e-5) +def test_relax_dynamo_scalar_params(): + class ScalarParams(torch.nn.Module): + def __init__(self): + super().__init__() + self.x = torch.nn.Parameter(torch.tensor(1.0)) + self.y = torch.nn.Parameter(torch.tensor(2.0)) + + def forward(self): + return self.x + self.y + + model = ScalarParams() + + opt_model = torch.compile(model, backend=relax_dynamo()) + + default_output = model().detach().numpy() + optimized_output = opt_model().detach().numpy() + tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5, atol=1e-5) + + def test_relax_dynamo_dynamic(): class Input1(torch.nn.Module): def __init__(self): From f0bbad5bae18f3c6feb54cafdf28ab4943a0fdf7 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 8 Feb 2026 00:34:10 +0900 Subject: [PATCH 2/2] fix --- .../tvm/relax/frontend/torch/base_fx_graph_translator.py | 4 ++-- python/tvm/relax/frontend/torch/dynamo.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index d04dfbb6c35a..447f4a4dc627 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -505,9 +505,9 @@ def call_binary_op(op, lhs, rhs): lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return call_binary_op(relax_op, lhs, rhs) - elif isinstance(lhs, relax.expr.Constant): + elif isinstance(lhs, relax.expr.Constant) and not isinstance(rhs, relax.expr.Constant): return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)) - elif isinstance(rhs, relax.expr.Constant): + elif isinstance(rhs, relax.expr.Constant) and not isinstance(lhs, relax.expr.Constant): return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs) return intrinsic_op(lhs, rhs) diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py index 8dc9e2a55aef..dea08256de71 100644 --- a/python/tvm/relax/frontend/torch/dynamo.py +++ b/python/tvm/relax/frontend/torch/dynamo.py @@ -137,10 +137,9 @@ def exec_tvm(*i_args): args = [a.contiguous() for a in i_args if isinstance(a, torch.Tensor)] vm_args = list() for arg in args: - if arg.dim() != 0: - if arg.requires_grad: - arg = arg.detach() - vm_args.append(to_tvm_tensor(arg)) + if arg.requires_grad: + arg = arg.detach() + vm_args.append(to_tvm_tensor(arg)) outputs = vm["main"](*vm_args) return to_torch_tensor(outputs)