diff --git a/impl/ascend/aclnn/adaptor.hpp b/impl/ascend/aclnn/adaptor.hpp index f0c4ff953..039daf7df 100644 --- a/impl/ascend/aclnn/adaptor.hpp +++ b/impl/ascend/aclnn/adaptor.hpp @@ -53,7 +53,7 @@ inline void* getOpApiFuncAddr(const char* apiName) { constexpr const char kOpApiLibName[] = "libopapi.so"; static void* opApiHandler = getOpApiLibHandler(kOpApiLibName); if (opApiHandler == nullptr) { - return nullptr; + error(__FILE__, __LINE__, __FUNCTION__, "Error: Failed to get opApi handler for %s.", apiName); } return getOpApiFuncAddrInLib(opApiHandler, kOpApiLibName, apiName); } diff --git a/impl/ascend/convert_config.yaml b/impl/ascend/convert_config.yaml index 50b78be98..6e58f8e4b 100755 --- a/impl/ascend/convert_config.yaml +++ b/impl/ascend/convert_config.yaml @@ -41,6 +41,12 @@ - diopiConvolution2dBackward: dtype: (float64)->float32 +- diopiConvTranspose2d: + dtype: (float64)->float32 + +- diopiConvTranspose2dBackward: + dtype: (float64)->float32 + - diopiAdaptiveAvgPool2d: dtype: (float64)->float32 diff --git a/impl/ascend/device_configs.py b/impl/ascend/device_configs.py index 1377c420e..de4658e95 100755 --- a/impl/ascend/device_configs.py +++ b/impl/ascend/device_configs.py @@ -470,18 +470,6 @@ ), ), - 'conv_transpose2d': dict( - name=['conv_transpose2d'], - tensor_para=dict( - args=[ - { - "ins": ['input'], - "dtype": [Skip(np.float32),Skip(np.float64),Skip(np.float16),], - }, - ] - ), - ), - 'unfold': dict( name=['unfold'], tensor_para=dict( @@ -1158,24 +1146,6 @@ ], ), ), - - # 'apply_penalty': dict( - # name=['apply_penalty'], - # tensor_para=dict( - # args=[ - # { - # "ins": ['logits'], - # "dtype": [Skip(np.float64)], - # }, - # ] - # ) - # ), - - # TODO(zhangqiu) Due to a bug in the software stack, this test will be skipped for now. - 'apply_penalty': dict( - name=['apply_penalty'], - skip_all=True - ), # TODO(zhangqiu) Due to a bug in the software stack, this test will be skipped for now. 'embedding': dict( diff --git a/impl/ascend/functions/conv2d.cpp b/impl/ascend/functions/conv2d.cpp index 5a7395702..4e9fd176d 100644 --- a/impl/ascend/functions/conv2d.cpp +++ b/impl/ascend/functions/conv2d.cpp @@ -122,5 +122,105 @@ diopiError_t diopiConvolution2dBackward(diopiContextHandle_t ctx, diopiTensorHan return diopiSuccess; } +DIOPI_API diopiError_t diopiConvTranspose2d(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, + diopiConstTensorHandle_t bias, diopiSize_t stride, diopiSize_t padding, diopiSize_t output_padding, int64_t groups, + diopiSize_t dilation) { + bool transposed = true; + int8_t cubeMathType = 0; + + ASCEND_CHECK_ABORT(stride.len == 1 || stride.len == 2, "the dim of stride must be 1 or 2!"); + ASCEND_CHECK_ABORT(padding.len == 1 || padding.len == 2, "the dim of padding must be 1 or 2!"); + ASCEND_CHECK_ABORT(dilation.len == 1 || dilation.len == 2, "the dim of dilation must be 1 or 2!"); + + int64_t strideExpandData[2]; + int64_t paddingExpandData[2]; + int64_t dilationExpandData[2]; + + strideExpandData[0] = stride.data[0]; + strideExpandData[1] = (stride.len == 1) ? stride.data[0] : stride.data[1]; + + paddingExpandData[0] = padding.data[0]; + paddingExpandData[1] = (padding.len == 1) ? padding.data[0] : padding.data[1]; + + dilationExpandData[0] = dilation.data[0]; + dilationExpandData[1] = (dilation.len == 1) ? dilation.data[0] : dilation.data[1]; + + DIOPI_ASCEND_CALL_ACLNN(aclnnConvolution, + ctx, + input, + weight, + bias, + diopiSize_t{strideExpandData, 2}, + diopiSize_t{paddingExpandData, 2}, + diopiSize_t{dilationExpandData, 2}, + transposed, + output_padding, + groups, + out, + cubeMathType); + return diopiSuccess; +} + +DIOPI_API diopiError_t diopiConvTranspose2dBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, + diopiTensorHandle_t grad_bias, diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, + diopiConstTensorHandle_t weight, diopiSize_t* bias_sizes, diopiSize_t stride, diopiSize_t padding, + diopiSize_t dilation, diopiSize_t output_padding, int64_t groups) { + bool transposed = true; + int8_t cubeMathType = 0; + + std::array gradMask = {true, true, true}; + if (nullptr == grad_input) { + gradMask[0] = false; + } + if (nullptr == grad_weight) { + gradMask[1] = false; + } + if (nullptr == grad_bias) { + gradMask[2] = false; + } + + ASCEND_CHECK_ABORT(stride.len == 1 || stride.len == 2, "the dim of stride must be 1 or 2!"); + ASCEND_CHECK_ABORT(padding.len == 1 || padding.len == 2, "the dim of padding must be 1 or 2!"); + ASCEND_CHECK_ABORT(dilation.len == 1 || dilation.len == 2, "the dim of dilation must be 1 or 2!"); + + int64_t strideExpandData[2]; + int64_t paddingExpandData[2]; + int64_t dilationExpandData[2]; + + strideExpandData[0] = stride.data[0]; + strideExpandData[1] = (stride.len == 1) ? stride.data[0] : stride.data[1]; + + paddingExpandData[0] = padding.data[0]; + paddingExpandData[1] = (padding.len == 1) ? padding.data[0] : padding.data[1]; + + dilationExpandData[0] = dilation.data[0]; + dilationExpandData[1] = (dilation.len == 1) ? dilation.data[0] : dilation.data[1]; + + AscendTensor gradBiasAt(grad_bias); + std::vector biasShape; + if (grad_bias != nullptr) { + biasShape = gradBiasAt.shape(); + } + + DIOPI_ASCEND_CALL_ACLNN(aclnnConvolutionBackward, + ctx, + grad_output, + input, + weight, + biasShape, + diopiSize_t{strideExpandData, 2}, + diopiSize_t{paddingExpandData, 2}, + diopiSize_t{dilationExpandData, 2}, + transposed, + output_padding, + groups, + gradMask, + cubeMathType, + grad_input, + grad_weight, + grad_bias); + return diopiSuccess; +} + } // namespace ascend } // namespace impl diff --git a/impl/ascend/functions_ext/apply_penalty.cpp b/impl/ascend/functions_ext/apply_penalty.cpp index ddcda9cf4..4fc3787d4 100644 --- a/impl/ascend/functions_ext/apply_penalty.cpp +++ b/impl/ascend/functions_ext/apply_penalty.cpp @@ -8,10 +8,103 @@ #include #include "../aclnn/adaptor.hpp" +#include "../common/acloprunner.hpp" +#include "impl_functions.hpp" namespace impl { namespace ascend { +diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t logits, diopiConstTensorHandle_t presencePenalty, + diopiConstTensorHandle_t frequencyPenalty, diopiConstTensorHandle_t pTokenIds, diopiConstTensorHandle_t pTokenCounts, + diopiConstTensorHandle_t pCumsumSeqLen, int pMaxLenInBatch) { + AscendTensor logitsAt(logits); + AscendTensor pCumsumSeqLenAt(pCumsumSeqLen); + AscendTensor frequencyPenaltyAt(frequencyPenalty); + AscendTensor presencePenaltyAt(presencePenalty); + AscendTensor pTokenIdsAt(pTokenIds); + AscendTensor pTokenCountsAt(pTokenCounts); + + int batch = logitsAt.shape(0); + const int64_t dim = 0; + diopiDtype_t logitsDtype = logitsAt.dtype(); + + AscendTensor curBatchIndexHostAt = deviceToHostSync(ctx, pCumsumSeqLenAt); + AscendTensor frequencyPenaltyHostAt = deviceToHostSync(ctx, frequencyPenaltyAt); + AscendTensor presencePenaltyHostAt = deviceToHostSync(ctx, presencePenaltyAt); + + const int* curBatchIndexData = reinterpret_cast(curBatchIndexHostAt.data()); + + for (int i = 0; i < batch; ++i) { + int curBatchStartIndex = *(curBatchIndexData + i); + int curBatchEndIndex = *(curBatchIndexData + (i + 1)); + AscendTensor sliceAt; + std::vector sliceAtShape(1, curBatchEndIndex - curBatchStartIndex); + makeTensor(ctx, sliceAt, sliceAtShape, diopi_dtype_int32); + const diopiScalar_t curBatchStartIndexScalar = constructDiopiScalarT(diopi_dtype_int32, curBatchStartIndex); + const diopiScalar_t curBatchEndIndexScalar = constructDiopiScalarT(diopi_dtype_int32, curBatchEndIndex); + const diopiScalar_t stepScalar = constructDiopiScalarT(diopi_dtype_int32, 1); + DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &curBatchStartIndexScalar, &curBatchEndIndexScalar, &stepScalar, sliceAt); + + diopiTensorHandle_t curTokenIds; + diopiConstTensorHandle_t sliceTensorHandle = sliceAt.tensorHandle(); + ascend_npu::diopiIndex(ctx, &curTokenIds, pTokenIds, &sliceTensorHandle, 1); + + diopiTensorHandle_t curTokenCounts; + ascend_npu::diopiIndex(ctx, &curTokenCounts, pTokenCounts, &sliceTensorHandle, 1); + + AscendTensor curTokenIdsAt(curTokenIds); + AscendTensor curTokenCountsAt(curTokenCounts); + AscendTensor curLogitsAt; + std::vector curLogitsAtShape(1); + curLogitsAtShape[dim] = curTokenIdsAt.shape()[0]; + makeTensor(ctx, curLogitsAt, curLogitsAtShape, logitsDtype); + AscendTensor logitsAtI; + makeTensor(ctx, logitsAtI, {1, logitsAt.shape()[1]}, logitsDtype); + diopiScalar_t iScalar = constructDiopiScalarT(diopi_dtype_int32, i); + AscendTensor iTensorAt; + makeTensorFromScalar(ctx, iTensorAt, &iScalar, logitsAt.device()); + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, logitsAt, dim, iTensorAt, logitsAtI); + + logitsAtI.view({logitsAt.shape()[1]}); + + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, logitsAtI, dim, curTokenIds, curLogitsAt); + AscendTensor frequencyPenaltyAdjustmentAt; + makeTensor(ctx, frequencyPenaltyAdjustmentAt, curTokenCountsAt.shape(), logitsDtype); + + diopiScalar_t frequencyPenaltyAtIScalar; + if (logitsDtype == diopi_dtype_float32) { + const float* frequencyPenaltyData = reinterpret_cast(frequencyPenaltyHostAt.data()); + frequencyPenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(frequencyPenaltyData + i)); + } else { + const half_float::half* frequencyPenaltyData = reinterpret_cast(frequencyPenaltyHostAt.data()); + frequencyPenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(frequencyPenaltyData + i)); + } + DIOPI_ASCEND_CALL_ACLNN(aclnnMuls, ctx, curTokenCounts, &frequencyPenaltyAtIScalar, frequencyPenaltyAdjustmentAt); + + AscendTensor totalPenaltyAdjustmentAt; + makeTensor(ctx, totalPenaltyAdjustmentAt, curTokenCountsAt.shape(), logitsDtype); + + diopiScalar_t presencePenaltyAtIScalar; + if (logitsDtype == diopi_dtype_float32) { + const float* presencePenaltyData = reinterpret_cast(presencePenaltyHostAt.data()); + presencePenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(presencePenaltyData + i)); + } else { + const half_float::half* presencePenaltyData = reinterpret_cast(presencePenaltyHostAt.data()); + presencePenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(presencePenaltyData + i)); + } + diopiScalar_t oneScalar = constructDiopiScalarT(logitsDtype, 1); + DIOPI_ASCEND_CALL_ACLNN(aclnnAdds, ctx, frequencyPenaltyAdjustmentAt, &presencePenaltyAtIScalar, &oneScalar, totalPenaltyAdjustmentAt); + + DIOPI_ASCEND_CALL_ACLNN(aclnnSub, ctx, curLogitsAt, totalPenaltyAdjustmentAt, &oneScalar, curLogitsAt); + std::vector indices; + indices.emplace_back(iTensorAt); + indices.emplace_back(curTokenIdsAt); + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, logitsAt, indices, curLogitsAt, false, true); + } + + return diopiSuccess; +} + diopiError_t diopiApplyPenaltyV2(diopiContextHandle_t ctx, diopiTensorHandle_t logits, diopiConstTensorHandle_t presencePenalty, diopiConstTensorHandle_t frequencyPenalty, diopiConstTensorHandle_t repetitionPenalty, diopiConstTensorHandle_t pTokenIds, diopiConstTensorHandle_t pTokenCounts) { diff --git a/impl/ascend/functions_ext/context_attention_inference.cpp b/impl/ascend/functions_ext/context_attention_inference.cpp new file mode 100644 index 000000000..3bd88e9b5 --- /dev/null +++ b/impl/ascend/functions_ext/context_attention_inference.cpp @@ -0,0 +1,151 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2024, DeepLink. + */ + +#include +#include + +#include "../aclnn/adaptor.hpp" +#include "../common/acloprunner.hpp" +#include "impl_functions.hpp" + +namespace impl { +namespace ascend { + +AscendTensor torchContextAttention(diopiContextHandle_t ctx, AscendTensor xq, AscendTensor xk, AscendTensor xv, int batchSize, int seqLen, int head, int dim) { + xq.view({batchSize, seqLen, head, dim}); + AscendTensor xqTransposeAt; + std::vector xqTransposeAtShape(xq.shape()); + int64_t tmp = xqTransposeAtShape[1]; + xqTransposeAtShape[1] = xqTransposeAtShape[2]; + xqTransposeAtShape[2] = tmp; + makeTensor(ctx, xqTransposeAt, xqTransposeAtShape, xq.dtype()); + std::vector xqTransposeDims = {0, 2, 1, 3}; + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, xq, xqTransposeDims, xqTransposeAt); + + xk.view({batchSize, seqLen, head, dim}); + AscendTensor xkTransposeAt; + makeTensor(ctx, xkTransposeAt, xqTransposeAtShape, xk.dtype()); + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, xk, xqTransposeDims, xkTransposeAt); + + xv.view({batchSize, seqLen, head, dim}); + AscendTensor xvTransposeAt; + makeTensor(ctx, xvTransposeAt, xqTransposeAtShape, xv.dtype()); + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, xv, xqTransposeDims, xvTransposeAt); + + AscendTensor maskAt; + makeTensor(ctx, maskAt, {1, 1, seqLen, seqLen}, diopi_dtype_float32); + AscendTensor onesMatrixAt; + makeTensor(ctx, onesMatrixAt, {seqLen, seqLen}, diopi_dtype_float32); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceOne, ctx, onesMatrixAt); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceTril, ctx, onesMatrixAt, 0); + maskAt = onesMatrixAt.unsqueeze(0).unsqueeze(0); + diopiScalar_t valueScalar = constructDiopiScalarT(diopi_dtype_float32, -100000000.0); + AscendTensor maskMatrixAt; + makeTensor(ctx, maskMatrixAt, maskAt.shape(), diopi_dtype_int32); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceOne, ctx, maskMatrixAt); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceTriu, ctx, maskMatrixAt, 1); + AscendTensor maskBoolMatrixAt; + makeTensor(ctx, maskBoolMatrixAt, maskMatrixAt.shape(), diopi_dtype_bool); + DIOPI_ASCEND_CALL_ACLNN(aclnnCast, ctx, maskMatrixAt, diopi_dtype_bool, maskBoolMatrixAt); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceMaskedFillScalar, ctx, maskAt, maskBoolMatrixAt, &valueScalar); + AscendTensor maskRepeatAt; + std::vector maskRepeatAtShape(maskAt.shape()); + maskRepeatAtShape[0] *= batchSize; + maskRepeatAtShape[1] *= head; + makeTensor(ctx, maskRepeatAt, maskRepeatAtShape, maskAt.dtype()); + std::vector repeats = {batchSize, head, 1, 1}; + DIOPI_ASCEND_CALL_ACLNN(aclnnRepeat, ctx, maskAt, repeats, maskRepeatAt); + + AscendTensor scoresAt; + AscendTensor xkTransposeAt2; + std::vector xkTransposeAtShape(xkTransposeAt.shape()); + tmp = xkTransposeAtShape[2]; + xkTransposeAtShape[2] = xkTransposeAtShape[3]; + xkTransposeAtShape[3] = tmp; + makeTensor(ctx, xkTransposeAt2, xkTransposeAtShape, xk.dtype()); + std::vector xkTransposeAt2Dims = {0, 1, 3, 2}; + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, xkTransposeAt, xkTransposeAt2Dims, xkTransposeAt2); + std::vector scoresShapeAt = xqTransposeAt.shape(); + scoresShapeAt[3] = xkTransposeAtShape[3]; + makeTensor(ctx, scoresAt, scoresShapeAt, xq.dtype()); + DIOPI_ASCEND_CALL_ACLNN(aclnnMatmul, ctx, xqTransposeAt, xkTransposeAt2, scoresAt, (int8_t)0); + diopiScalar_t otherScalar = constructDiopiScalarT(diopi_dtype_float32, std::sqrt(dim)); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceDivs, ctx, scoresAt, &otherScalar); + + AscendTensor adjustedScoresAt; + std::vector adjustedScoresAtShape = inferSize(scoresAt.shape(), maskRepeatAt.shape()); + makeTensor(ctx, adjustedScoresAt, adjustedScoresAtShape, scoresAt.dtype()); + diopiScalar_t alphaScalar = constructDiopiScalarT(diopi_dtype_float32, 1); + DIOPI_ASCEND_CALL_ACLNN(aclnnAdd, ctx, scoresAt, maskRepeatAt, &alphaScalar, adjustedScoresAt); + DIOPI_ASCEND_CALL_ACLNN(aclnnSoftmax, ctx, adjustedScoresAt, adjustedScoresAt.dim() - 1, adjustedScoresAt); + AscendTensor outputAt; + std::vector outputAtShape = adjustedScoresAt.shape(); + outputAtShape[3] = xvTransposeAt.shape(3); + makeTensor(ctx, outputAt, outputAtShape, scoresAt.dtype()); + DIOPI_ASCEND_CALL_ACLNN(aclnnMatmul, ctx, adjustedScoresAt, xvTransposeAt, outputAt, (int8_t)0); + tmp = outputAtShape[1]; + outputAtShape[1] = outputAtShape[2]; + outputAtShape[2] = tmp; + AscendTensor outputTransposeAt; + makeTensor(ctx, outputTransposeAt, outputAtShape, outputAt.dtype()); + std::vector outputTransposeDims = {0, 2, 1, 3}; + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, outputAt, outputTransposeDims, outputTransposeAt); + outputTransposeAt.view({outputTransposeAt.numel() / static_cast(head * dim), head, dim}); + + return outputTransposeAt; +} + +diopiError_t diopiContextAttentionInference(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t q, diopiConstTensorHandle_t k, + diopiConstTensorHandle_t v, diopiConstTensorHandle_t bStartLoc, diopiConstTensorHandle_t bSeqLen, int maxInputLen) { + AscendTensor bStartLocAt(bStartLoc); + AscendTensor qAt(q); + AscendTensor bSeqLenAt(bSeqLen); + + int batch = bStartLocAt.shape()[0]; + int head = qAt.shape()[1]; + int dim = qAt.shape()[2]; + + AscendTensor bStartLocHostAt = deviceToHostSync(ctx, bStartLocAt); + AscendTensor bSeqLenHostAt = deviceToHostSync(ctx, bSeqLenAt); + + const int* bStartLocData = reinterpret_cast(bStartLocHostAt.data()); + const int* bSeqLenData = reinterpret_cast(bSeqLenHostAt.data()); + + for (int i = 0; i < batch; ++i) { + int start = *(bStartLocData + i); + int end = start + *(bSeqLenData + i); + + AscendTensor sliceAt; + std::vector sliceAtShape(1, end - start); + makeTensor(ctx, sliceAt, sliceAtShape, diopi_dtype_int32); + diopiScalar_t startIndexScalar = constructDiopiScalarT(diopi_dtype_int32, start); + diopiScalar_t endIndexScalar = constructDiopiScalarT(diopi_dtype_int32, end); + diopiScalar_t stepScalar = constructDiopiScalarT(diopi_dtype_int32, 1); + DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &startIndexScalar, &endIndexScalar, &stepScalar, sliceAt); + + diopiTensorHandle_t qIndex; + diopiConstTensorHandle_t sliceTensorHandle = sliceAt.tensorHandle(); + ascend_npu::diopiIndex(ctx, &qIndex, q, &sliceTensorHandle, 1); + + diopiTensorHandle_t kIndex; + ascend_npu::diopiIndex(ctx, &kIndex, k, &sliceTensorHandle, 1); + + diopiTensorHandle_t vIndex; + ascend_npu::diopiIndex(ctx, &vIndex, v, &sliceTensorHandle, 1); + + AscendTensor valuesAt; + AscendTensor qIndexAt(qIndex), kIndexAt(kIndex), vIndexAt(vIndex); + valuesAt = torchContextAttention(ctx, qIndexAt, kIndexAt, vIndexAt, 1, *(bSeqLenData + i), head, dim); + + std::vector indices = {sliceAt}; + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, out, indices, valuesAt, false, true); + } + + return diopiSuccess; +} + +} // namespace ascend +} // namespace impl diff --git a/impl/ascend_npu/CMakeLists.txt b/impl/ascend_npu/CMakeLists.txt index 79ce646e2..8c4116bd0 100644 --- a/impl/ascend_npu/CMakeLists.txt +++ b/impl/ascend_npu/CMakeLists.txt @@ -201,6 +201,7 @@ set(OLD_IMPL_SRC ${OLD_IMPL_DIR}/functions_ext/adamw.cpp ${OLD_IMPL_DIR}/functions_ext/destindex_copy_kv.cpp ${OLD_IMPL_DIR}/functions_ext/apply_penalty.cpp + ${OLD_IMPL_DIR}/functions_ext/context_attention_inference.cpp ${OLD_IMPL_DIR}/functions_ext/flash_attention.cpp ${OLD_IMPL_DIR}/functions_ext/flash_attention_varlen.cpp ${OLD_IMPL_DIR}/functions_ext/prompt_flash_attention.cpp diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml index 74b1c8733..f045f683f 100644 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -15,6 +15,7 @@ ascend: - diopiAddmm - diopiAll - diopiAny +- diopiApplyPenalty - diopiApplyPenaltyV2 - diopiArange - diopiArgmax @@ -58,9 +59,12 @@ ascend: - diopiClampMinScalar - diopiClampScalar - diopiCol2Im +- diopiContextAttentionInference - diopiContiguous - diopiConvolution2d - diopiConvolution2dBackward +- diopiConvTranspose2d +- diopiConvTranspose2dBackward - diopiCopyInp - diopiCos - diopiCosInp @@ -269,8 +273,6 @@ ascend: - diopiZeroInp - diopiZeros ascend_npu: -- diopiApplyPenalty -- diopiContextAttentionInference - diopiGetNativeMemoryFormat - diopiNLLLoss - diopiNLLLossBackward