Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,9 @@ def call_binary_op(op, lhs, rhs):
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return call_binary_op(relax_op, lhs, rhs)
elif isinstance(lhs, relax.expr.Constant):
elif isinstance(lhs, relax.expr.Constant) and not isinstance(rhs, relax.expr.Constant):
return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype))
elif isinstance(rhs, relax.expr.Constant):
elif isinstance(rhs, relax.expr.Constant) and not isinstance(lhs, relax.expr.Constant):
return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs)
return intrinsic_op(lhs, rhs)
Comment on lines 506 to 512
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this change correctly fixes the issue with binary operations on two constants, the logic can be simplified. The call_binary_op function already contains promote_binary_op_args, which handles the promotion of Python scalars to Relax constants when one of the operands is a Relax expression. We can leverage this to make the code more concise and readable.

            if isinstance(lhs, relax.Expr) or isinstance(rhs, relax.Expr):
                return call_binary_op(relax_op, lhs, rhs)
            return intrinsic_op(lhs, rhs)


Expand Down
7 changes: 3 additions & 4 deletions python/tvm/relax/frontend/torch/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,9 @@ def exec_tvm(*i_args):
args = [a.contiguous() for a in i_args if isinstance(a, torch.Tensor)]
vm_args = list()
for arg in args:
if arg.dim() != 0:
if arg.requires_grad:
arg = arg.detach()
vm_args.append(to_tvm_tensor(arg))
if arg.requires_grad:
arg = arg.detach()
vm_args.append(to_tvm_tensor(arg))
outputs = vm["main"](*vm_args)
return to_torch_tensor(outputs)

Expand Down
19 changes: 19 additions & 0 deletions tests/python/relax/test_frontend_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,25 @@ def main(
tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5, atol=1e-5)


def test_relax_dynamo_scalar_params():
class ScalarParams(torch.nn.Module):
def __init__(self):
super().__init__()
self.x = torch.nn.Parameter(torch.tensor(1.0))
self.y = torch.nn.Parameter(torch.tensor(2.0))

def forward(self):
return self.x + self.y

model = ScalarParams()

opt_model = torch.compile(model, backend=relax_dynamo())

default_output = model().detach().numpy()
optimized_output = opt_model().detach().numpy()
tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5, atol=1e-5)


def test_relax_dynamo_dynamic():
class Input1(torch.nn.Module):
def __init__(self):
Expand Down
Loading