diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/conversion.py b/dlinfer/graph/dicp/vendor/AtbGraph/conversion.py index 7f11d16b..92b6c1e6 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/conversion.py +++ b/dlinfer/graph/dicp/vendor/AtbGraph/conversion.py @@ -95,7 +95,7 @@ def __init__(self, gm): @lru_cache def get_shared_zeros(self, size, dtype): - return self.get_proxy(atb_op.Zeros, (size, dtype)) + return self.aten_zeros_default(size, dtype) def get_proxy(self, target, args, kwargs=immutable_dict()): proxy = super().get_proxy(target, args, kwargs) @@ -791,11 +791,11 @@ def aten_new_empty(self, x, size, pin_memory=False): def aten_slice_scatter(self, x, data, dim=0, start=None, end=None, step=1): return self.get_proxy(atb_op.SliceScatter, (x, data, dim, start, end, step)) - @register_conversion(torch.ops.dlinfer.dynamic_quant.default) + @register_conversion("torch.ops.dlinfer.dynamic_quant.default") def dlinfer_dynamic_quant(self, x, quant_dtype, quant_granularity): return self.get_proxy(atb_op.AclNnDynamicQuant, (x, quant_dtype)) - @register_conversion(torch.ops.dlinfer.linear_w8a8.default) + @register_conversion("torch.ops.dlinfer.linear_w8a8.default") def dlinfer_linear_w8a8( self, x, y, rms_scale, linear_scale, out_type, quant_dtype, bias ): @@ -840,7 +840,7 @@ def aten_topk(self, x, k, dim=-1, largest=True, sorted=True): assert largest == True return self.get_proxy(atb_op.Sort, (x, k)) - @register_conversion(torch.ops.dlinfer.transdata.default) + @register_conversion("torch.ops.dlinfer.transdata.default") def dlinfer_transdata(self, x, transdata_type): return self.get_proxy(atb_op.Transdata, (x, transdata_type))