From 513aa3e3fdeb95082ec6deaa9c4bf1eb8c601fbf Mon Sep 17 00:00:00 2001 From: maibaodexiaohangjiaya Date: Tue, 30 Jul 2024 06:56:56 +0000 Subject: [PATCH 1/5] replace acllnnOp penalty and context --- impl/ascend/aclnn/adaptor.hpp | 2 +- impl/ascend/common/utils.cpp | 11 ++ impl/ascend/common/utils.hpp | 2 + impl/ascend/device_configs.py | 18 --- impl/ascend/functions_ext/apply_penalty.cpp | 88 +++++++++++ .../context_attention_inference.cpp | 146 ++++++++++++++++++ impl/ascend_npu/CMakeLists.txt | 1 + impl/ascend_npu/ascend_config.yaml | 4 +- 8 files changed, 251 insertions(+), 21 deletions(-) create mode 100644 impl/ascend/functions_ext/context_attention_inference.cpp diff --git a/impl/ascend/aclnn/adaptor.hpp b/impl/ascend/aclnn/adaptor.hpp index 0d67d4093..88356f5dc 100644 --- a/impl/ascend/aclnn/adaptor.hpp +++ b/impl/ascend/aclnn/adaptor.hpp @@ -53,7 +53,7 @@ inline void* getOpApiFuncAddr(const char* apiName) { constexpr const char kOpApiLibName[] = "libopapi.so"; static void* opApiHandler = getOpApiLibHandler(kOpApiLibName); if (opApiHandler == nullptr) { - return nullptr; + error(__FILE__, __LINE__, __FUNCTION__, "Error: Failed to get opApi handler for %s.", apiName); } return getOpApiFuncAddrInLib(opApiHandler, kOpApiLibName, apiName); } diff --git a/impl/ascend/common/utils.cpp b/impl/ascend/common/utils.cpp index 5ba21de71..462beecf7 100644 --- a/impl/ascend/common/utils.cpp +++ b/impl/ascend/common/utils.cpp @@ -760,5 +760,16 @@ diopiError_t autoCastTensorType(diopiContextHandle_t ctx, const std::vector(stream))); + CALL_ACLRT(aclrtMemcpyAsync( + ptrHost, at.numel() * at.elemsize(), at.data(), at.numel() * at.elemsize(), ACL_MEMCPY_DEVICE_TO_HOST, reinterpret_cast(stream))); + CALL_ACLRT(aclrtSynchronizeStream(reinterpret_cast(stream))); + return ptrHost; +} + } // namespace ascend } // namespace impl diff --git a/impl/ascend/common/utils.hpp b/impl/ascend/common/utils.hpp index 80b1ce056..af86baf1b 100644 --- a/impl/ascend/common/utils.hpp +++ b/impl/ascend/common/utils.hpp @@ -123,6 +123,8 @@ diopiError_t fillNan(diopiContextHandle_t ctx, AscendTensor& src); diopiError_t autoCastTensorType(diopiContextHandle_t ctx, const std::vector& pTensors, const std::set& opSupportedDtype); +void* AscendTensorDeviceToHost(diopiContextHandle_t ctx, AscendTensor at); + } // namespace ascend } // namespace impl diff --git a/impl/ascend/device_configs.py b/impl/ascend/device_configs.py index 441352a02..a6d2049c6 100755 --- a/impl/ascend/device_configs.py +++ b/impl/ascend/device_configs.py @@ -1166,24 +1166,6 @@ ], ), ), - - # 'apply_penalty': dict( - # name=['apply_penalty'], - # tensor_para=dict( - # args=[ - # { - # "ins": ['logits'], - # "dtype": [Skip(np.float64)], - # }, - # ] - # ) - # ), - - # TODO(zhangqiu) Due to a bug in the software stack, this test will be skipped for now. - 'apply_penalty': dict( - name=['apply_penalty'], - skip_all=True - ), # TODO(zhangqiu) Due to a bug in the software stack, this test will be skipped for now. 'embedding': dict( diff --git a/impl/ascend/functions_ext/apply_penalty.cpp b/impl/ascend/functions_ext/apply_penalty.cpp index ddcda9cf4..67ca3270b 100644 --- a/impl/ascend/functions_ext/apply_penalty.cpp +++ b/impl/ascend/functions_ext/apply_penalty.cpp @@ -8,10 +8,98 @@ #include #include "../aclnn/adaptor.hpp" +#include "impl_functions.hpp" namespace impl { namespace ascend { +diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t logits, diopiConstTensorHandle_t presencePenalty, + diopiConstTensorHandle_t frequencyPenalty, diopiConstTensorHandle_t pTokenIds, diopiConstTensorHandle_t pTokenCounts, + diopiConstTensorHandle_t pCumsumSeqLen, int pMaxLenInBatch) { + AscendTensor logitsAt(logits); + AscendTensor pCumsumSeqLenAt(pCumsumSeqLen); + AscendTensor frequencyPenaltyAt(frequencyPenalty); + AscendTensor presencePenaltyAt(presencePenalty); + AscendTensor pTokenIdsAt(pTokenIds); + AscendTensor pTokenCountsAt(pTokenCounts); + + int batch = logitsAt.shape(0); + const int64_t dim = 0; + diopiDtype_t logitsDtype = logitsAt.dtype(); + + void* curBatchIndexDataPtrHost = AscendTensorDeviceToHost(ctx, pCumsumSeqLenAt); + void* frequencyPenaltyDataPtrHost = AscendTensorDeviceToHost(ctx, frequencyPenaltyAt); + void* presencePenaltyDataPtrHost = AscendTensorDeviceToHost(ctx, presencePenaltyAt); + for (int i = 0; i < batch; ++i) { + int curBatchStartIndex = reinterpret_cast(curBatchIndexDataPtrHost)[i]; + int curBatchEndIndex = reinterpret_cast(curBatchIndexDataPtrHost)[i + 1]; + AscendTensor sliceAt; + std::vector sliceAtShape(1, curBatchEndIndex - curBatchStartIndex); + makeTensor(ctx, sliceAt, sliceAtShape, diopi_dtype_int32); + const diopiScalar_t curBatchStartIndexScalar = constructDiopiScalarT(diopi_dtype_int32, curBatchStartIndex); + const diopiScalar_t curBatchEndIndexScalar = constructDiopiScalarT(diopi_dtype_int32, curBatchEndIndex); + const diopiScalar_t stepScalar = constructDiopiScalarT(diopi_dtype_int32, 1); + DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &curBatchStartIndexScalar, &curBatchEndIndexScalar, &stepScalar, sliceAt); + + diopiTensorHandle_t curTokenIds; + diopiConstTensorHandle_t sliceTensorHandle = sliceAt.tensorHandle(); + ascend_npu::diopiIndex(ctx, &curTokenIds, pTokenIds, &sliceTensorHandle, 1); + + diopiTensorHandle_t curTokenCounts; + ascend_npu::diopiIndex(ctx, &curTokenCounts, pTokenCounts, &sliceTensorHandle, 1); + + AscendTensor curTokenIdsAt(curTokenIds); + AscendTensor curTokenCountsAt(curTokenCounts); + AscendTensor curLogitsAt; + std::vector curLogitsAtShape(1); + curLogitsAtShape[dim] = curTokenIdsAt.shape()[0]; + makeTensor(ctx, curLogitsAt, curLogitsAtShape, logitsDtype); + AscendTensor logitsAtI; + makeTensor(ctx, logitsAtI, {1, logitsAt.shape()[1]}, logitsDtype); + diopiScalar_t iScalar = constructDiopiScalarT(diopi_dtype_int32, i); + AscendTensor iTensorAt; + makeTensorFromScalar(ctx, iTensorAt, &iScalar, logitsAt.device()); + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, logitsAt, dim, iTensorAt, logitsAtI); + + logitsAtI.view({logitsAt.shape()[1]}); + + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, logitsAtI, dim, curTokenIds, curLogitsAt); + AscendTensor frequencyPenaltyAdjustmentAt; + makeTensor(ctx, frequencyPenaltyAdjustmentAt, curTokenCountsAt.shape(), logitsDtype); + + diopiScalar_t frequencyPenaltyAt_i_Scalar; + if (logitsDtype == diopi_dtype_float32) { + frequencyPenaltyAt_i_Scalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(frequencyPenaltyDataPtrHost)[i]); + } else { + frequencyPenaltyAt_i_Scalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(frequencyPenaltyDataPtrHost)[i]); + } + DIOPI_ASCEND_CALL_ACLNN(aclnnMuls, ctx, curTokenCounts, &frequencyPenaltyAt_i_Scalar, frequencyPenaltyAdjustmentAt); + + AscendTensor totalPenaltyAdjustmentAt; + makeTensor(ctx, totalPenaltyAdjustmentAt, curTokenCountsAt.shape(), logitsDtype); + + diopiScalar_t presencePenaltyAt_i_Scalar; + if (logitsDtype == diopi_dtype_float32) { + presencePenaltyAt_i_Scalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(presencePenaltyDataPtrHost)[i]); + } else { + presencePenaltyAt_i_Scalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(presencePenaltyDataPtrHost)[i]); + } + diopiScalar_t oneScalar = constructDiopiScalarT(logitsDtype, 1); + DIOPI_ASCEND_CALL_ACLNN(aclnnAdds, ctx, frequencyPenaltyAdjustmentAt, &presencePenaltyAt_i_Scalar, &oneScalar, totalPenaltyAdjustmentAt); + + DIOPI_ASCEND_CALL_ACLNN(aclnnSub, ctx, curLogitsAt, totalPenaltyAdjustmentAt, &oneScalar, curLogitsAt); + std::vector indices; + indices.emplace_back(iTensorAt); + indices.emplace_back(curTokenIdsAt); + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, logitsAt, indices, curLogitsAt, false, true); + } + free(curBatchIndexDataPtrHost); + free(frequencyPenaltyDataPtrHost); + free(presencePenaltyDataPtrHost); + + return diopiSuccess; +} + diopiError_t diopiApplyPenaltyV2(diopiContextHandle_t ctx, diopiTensorHandle_t logits, diopiConstTensorHandle_t presencePenalty, diopiConstTensorHandle_t frequencyPenalty, diopiConstTensorHandle_t repetitionPenalty, diopiConstTensorHandle_t pTokenIds, diopiConstTensorHandle_t pTokenCounts) { diff --git a/impl/ascend/functions_ext/context_attention_inference.cpp b/impl/ascend/functions_ext/context_attention_inference.cpp new file mode 100644 index 000000000..cfbf628fc --- /dev/null +++ b/impl/ascend/functions_ext/context_attention_inference.cpp @@ -0,0 +1,146 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2024, DeepLink. + */ + +#include +#include + +#include "../aclnn/adaptor.hpp" +#include "impl_functions.hpp" + +namespace impl { +namespace ascend { + +AscendTensor torchContextAttention(diopiContextHandle_t ctx, AscendTensor xq, AscendTensor xk, AscendTensor xv, int batchSize, int seqLen, int head, int dim) { + xq.view({batchSize, seqLen, head, dim}); + AscendTensor xqTransposeAt; + std::vector xqTransposeAtShape(xq.shape()); + int64_t tmp = xqTransposeAtShape[1]; + xqTransposeAtShape[1] = xqTransposeAtShape[2]; + xqTransposeAtShape[2] = tmp; + makeTensor(ctx, xqTransposeAt, xqTransposeAtShape, xq.dtype()); + std::vector xqTransposeDims = {0, 2, 1, 3}; + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, xq, xqTransposeDims, xqTransposeAt); + + xk.view({batchSize, seqLen, head, dim}); + AscendTensor xkTransposeAt; + makeTensor(ctx, xkTransposeAt, xqTransposeAtShape, xk.dtype()); + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, xk, xqTransposeDims, xkTransposeAt); + + xv.view({batchSize, seqLen, head, dim}); + AscendTensor xvTransposeAt; + makeTensor(ctx, xvTransposeAt, xqTransposeAtShape, xv.dtype()); + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, xv, xqTransposeDims, xvTransposeAt); + + AscendTensor maskAt; + makeTensor(ctx, maskAt, {1, 1, seqLen, seqLen}, diopi_dtype_float32); + AscendTensor onesMatrixAt; + makeTensor(ctx, onesMatrixAt, {seqLen, seqLen}, diopi_dtype_float32); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceOne, ctx, onesMatrixAt); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceTril, ctx, onesMatrixAt, 0); + maskAt = onesMatrixAt.unsqueeze(0).unsqueeze(0); + diopiScalar_t valueScalar = constructDiopiScalarT(diopi_dtype_float32, -100000000.0); + AscendTensor maskMatrixAt; + makeTensor(ctx, maskMatrixAt, maskAt.shape(), diopi_dtype_int32); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceOne, ctx, maskMatrixAt); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceTriu, ctx, maskMatrixAt, 1); + AscendTensor maskBoolMatrixAt; + makeTensor(ctx, maskBoolMatrixAt, maskMatrixAt.shape(), diopi_dtype_bool); + DIOPI_ASCEND_CALL_ACLNN(aclnnCast, ctx, maskMatrixAt, diopi_dtype_bool, maskBoolMatrixAt); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceMaskedFillScalar, ctx, maskAt, maskBoolMatrixAt, &valueScalar); + AscendTensor maskRepeatAt; + std::vector maskRepeatAtShape(maskAt.shape()); + maskRepeatAtShape[0] *= batchSize; + maskRepeatAtShape[1] *= head; + makeTensor(ctx, maskRepeatAt, maskRepeatAtShape, maskAt.dtype()); + std::vector repeats = {batchSize, head, 1, 1}; + DIOPI_ASCEND_CALL_ACLNN(aclnnRepeat, ctx, maskAt, repeats, maskRepeatAt); + + AscendTensor scoresAt; + AscendTensor xkTransposeAt2; + std::vector xkTransposeAtShape(xkTransposeAt.shape()); + tmp = xkTransposeAtShape[2]; + xkTransposeAtShape[2] = xkTransposeAtShape[3]; + xkTransposeAtShape[3] = tmp; + makeTensor(ctx, xkTransposeAt2, xkTransposeAtShape, xk.dtype()); + std::vector xkTransposeAt_2Dims = {0, 1, 3, 2}; + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, xkTransposeAt, xkTransposeAt_2Dims, xkTransposeAt2); + std::vector scoresShapeAt = xqTransposeAt.shape(); + scoresShapeAt[3] = xkTransposeAtShape[3]; + makeTensor(ctx, scoresAt, scoresShapeAt, xq.dtype()); + DIOPI_ASCEND_CALL_ACLNN(aclnnMatmul, ctx, xqTransposeAt, xkTransposeAt2, scoresAt, (int8_t)0); + diopiScalar_t otherScalar = constructDiopiScalarT(diopi_dtype_float32, std::sqrt(dim)); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceDivs, ctx, scoresAt, &otherScalar); + + AscendTensor adjustedScoresAt; + std::vector adjusted_scoresAtShape = inferSize(scoresAt.shape(), maskRepeatAt.shape()); + makeTensor(ctx, adjustedScoresAt, adjusted_scoresAtShape, scoresAt.dtype()); + diopiScalar_t alphaScalar = constructDiopiScalarT(diopi_dtype_float32, 1); + DIOPI_ASCEND_CALL_ACLNN(aclnnAdd, ctx, scoresAt, maskRepeatAt, &alphaScalar, adjustedScoresAt); + DIOPI_ASCEND_CALL_ACLNN(aclnnSoftmax, ctx, adjustedScoresAt, adjustedScoresAt.dim() - 1, adjustedScoresAt); + AscendTensor outputAt; + std::vector outputAtShape = adjustedScoresAt.shape(); + outputAtShape[3] = xvTransposeAt.shape(3); + makeTensor(ctx, outputAt, outputAtShape, scoresAt.dtype()); + DIOPI_ASCEND_CALL_ACLNN(aclnnMatmul, ctx, adjustedScoresAt, xvTransposeAt, outputAt, (int8_t)0); + tmp = outputAtShape[1]; + outputAtShape[1] = outputAtShape[2]; + outputAtShape[2] = tmp; + AscendTensor outputTransposeAt; + makeTensor(ctx, outputTransposeAt, outputAtShape, outputAt.dtype()); + std::vector outputTransposeDims = {0, 2, 1, 3}; + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, outputAt, outputTransposeDims, outputTransposeAt); + outputTransposeAt.view({outputTransposeAt.numel() / static_cast(head * dim), head, dim}); + + return outputTransposeAt; +} + +diopiError_t diopiContextAttentionInference(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t q, diopiConstTensorHandle_t k, + diopiConstTensorHandle_t v, diopiConstTensorHandle_t bStartLoc, diopiConstTensorHandle_t bSeqLen, int maxInputLen) { + AscendTensor bStartLocAt(bStartLoc); + AscendTensor qAt(q); + AscendTensor bSeqLenAt(bSeqLen); + + int batch = bStartLocAt.shape()[0]; + int head = qAt.shape()[1]; + int dim = qAt.shape()[2]; + + void* bStartLocDataPtr = AscendTensorDeviceToHost(ctx, bStartLocAt); + void* bSeqLenDataPtr = AscendTensorDeviceToHost(ctx, bSeqLenAt); + for (int i = 0; i < batch; ++i) { + int start = reinterpret_cast(bStartLocDataPtr)[i]; + int end = start + reinterpret_cast(bSeqLenDataPtr)[i]; + + AscendTensor sliceAt; + std::vector sliceAtShape(1, end - start); + makeTensor(ctx, sliceAt, sliceAtShape, diopi_dtype_int32); + diopiScalar_t startIndexScalar = constructDiopiScalarT(diopi_dtype_int32, start); + diopiScalar_t endIndexScalar = constructDiopiScalarT(diopi_dtype_int32, end); + diopiScalar_t stepScalar = constructDiopiScalarT(diopi_dtype_int32, 1); + DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &startIndexScalar, &endIndexScalar, &stepScalar, sliceAt); + + diopiTensorHandle_t qIndex; + diopiConstTensorHandle_t sliceTensorHandle = sliceAt.tensorHandle(); + ascend_npu::diopiIndex(ctx, &qIndex, q, &sliceTensorHandle, 1); + + diopiTensorHandle_t kIndex; + ascend_npu::diopiIndex(ctx, &kIndex, k, &sliceTensorHandle, 1); + + diopiTensorHandle_t vIndex; + ascend_npu::diopiIndex(ctx, &vIndex, v, &sliceTensorHandle, 1); + + AscendTensor valuesAt; + AscendTensor qIndexAt(qIndex), kIndexAt(kIndex), vIndexAt(vIndex); + valuesAt = torchContextAttention(ctx, qIndexAt, kIndexAt, vIndexAt, 1, reinterpret_cast(bSeqLenDataPtr)[i], head, dim); + + std::vector indices = {sliceAt}; + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, out, indices, valuesAt, false, true); + } + + return diopiSuccess; +} + +} // namespace ascend +} // namespace impl diff --git a/impl/ascend_npu/CMakeLists.txt b/impl/ascend_npu/CMakeLists.txt index 44da6db15..21b21b220 100755 --- a/impl/ascend_npu/CMakeLists.txt +++ b/impl/ascend_npu/CMakeLists.txt @@ -198,6 +198,7 @@ set(OLD_IMPL_SRC ${OLD_IMPL_DIR}/functions_ext/adamw.cpp ${OLD_IMPL_DIR}/functions_ext/destindex_copy_kv.cpp ${OLD_IMPL_DIR}/functions_ext/apply_penalty.cpp + ${OLD_IMPL_DIR}/functions_ext/context_attention_inference.cpp ${OLD_IMPL_DIR}/functions_ext/flash_attention.cpp ${OLD_IMPL_DIR}/functions_ext/flash_attention_varlen.cpp ${OLD_IMPL_DIR}/functions_ext/prompt_flash_attention.cpp diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml index 4adf5db16..aabd43fdf 100755 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -15,6 +15,7 @@ ascend: - diopiAddmm - diopiAll - diopiAny +- diopiApplyPenalty - diopiApplyPenaltyV2 - diopiArange - diopiArgmax @@ -55,6 +56,7 @@ ascend: - diopiClampMinScalar - diopiClampScalar - diopiCol2Im +- diopiContextAttentionInference - diopiContiguous - diopiConvolution2d - diopiConvolution2dBackward @@ -261,8 +263,6 @@ ascend: - diopiZeroInp - diopiZeros ascend_npu: -- diopiApplyPenalty -- diopiContextAttentionInference - diopiGetNativeMemoryFormat - diopiIndex - diopiIndexBackward From 207210aeffbdb94c59d078c8ea4602a64cb56b6a Mon Sep 17 00:00:00 2001 From: maibaodexiaohangjiaya Date: Tue, 30 Jul 2024 07:20:32 +0000 Subject: [PATCH 2/5] fix style --- impl/ascend/common/utils.cpp | 2 +- impl/ascend/common/utils.hpp | 2 +- impl/ascend/functions_ext/apply_penalty.cpp | 22 +++++++++---------- .../context_attention_inference.cpp | 12 +++++----- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/impl/ascend/common/utils.cpp b/impl/ascend/common/utils.cpp index 462beecf7..2cac7727b 100644 --- a/impl/ascend/common/utils.cpp +++ b/impl/ascend/common/utils.cpp @@ -760,7 +760,7 @@ diopiError_t autoCastTensorType(diopiContextHandle_t ctx, const std::vector& pTensors, const std::set& opSupportedDtype); -void* AscendTensorDeviceToHost(diopiContextHandle_t ctx, AscendTensor at); +void* ascendTensorDeviceToHost(diopiContextHandle_t ctx, AscendTensor at); } // namespace ascend } // namespace impl diff --git a/impl/ascend/functions_ext/apply_penalty.cpp b/impl/ascend/functions_ext/apply_penalty.cpp index 67ca3270b..8defbfe9b 100644 --- a/impl/ascend/functions_ext/apply_penalty.cpp +++ b/impl/ascend/functions_ext/apply_penalty.cpp @@ -27,9 +27,9 @@ diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t log const int64_t dim = 0; diopiDtype_t logitsDtype = logitsAt.dtype(); - void* curBatchIndexDataPtrHost = AscendTensorDeviceToHost(ctx, pCumsumSeqLenAt); - void* frequencyPenaltyDataPtrHost = AscendTensorDeviceToHost(ctx, frequencyPenaltyAt); - void* presencePenaltyDataPtrHost = AscendTensorDeviceToHost(ctx, presencePenaltyAt); + void* curBatchIndexDataPtrHost = ascendTensorDeviceToHost(ctx, pCumsumSeqLenAt); + void* frequencyPenaltyDataPtrHost = ascendTensorDeviceToHost(ctx, frequencyPenaltyAt); + void* presencePenaltyDataPtrHost = ascendTensorDeviceToHost(ctx, presencePenaltyAt); for (int i = 0; i < batch; ++i) { int curBatchStartIndex = reinterpret_cast(curBatchIndexDataPtrHost)[i]; int curBatchEndIndex = reinterpret_cast(curBatchIndexDataPtrHost)[i + 1]; @@ -67,25 +67,25 @@ diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t log AscendTensor frequencyPenaltyAdjustmentAt; makeTensor(ctx, frequencyPenaltyAdjustmentAt, curTokenCountsAt.shape(), logitsDtype); - diopiScalar_t frequencyPenaltyAt_i_Scalar; + diopiScalar_t frequencyPenaltyAtIScalar; if (logitsDtype == diopi_dtype_float32) { - frequencyPenaltyAt_i_Scalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(frequencyPenaltyDataPtrHost)[i]); + frequencyPenaltyAtIScalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(frequencyPenaltyDataPtrHost)[i]); } else { - frequencyPenaltyAt_i_Scalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(frequencyPenaltyDataPtrHost)[i]); + frequencyPenaltyAtIScalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(frequencyPenaltyDataPtrHost)[i]); } - DIOPI_ASCEND_CALL_ACLNN(aclnnMuls, ctx, curTokenCounts, &frequencyPenaltyAt_i_Scalar, frequencyPenaltyAdjustmentAt); + DIOPI_ASCEND_CALL_ACLNN(aclnnMuls, ctx, curTokenCounts, &frequencyPenaltyAtIScalar, frequencyPenaltyAdjustmentAt); AscendTensor totalPenaltyAdjustmentAt; makeTensor(ctx, totalPenaltyAdjustmentAt, curTokenCountsAt.shape(), logitsDtype); - diopiScalar_t presencePenaltyAt_i_Scalar; + diopiScalar_t presencePenaltyAtIScalar; if (logitsDtype == diopi_dtype_float32) { - presencePenaltyAt_i_Scalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(presencePenaltyDataPtrHost)[i]); + presencePenaltyAtIScalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(presencePenaltyDataPtrHost)[i]); } else { - presencePenaltyAt_i_Scalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(presencePenaltyDataPtrHost)[i]); + presencePenaltyAtIScalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(presencePenaltyDataPtrHost)[i]); } diopiScalar_t oneScalar = constructDiopiScalarT(logitsDtype, 1); - DIOPI_ASCEND_CALL_ACLNN(aclnnAdds, ctx, frequencyPenaltyAdjustmentAt, &presencePenaltyAt_i_Scalar, &oneScalar, totalPenaltyAdjustmentAt); + DIOPI_ASCEND_CALL_ACLNN(aclnnAdds, ctx, frequencyPenaltyAdjustmentAt, &presencePenaltyAtIScalar, &oneScalar, totalPenaltyAdjustmentAt); DIOPI_ASCEND_CALL_ACLNN(aclnnSub, ctx, curLogitsAt, totalPenaltyAdjustmentAt, &oneScalar, curLogitsAt); std::vector indices; diff --git a/impl/ascend/functions_ext/context_attention_inference.cpp b/impl/ascend/functions_ext/context_attention_inference.cpp index cfbf628fc..6f4c7c353 100644 --- a/impl/ascend/functions_ext/context_attention_inference.cpp +++ b/impl/ascend/functions_ext/context_attention_inference.cpp @@ -65,8 +65,8 @@ AscendTensor torchContextAttention(diopiContextHandle_t ctx, AscendTensor xq, As xkTransposeAtShape[2] = xkTransposeAtShape[3]; xkTransposeAtShape[3] = tmp; makeTensor(ctx, xkTransposeAt2, xkTransposeAtShape, xk.dtype()); - std::vector xkTransposeAt_2Dims = {0, 1, 3, 2}; - DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, xkTransposeAt, xkTransposeAt_2Dims, xkTransposeAt2); + std::vector xkTransposeAt2Dims = {0, 1, 3, 2}; + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, xkTransposeAt, xkTransposeAt2Dims, xkTransposeAt2); std::vector scoresShapeAt = xqTransposeAt.shape(); scoresShapeAt[3] = xkTransposeAtShape[3]; makeTensor(ctx, scoresAt, scoresShapeAt, xq.dtype()); @@ -75,8 +75,8 @@ AscendTensor torchContextAttention(diopiContextHandle_t ctx, AscendTensor xq, As DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceDivs, ctx, scoresAt, &otherScalar); AscendTensor adjustedScoresAt; - std::vector adjusted_scoresAtShape = inferSize(scoresAt.shape(), maskRepeatAt.shape()); - makeTensor(ctx, adjustedScoresAt, adjusted_scoresAtShape, scoresAt.dtype()); + std::vector adjustedScoresAtShape = inferSize(scoresAt.shape(), maskRepeatAt.shape()); + makeTensor(ctx, adjustedScoresAt, adjustedScoresAtShape, scoresAt.dtype()); diopiScalar_t alphaScalar = constructDiopiScalarT(diopi_dtype_float32, 1); DIOPI_ASCEND_CALL_ACLNN(aclnnAdd, ctx, scoresAt, maskRepeatAt, &alphaScalar, adjustedScoresAt); DIOPI_ASCEND_CALL_ACLNN(aclnnSoftmax, ctx, adjustedScoresAt, adjustedScoresAt.dim() - 1, adjustedScoresAt); @@ -107,8 +107,8 @@ diopiError_t diopiContextAttentionInference(diopiContextHandle_t ctx, diopiTenso int head = qAt.shape()[1]; int dim = qAt.shape()[2]; - void* bStartLocDataPtr = AscendTensorDeviceToHost(ctx, bStartLocAt); - void* bSeqLenDataPtr = AscendTensorDeviceToHost(ctx, bSeqLenAt); + void* bStartLocDataPtr = ascendTensorDeviceToHost(ctx, bStartLocAt); + void* bSeqLenDataPtr = ascendTensorDeviceToHost(ctx, bSeqLenAt); for (int i = 0; i < batch; ++i) { int start = reinterpret_cast(bStartLocDataPtr)[i]; int end = start + reinterpret_cast(bSeqLenDataPtr)[i]; From 8e41845f526bb6d374b31484a37f46390bc96ebe Mon Sep 17 00:00:00 2001 From: maibaodexiaohangjiaya Date: Wed, 14 Aug 2024 04:31:05 +0000 Subject: [PATCH 3/5] merge main --- impl/ascend/common/utils.cpp | 11 ------- impl/ascend/common/utils.hpp | 2 -- impl/ascend/functions_ext/apply_penalty.cpp | 29 +++++++++++-------- .../context_attention_inference.cpp | 15 ++++++---- 4 files changed, 27 insertions(+), 30 deletions(-) diff --git a/impl/ascend/common/utils.cpp b/impl/ascend/common/utils.cpp index 8520f4af5..b29465270 100644 --- a/impl/ascend/common/utils.cpp +++ b/impl/ascend/common/utils.cpp @@ -821,16 +821,5 @@ diopiError_t autoCastTensorType(diopiContextHandle_t ctx, const std::vector(stream))); - CALL_ACLRT(aclrtMemcpyAsync( - ptrHost, at.numel() * at.elemsize(), at.data(), at.numel() * at.elemsize(), ACL_MEMCPY_DEVICE_TO_HOST, reinterpret_cast(stream))); - CALL_ACLRT(aclrtSynchronizeStream(reinterpret_cast(stream))); - return ptrHost; -} - } // namespace ascend } // namespace impl diff --git a/impl/ascend/common/utils.hpp b/impl/ascend/common/utils.hpp index 28accbbe5..05314907d 100644 --- a/impl/ascend/common/utils.hpp +++ b/impl/ascend/common/utils.hpp @@ -125,8 +125,6 @@ diopiError_t fillNan(diopiContextHandle_t ctx, AscendTensor& src); diopiError_t autoCastTensorType(diopiContextHandle_t ctx, const std::vector& pTensors, const std::set& opSupportedDtype); -void* ascendTensorDeviceToHost(diopiContextHandle_t ctx, AscendTensor at); - } // namespace ascend } // namespace impl diff --git a/impl/ascend/functions_ext/apply_penalty.cpp b/impl/ascend/functions_ext/apply_penalty.cpp index 8defbfe9b..df8eea89b 100644 --- a/impl/ascend/functions_ext/apply_penalty.cpp +++ b/impl/ascend/functions_ext/apply_penalty.cpp @@ -8,6 +8,7 @@ #include #include "../aclnn/adaptor.hpp" +#include "../common/acloprunner.hpp" #include "impl_functions.hpp" namespace impl { @@ -27,12 +28,15 @@ diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t log const int64_t dim = 0; diopiDtype_t logitsDtype = logitsAt.dtype(); - void* curBatchIndexDataPtrHost = ascendTensorDeviceToHost(ctx, pCumsumSeqLenAt); - void* frequencyPenaltyDataPtrHost = ascendTensorDeviceToHost(ctx, frequencyPenaltyAt); - void* presencePenaltyDataPtrHost = ascendTensorDeviceToHost(ctx, presencePenaltyAt); + AscendTensor curBatchIndexHostAt = deviceToHostSync(ctx, pCumsumSeqLenAt); + AscendTensor frequencyPenaltyHostAt = deviceToHostSync(ctx, frequencyPenaltyAt); + AscendTensor presencePenaltyHostAt = deviceToHostSync(ctx, presencePenaltyAt); + + const int *curBatchIndexData = reinterpret_cast(curBatchIndexHostAt.data()); + for (int i = 0; i < batch; ++i) { - int curBatchStartIndex = reinterpret_cast(curBatchIndexDataPtrHost)[i]; - int curBatchEndIndex = reinterpret_cast(curBatchIndexDataPtrHost)[i + 1]; + int curBatchStartIndex = *(curBatchIndexData + i); + int curBatchEndIndex = *(curBatchIndexData + (i + 1)); AscendTensor sliceAt; std::vector sliceAtShape(1, curBatchEndIndex - curBatchStartIndex); makeTensor(ctx, sliceAt, sliceAtShape, diopi_dtype_int32); @@ -69,9 +73,11 @@ diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t log diopiScalar_t frequencyPenaltyAtIScalar; if (logitsDtype == diopi_dtype_float32) { - frequencyPenaltyAtIScalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(frequencyPenaltyDataPtrHost)[i]); + const float *frequencyPenaltyData = reinterpret_cast(frequencyPenaltyHostAt.data()); + frequencyPenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(frequencyPenaltyData + i)); } else { - frequencyPenaltyAtIScalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(frequencyPenaltyDataPtrHost)[i]); + const half_float::half *frequencyPenaltyData = reinterpret_cast(frequencyPenaltyHostAt.data()); + frequencyPenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(frequencyPenaltyData + i)); } DIOPI_ASCEND_CALL_ACLNN(aclnnMuls, ctx, curTokenCounts, &frequencyPenaltyAtIScalar, frequencyPenaltyAdjustmentAt); @@ -80,9 +86,11 @@ diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t log diopiScalar_t presencePenaltyAtIScalar; if (logitsDtype == diopi_dtype_float32) { - presencePenaltyAtIScalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(presencePenaltyDataPtrHost)[i]); + const float *presencePenaltyData = reinterpret_cast(presencePenaltyHostAt.data()); + presencePenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(presencePenaltyData + i)); } else { - presencePenaltyAtIScalar = constructDiopiScalarT(logitsDtype, reinterpret_cast(presencePenaltyDataPtrHost)[i]); + const half_float::half *presencePenaltyData = reinterpret_cast(presencePenaltyHostAt.data()); + presencePenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(presencePenaltyData + i)); } diopiScalar_t oneScalar = constructDiopiScalarT(logitsDtype, 1); DIOPI_ASCEND_CALL_ACLNN(aclnnAdds, ctx, frequencyPenaltyAdjustmentAt, &presencePenaltyAtIScalar, &oneScalar, totalPenaltyAdjustmentAt); @@ -93,9 +101,6 @@ diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t log indices.emplace_back(curTokenIdsAt); DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, logitsAt, indices, curLogitsAt, false, true); } - free(curBatchIndexDataPtrHost); - free(frequencyPenaltyDataPtrHost); - free(presencePenaltyDataPtrHost); return diopiSuccess; } diff --git a/impl/ascend/functions_ext/context_attention_inference.cpp b/impl/ascend/functions_ext/context_attention_inference.cpp index 6f4c7c353..51311b4bd 100644 --- a/impl/ascend/functions_ext/context_attention_inference.cpp +++ b/impl/ascend/functions_ext/context_attention_inference.cpp @@ -9,6 +9,7 @@ #include "../aclnn/adaptor.hpp" #include "impl_functions.hpp" +#include "../common/acloprunner.hpp" namespace impl { namespace ascend { @@ -107,11 +108,15 @@ diopiError_t diopiContextAttentionInference(diopiContextHandle_t ctx, diopiTenso int head = qAt.shape()[1]; int dim = qAt.shape()[2]; - void* bStartLocDataPtr = ascendTensorDeviceToHost(ctx, bStartLocAt); - void* bSeqLenDataPtr = ascendTensorDeviceToHost(ctx, bSeqLenAt); + AscendTensor bStartLocHostAt = deviceToHostSync(ctx, bStartLocAt); + AscendTensor bSeqLenHostAt = deviceToHostSync(ctx, bSeqLenAt); + + const int *bStartLocData = reinterpret_cast(bStartLocAt.data()); + const int *bSeqLenData = reinterpret_cast(bSeqLenAt.data()); + for (int i = 0; i < batch; ++i) { - int start = reinterpret_cast(bStartLocDataPtr)[i]; - int end = start + reinterpret_cast(bSeqLenDataPtr)[i]; + int start = *(bStartLocData + i); + int end = start + *(bSeqLenData + i); AscendTensor sliceAt; std::vector sliceAtShape(1, end - start); @@ -133,7 +138,7 @@ diopiError_t diopiContextAttentionInference(diopiContextHandle_t ctx, diopiTenso AscendTensor valuesAt; AscendTensor qIndexAt(qIndex), kIndexAt(kIndex), vIndexAt(vIndex); - valuesAt = torchContextAttention(ctx, qIndexAt, kIndexAt, vIndexAt, 1, reinterpret_cast(bSeqLenDataPtr)[i], head, dim); + valuesAt = torchContextAttention(ctx, qIndexAt, kIndexAt, vIndexAt, 1, *(bSeqLenData + i), head, dim); std::vector indices = {sliceAt}; DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, out, indices, valuesAt, false, true); From 22b25e162f5dccd8b68068cf2e865627140cc71b Mon Sep 17 00:00:00 2001 From: maibaodexiaohangjiaya Date: Wed, 14 Aug 2024 05:31:58 +0000 Subject: [PATCH 4/5] add conv2d_transpose --- impl/ascend/device_configs.py | 12 --- impl/ascend/functions/conv2d.cpp | 100 ++++++++++++++++++ .../context_attention_inference.cpp | 4 +- impl/ascend_npu/ascend_config.yaml | 2 + 4 files changed, 104 insertions(+), 14 deletions(-) diff --git a/impl/ascend/device_configs.py b/impl/ascend/device_configs.py index fca196da5..de4658e95 100755 --- a/impl/ascend/device_configs.py +++ b/impl/ascend/device_configs.py @@ -470,18 +470,6 @@ ), ), - 'conv_transpose2d': dict( - name=['conv_transpose2d'], - tensor_para=dict( - args=[ - { - "ins": ['input'], - "dtype": [Skip(np.float32),Skip(np.float64),Skip(np.float16),], - }, - ] - ), - ), - 'unfold': dict( name=['unfold'], tensor_para=dict( diff --git a/impl/ascend/functions/conv2d.cpp b/impl/ascend/functions/conv2d.cpp index 5a7395702..4e9fd176d 100644 --- a/impl/ascend/functions/conv2d.cpp +++ b/impl/ascend/functions/conv2d.cpp @@ -122,5 +122,105 @@ diopiError_t diopiConvolution2dBackward(diopiContextHandle_t ctx, diopiTensorHan return diopiSuccess; } +DIOPI_API diopiError_t diopiConvTranspose2d(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, + diopiConstTensorHandle_t bias, diopiSize_t stride, diopiSize_t padding, diopiSize_t output_padding, int64_t groups, + diopiSize_t dilation) { + bool transposed = true; + int8_t cubeMathType = 0; + + ASCEND_CHECK_ABORT(stride.len == 1 || stride.len == 2, "the dim of stride must be 1 or 2!"); + ASCEND_CHECK_ABORT(padding.len == 1 || padding.len == 2, "the dim of padding must be 1 or 2!"); + ASCEND_CHECK_ABORT(dilation.len == 1 || dilation.len == 2, "the dim of dilation must be 1 or 2!"); + + int64_t strideExpandData[2]; + int64_t paddingExpandData[2]; + int64_t dilationExpandData[2]; + + strideExpandData[0] = stride.data[0]; + strideExpandData[1] = (stride.len == 1) ? stride.data[0] : stride.data[1]; + + paddingExpandData[0] = padding.data[0]; + paddingExpandData[1] = (padding.len == 1) ? padding.data[0] : padding.data[1]; + + dilationExpandData[0] = dilation.data[0]; + dilationExpandData[1] = (dilation.len == 1) ? dilation.data[0] : dilation.data[1]; + + DIOPI_ASCEND_CALL_ACLNN(aclnnConvolution, + ctx, + input, + weight, + bias, + diopiSize_t{strideExpandData, 2}, + diopiSize_t{paddingExpandData, 2}, + diopiSize_t{dilationExpandData, 2}, + transposed, + output_padding, + groups, + out, + cubeMathType); + return diopiSuccess; +} + +DIOPI_API diopiError_t diopiConvTranspose2dBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, + diopiTensorHandle_t grad_bias, diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, + diopiConstTensorHandle_t weight, diopiSize_t* bias_sizes, diopiSize_t stride, diopiSize_t padding, + diopiSize_t dilation, diopiSize_t output_padding, int64_t groups) { + bool transposed = true; + int8_t cubeMathType = 0; + + std::array gradMask = {true, true, true}; + if (nullptr == grad_input) { + gradMask[0] = false; + } + if (nullptr == grad_weight) { + gradMask[1] = false; + } + if (nullptr == grad_bias) { + gradMask[2] = false; + } + + ASCEND_CHECK_ABORT(stride.len == 1 || stride.len == 2, "the dim of stride must be 1 or 2!"); + ASCEND_CHECK_ABORT(padding.len == 1 || padding.len == 2, "the dim of padding must be 1 or 2!"); + ASCEND_CHECK_ABORT(dilation.len == 1 || dilation.len == 2, "the dim of dilation must be 1 or 2!"); + + int64_t strideExpandData[2]; + int64_t paddingExpandData[2]; + int64_t dilationExpandData[2]; + + strideExpandData[0] = stride.data[0]; + strideExpandData[1] = (stride.len == 1) ? stride.data[0] : stride.data[1]; + + paddingExpandData[0] = padding.data[0]; + paddingExpandData[1] = (padding.len == 1) ? padding.data[0] : padding.data[1]; + + dilationExpandData[0] = dilation.data[0]; + dilationExpandData[1] = (dilation.len == 1) ? dilation.data[0] : dilation.data[1]; + + AscendTensor gradBiasAt(grad_bias); + std::vector biasShape; + if (grad_bias != nullptr) { + biasShape = gradBiasAt.shape(); + } + + DIOPI_ASCEND_CALL_ACLNN(aclnnConvolutionBackward, + ctx, + grad_output, + input, + weight, + biasShape, + diopiSize_t{strideExpandData, 2}, + diopiSize_t{paddingExpandData, 2}, + diopiSize_t{dilationExpandData, 2}, + transposed, + output_padding, + groups, + gradMask, + cubeMathType, + grad_input, + grad_weight, + grad_bias); + return diopiSuccess; +} + } // namespace ascend } // namespace impl diff --git a/impl/ascend/functions_ext/context_attention_inference.cpp b/impl/ascend/functions_ext/context_attention_inference.cpp index 51311b4bd..57e80f5e9 100644 --- a/impl/ascend/functions_ext/context_attention_inference.cpp +++ b/impl/ascend/functions_ext/context_attention_inference.cpp @@ -111,8 +111,8 @@ diopiError_t diopiContextAttentionInference(diopiContextHandle_t ctx, diopiTenso AscendTensor bStartLocHostAt = deviceToHostSync(ctx, bStartLocAt); AscendTensor bSeqLenHostAt = deviceToHostSync(ctx, bSeqLenAt); - const int *bStartLocData = reinterpret_cast(bStartLocAt.data()); - const int *bSeqLenData = reinterpret_cast(bSeqLenAt.data()); + const int *bStartLocData = reinterpret_cast(bStartLocHostAt.data()); + const int *bSeqLenData = reinterpret_cast(bSeqLenHostAt.data()); for (int i = 0; i < batch; ++i) { int start = *(bStartLocData + i); diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml index d0f47e6b1..f045f683f 100644 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -63,6 +63,8 @@ ascend: - diopiContiguous - diopiConvolution2d - diopiConvolution2dBackward +- diopiConvTranspose2d +- diopiConvTranspose2dBackward - diopiCopyInp - diopiCos - diopiCosInp From 4252aa63788413b6b089423892fc8ac9c62f182f Mon Sep 17 00:00:00 2001 From: maibaodexiaohangjiaya Date: Wed, 14 Aug 2024 06:08:41 +0000 Subject: [PATCH 5/5] fix style and type --- impl/ascend/convert_config.yaml | 6 ++++++ impl/ascend/functions_ext/apply_penalty.cpp | 10 +++++----- .../functions_ext/context_attention_inference.cpp | 6 +++--- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/impl/ascend/convert_config.yaml b/impl/ascend/convert_config.yaml index 50b78be98..6e58f8e4b 100755 --- a/impl/ascend/convert_config.yaml +++ b/impl/ascend/convert_config.yaml @@ -41,6 +41,12 @@ - diopiConvolution2dBackward: dtype: (float64)->float32 +- diopiConvTranspose2d: + dtype: (float64)->float32 + +- diopiConvTranspose2dBackward: + dtype: (float64)->float32 + - diopiAdaptiveAvgPool2d: dtype: (float64)->float32 diff --git a/impl/ascend/functions_ext/apply_penalty.cpp b/impl/ascend/functions_ext/apply_penalty.cpp index df8eea89b..4fc3787d4 100644 --- a/impl/ascend/functions_ext/apply_penalty.cpp +++ b/impl/ascend/functions_ext/apply_penalty.cpp @@ -32,7 +32,7 @@ diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t log AscendTensor frequencyPenaltyHostAt = deviceToHostSync(ctx, frequencyPenaltyAt); AscendTensor presencePenaltyHostAt = deviceToHostSync(ctx, presencePenaltyAt); - const int *curBatchIndexData = reinterpret_cast(curBatchIndexHostAt.data()); + const int* curBatchIndexData = reinterpret_cast(curBatchIndexHostAt.data()); for (int i = 0; i < batch; ++i) { int curBatchStartIndex = *(curBatchIndexData + i); @@ -73,10 +73,10 @@ diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t log diopiScalar_t frequencyPenaltyAtIScalar; if (logitsDtype == diopi_dtype_float32) { - const float *frequencyPenaltyData = reinterpret_cast(frequencyPenaltyHostAt.data()); + const float* frequencyPenaltyData = reinterpret_cast(frequencyPenaltyHostAt.data()); frequencyPenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(frequencyPenaltyData + i)); } else { - const half_float::half *frequencyPenaltyData = reinterpret_cast(frequencyPenaltyHostAt.data()); + const half_float::half* frequencyPenaltyData = reinterpret_cast(frequencyPenaltyHostAt.data()); frequencyPenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(frequencyPenaltyData + i)); } DIOPI_ASCEND_CALL_ACLNN(aclnnMuls, ctx, curTokenCounts, &frequencyPenaltyAtIScalar, frequencyPenaltyAdjustmentAt); @@ -86,10 +86,10 @@ diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t log diopiScalar_t presencePenaltyAtIScalar; if (logitsDtype == diopi_dtype_float32) { - const float *presencePenaltyData = reinterpret_cast(presencePenaltyHostAt.data()); + const float* presencePenaltyData = reinterpret_cast(presencePenaltyHostAt.data()); presencePenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(presencePenaltyData + i)); } else { - const half_float::half *presencePenaltyData = reinterpret_cast(presencePenaltyHostAt.data()); + const half_float::half* presencePenaltyData = reinterpret_cast(presencePenaltyHostAt.data()); presencePenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(presencePenaltyData + i)); } diopiScalar_t oneScalar = constructDiopiScalarT(logitsDtype, 1); diff --git a/impl/ascend/functions_ext/context_attention_inference.cpp b/impl/ascend/functions_ext/context_attention_inference.cpp index 57e80f5e9..3bd88e9b5 100644 --- a/impl/ascend/functions_ext/context_attention_inference.cpp +++ b/impl/ascend/functions_ext/context_attention_inference.cpp @@ -8,8 +8,8 @@ #include #include "../aclnn/adaptor.hpp" -#include "impl_functions.hpp" #include "../common/acloprunner.hpp" +#include "impl_functions.hpp" namespace impl { namespace ascend { @@ -111,8 +111,8 @@ diopiError_t diopiContextAttentionInference(diopiContextHandle_t ctx, diopiTenso AscendTensor bStartLocHostAt = deviceToHostSync(ctx, bStartLocAt); AscendTensor bSeqLenHostAt = deviceToHostSync(ctx, bSeqLenAt); - const int *bStartLocData = reinterpret_cast(bStartLocHostAt.data()); - const int *bSeqLenData = reinterpret_cast(bSeqLenHostAt.data()); + const int* bStartLocData = reinterpret_cast(bStartLocHostAt.data()); + const int* bSeqLenData = reinterpret_cast(bSeqLenHostAt.data()); for (int i = 0; i < batch; ++i) { int start = *(bStartLocData + i);