diff --git a/impl/ascend/functions/unique.cpp b/impl/ascend/functions/unique.cpp index 3b1f87ebe..bfd033501 100644 --- a/impl/ascend/functions/unique.cpp +++ b/impl/ascend/functions/unique.cpp @@ -27,16 +27,8 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio makeTensor(ctx, outTmpAt, {inputAt.numel()}, inputAt.dtype()); } - // allocate temp inverse tensor - diopiTensorHandle_t inverseTmp = nullptr; - AscendTensor inverseTmpAt(inverseTmp); bool returnInverse = (indices != nullptr) ? true : false; std::vector zeroShape = {0}; - if (returnInverse || returnCounts) { - makeTensor(ctx, inverseTmpAt, inputAt.shape(), diopi_dtype_int64); - } else { - makeTensor(ctx, inverseTmpAt, zeroShape, diopi_dtype_int64); - } // allocate temp counts tensor diopiTensorHandle_t countsTmp = nullptr; @@ -48,8 +40,23 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio } // call aclnnUnique2 - auto params = ::impl::ascend::aclnn_adaptor::convertParams(input, sorted, returnInverse, returnCounts, outTmpAt, inverseTmpAt, countsTmpAt).params(); - DIOPI_ASECND_CALL_ACLNN_TYPE_SYNC(aclnnUnique2, ctx, params); + std::tuple params; + if (returnInverse) { + params = ::impl::ascend::aclnn_adaptor::convertParams(input, sorted, returnInverse, returnCounts, outTmpAt, indices, countsTmpAt).params(); + DIOPI_ASECND_CALL_ACLNN_TYPE_SYNC(aclnnUnique2, ctx, params); + } else { + // allocate temp inverse tensor + diopiTensorHandle_t inverseTmp = nullptr; + AscendTensor inverseTmpAt(inverseTmp); + makeTensor(ctx, inverseTmpAt, zeroShape, diopi_dtype_int64); + if (returnCounts) { + makeTensor(ctx, inverseTmpAt, inputAt.shape(), diopi_dtype_int64); + } else { + makeTensor(ctx, inverseTmpAt, zeroShape, diopi_dtype_int64); + } + params = ::impl::ascend::aclnn_adaptor::convertParams(input, sorted, returnInverse, returnCounts, outTmpAt, inverseTmpAt, countsTmpAt).params(); + DIOPI_ASECND_CALL_ACLNN_TYPE_SYNC(aclnnUnique2, ctx, params); + } // get true outShape by aclGetViewShape int64_t* viewDims = nullptr; @@ -65,11 +72,6 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio AscendTensor outReshapeAt = reshape(ctx, outTmpAt, {viewDims, viewDims + viewDimNum}); *out = const_cast(outReshapeAt.tensorHandle()); - // fill indices tensor - if (returnInverse) { - indices = const_cast(inverseTmpAt.tensorHandle()); - } - // fill counts tensor if (returnCounts) { // get counts tensor shape, counts tensor is the 7th tensor in aclnnUnique2, index = 6