diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py index dea08256de71..21388dbef7a0 100644 --- a/python/tvm/relax/frontend/torch/dynamo.py +++ b/python/tvm/relax/frontend/torch/dynamo.py @@ -64,14 +64,6 @@ def to_torch_tensor(nd_tensor): else: raise ValueError(f"Unsupported type {type(nd_tensor)}") - def to_tvm_tensor(torch_tensor): - """A helper function to transfer a torch.tensor to Tensor.""" - if not isinstance(torch_tensor, torch._subclasses.fake_tensor.FakeTensor): - return tvm.runtime.tensor(torch_tensor.numpy()) - # Fake Tensor - real_tensor = torch.randn(torch_tensor.shape, dtype=torch_tensor.dtype) - return tvm.runtime.tensor(real_tensor.numpy()) - graph_module.graph.eliminate_dead_code() device = device_from_inputs(example_inputs) @@ -139,7 +131,10 @@ def exec_tvm(*i_args): for arg in args: if arg.requires_grad: arg = arg.detach() - vm_args.append(to_tvm_tensor(arg)) + if isinstance(arg, torch._subclasses.fake_tensor.FakeTensor): + # Materialize a real (eager) Tensor + arg = torch.randn(arg.shape, dtype=arg.dtype, device=device) + vm_args.append(arg) outputs = vm["main"](*vm_args) return to_torch_tensor(outputs)