Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion impl/ascend/aclnn/adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ inline void* getOpApiFuncAddr(const char* apiName) {
constexpr const char kOpApiLibName[] = "libopapi.so";
static void* opApiHandler = getOpApiLibHandler(kOpApiLibName);
if (opApiHandler == nullptr) {
return nullptr;
error(__FILE__, __LINE__, __FUNCTION__, "Error: Failed to get opApi handler for %s.", apiName);
}
return getOpApiFuncAddrInLib(opApiHandler, kOpApiLibName, apiName);
}
Expand Down
6 changes: 6 additions & 0 deletions impl/ascend/convert_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@
- diopiConvolution2dBackward:
dtype: (float64)->float32

- diopiConvTranspose2d:
dtype: (float64)->float32

- diopiConvTranspose2dBackward:
dtype: (float64)->float32

- diopiAdaptiveAvgPool2d:
dtype: (float64)->float32

Expand Down
30 changes: 0 additions & 30 deletions impl/ascend/device_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,18 +470,6 @@
),
),

'conv_transpose2d': dict(
name=['conv_transpose2d'],
tensor_para=dict(
args=[
{
"ins": ['input'],
"dtype": [Skip(np.float32),Skip(np.float64),Skip(np.float16),],
},
]
),
),

'unfold': dict(
name=['unfold'],
tensor_para=dict(
Expand Down Expand Up @@ -1158,24 +1146,6 @@
],
),
),

# 'apply_penalty': dict(
# name=['apply_penalty'],
# tensor_para=dict(
# args=[
# {
# "ins": ['logits'],
# "dtype": [Skip(np.float64)],
# },
# ]
# )
# ),

# TODO(zhangqiu) Due to a bug in the software stack, this test will be skipped for now.
'apply_penalty': dict(
name=['apply_penalty'],
skip_all=True
),

# TODO(zhangqiu) Due to a bug in the software stack, this test will be skipped for now.
'embedding': dict(
Expand Down
100 changes: 100 additions & 0 deletions impl/ascend/functions/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,105 @@ diopiError_t diopiConvolution2dBackward(diopiContextHandle_t ctx, diopiTensorHan
return diopiSuccess;
}

DIOPI_API diopiError_t diopiConvTranspose2d(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight,
diopiConstTensorHandle_t bias, diopiSize_t stride, diopiSize_t padding, diopiSize_t output_padding, int64_t groups,
diopiSize_t dilation) {
bool transposed = true;
int8_t cubeMathType = 0;

ASCEND_CHECK_ABORT(stride.len == 1 || stride.len == 2, "the dim of stride must be 1 or 2!");
ASCEND_CHECK_ABORT(padding.len == 1 || padding.len == 2, "the dim of padding must be 1 or 2!");
ASCEND_CHECK_ABORT(dilation.len == 1 || dilation.len == 2, "the dim of dilation must be 1 or 2!");

int64_t strideExpandData[2];
int64_t paddingExpandData[2];
int64_t dilationExpandData[2];

strideExpandData[0] = stride.data[0];
strideExpandData[1] = (stride.len == 1) ? stride.data[0] : stride.data[1];

paddingExpandData[0] = padding.data[0];
paddingExpandData[1] = (padding.len == 1) ? padding.data[0] : padding.data[1];

dilationExpandData[0] = dilation.data[0];
dilationExpandData[1] = (dilation.len == 1) ? dilation.data[0] : dilation.data[1];

DIOPI_ASCEND_CALL_ACLNN(aclnnConvolution,
ctx,
input,
weight,
bias,
diopiSize_t{strideExpandData, 2},
diopiSize_t{paddingExpandData, 2},
diopiSize_t{dilationExpandData, 2},
transposed,
output_padding,
groups,
out,
cubeMathType);
return diopiSuccess;
}

DIOPI_API diopiError_t diopiConvTranspose2dBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight,
diopiTensorHandle_t grad_bias, diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input,
diopiConstTensorHandle_t weight, diopiSize_t* bias_sizes, diopiSize_t stride, diopiSize_t padding,
diopiSize_t dilation, diopiSize_t output_padding, int64_t groups) {
bool transposed = true;
int8_t cubeMathType = 0;

std::array<bool, 3> gradMask = {true, true, true};
if (nullptr == grad_input) {
gradMask[0] = false;
}
if (nullptr == grad_weight) {
gradMask[1] = false;
}
if (nullptr == grad_bias) {
gradMask[2] = false;
}

ASCEND_CHECK_ABORT(stride.len == 1 || stride.len == 2, "the dim of stride must be 1 or 2!");
ASCEND_CHECK_ABORT(padding.len == 1 || padding.len == 2, "the dim of padding must be 1 or 2!");
ASCEND_CHECK_ABORT(dilation.len == 1 || dilation.len == 2, "the dim of dilation must be 1 or 2!");

int64_t strideExpandData[2];
int64_t paddingExpandData[2];
int64_t dilationExpandData[2];

strideExpandData[0] = stride.data[0];
strideExpandData[1] = (stride.len == 1) ? stride.data[0] : stride.data[1];

paddingExpandData[0] = padding.data[0];
paddingExpandData[1] = (padding.len == 1) ? padding.data[0] : padding.data[1];

dilationExpandData[0] = dilation.data[0];
dilationExpandData[1] = (dilation.len == 1) ? dilation.data[0] : dilation.data[1];

AscendTensor gradBiasAt(grad_bias);
std::vector<int64_t> biasShape;
if (grad_bias != nullptr) {
biasShape = gradBiasAt.shape();
}

DIOPI_ASCEND_CALL_ACLNN(aclnnConvolutionBackward,
ctx,
grad_output,
input,
weight,
biasShape,
diopiSize_t{strideExpandData, 2},
diopiSize_t{paddingExpandData, 2},
diopiSize_t{dilationExpandData, 2},
transposed,
output_padding,
groups,
gradMask,
cubeMathType,
grad_input,
grad_weight,
grad_bias);
return diopiSuccess;
}

} // namespace ascend
} // namespace impl
93 changes: 93 additions & 0 deletions impl/ascend/functions_ext/apply_penalty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,103 @@
#include <vector>

#include "../aclnn/adaptor.hpp"
#include "../common/acloprunner.hpp"
#include "impl_functions.hpp"

namespace impl {
namespace ascend {

diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t logits, diopiConstTensorHandle_t presencePenalty,
diopiConstTensorHandle_t frequencyPenalty, diopiConstTensorHandle_t pTokenIds, diopiConstTensorHandle_t pTokenCounts,
diopiConstTensorHandle_t pCumsumSeqLen, int pMaxLenInBatch) {
AscendTensor logitsAt(logits);
AscendTensor pCumsumSeqLenAt(pCumsumSeqLen);
AscendTensor frequencyPenaltyAt(frequencyPenalty);
AscendTensor presencePenaltyAt(presencePenalty);
AscendTensor pTokenIdsAt(pTokenIds);
AscendTensor pTokenCountsAt(pTokenCounts);

int batch = logitsAt.shape(0);
const int64_t dim = 0;
diopiDtype_t logitsDtype = logitsAt.dtype();

AscendTensor curBatchIndexHostAt = deviceToHostSync(ctx, pCumsumSeqLenAt);
AscendTensor frequencyPenaltyHostAt = deviceToHostSync(ctx, frequencyPenaltyAt);
AscendTensor presencePenaltyHostAt = deviceToHostSync(ctx, presencePenaltyAt);

const int* curBatchIndexData = reinterpret_cast<const int*>(curBatchIndexHostAt.data());

for (int i = 0; i < batch; ++i) {
int curBatchStartIndex = *(curBatchIndexData + i);
int curBatchEndIndex = *(curBatchIndexData + (i + 1));
AscendTensor sliceAt;
std::vector<int64_t> sliceAtShape(1, curBatchEndIndex - curBatchStartIndex);
makeTensor(ctx, sliceAt, sliceAtShape, diopi_dtype_int32);
const diopiScalar_t curBatchStartIndexScalar = constructDiopiScalarT(diopi_dtype_int32, curBatchStartIndex);
const diopiScalar_t curBatchEndIndexScalar = constructDiopiScalarT(diopi_dtype_int32, curBatchEndIndex);
const diopiScalar_t stepScalar = constructDiopiScalarT(diopi_dtype_int32, 1);
DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &curBatchStartIndexScalar, &curBatchEndIndexScalar, &stepScalar, sliceAt);

diopiTensorHandle_t curTokenIds;
diopiConstTensorHandle_t sliceTensorHandle = sliceAt.tensorHandle();
ascend_npu::diopiIndex(ctx, &curTokenIds, pTokenIds, &sliceTensorHandle, 1);

diopiTensorHandle_t curTokenCounts;
ascend_npu::diopiIndex(ctx, &curTokenCounts, pTokenCounts, &sliceTensorHandle, 1);

AscendTensor curTokenIdsAt(curTokenIds);
AscendTensor curTokenCountsAt(curTokenCounts);
AscendTensor curLogitsAt;
std::vector<int64_t> curLogitsAtShape(1);
curLogitsAtShape[dim] = curTokenIdsAt.shape()[0];
makeTensor(ctx, curLogitsAt, curLogitsAtShape, logitsDtype);
AscendTensor logitsAtI;
makeTensor(ctx, logitsAtI, {1, logitsAt.shape()[1]}, logitsDtype);
diopiScalar_t iScalar = constructDiopiScalarT(diopi_dtype_int32, i);
AscendTensor iTensorAt;
makeTensorFromScalar(ctx, iTensorAt, &iScalar, logitsAt.device());
DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, logitsAt, dim, iTensorAt, logitsAtI);

logitsAtI.view({logitsAt.shape()[1]});

DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, logitsAtI, dim, curTokenIds, curLogitsAt);
AscendTensor frequencyPenaltyAdjustmentAt;
makeTensor(ctx, frequencyPenaltyAdjustmentAt, curTokenCountsAt.shape(), logitsDtype);

diopiScalar_t frequencyPenaltyAtIScalar;
if (logitsDtype == diopi_dtype_float32) {
const float* frequencyPenaltyData = reinterpret_cast<const float*>(frequencyPenaltyHostAt.data());
frequencyPenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(frequencyPenaltyData + i));
} else {
const half_float::half* frequencyPenaltyData = reinterpret_cast<const half_float::half*>(frequencyPenaltyHostAt.data());
frequencyPenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(frequencyPenaltyData + i));
}
DIOPI_ASCEND_CALL_ACLNN(aclnnMuls, ctx, curTokenCounts, &frequencyPenaltyAtIScalar, frequencyPenaltyAdjustmentAt);

AscendTensor totalPenaltyAdjustmentAt;
makeTensor(ctx, totalPenaltyAdjustmentAt, curTokenCountsAt.shape(), logitsDtype);

diopiScalar_t presencePenaltyAtIScalar;
if (logitsDtype == diopi_dtype_float32) {
const float* presencePenaltyData = reinterpret_cast<const float*>(presencePenaltyHostAt.data());
presencePenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(presencePenaltyData + i));
} else {
const half_float::half* presencePenaltyData = reinterpret_cast<const half_float::half*>(presencePenaltyHostAt.data());
presencePenaltyAtIScalar = constructDiopiScalarT(logitsDtype, *(presencePenaltyData + i));
}
diopiScalar_t oneScalar = constructDiopiScalarT(logitsDtype, 1);
DIOPI_ASCEND_CALL_ACLNN(aclnnAdds, ctx, frequencyPenaltyAdjustmentAt, &presencePenaltyAtIScalar, &oneScalar, totalPenaltyAdjustmentAt);

DIOPI_ASCEND_CALL_ACLNN(aclnnSub, ctx, curLogitsAt, totalPenaltyAdjustmentAt, &oneScalar, curLogitsAt);
std::vector<AscendTensor> indices;
indices.emplace_back(iTensorAt);
indices.emplace_back(curTokenIdsAt);
DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, logitsAt, indices, curLogitsAt, false, true);
}

return diopiSuccess;
}

diopiError_t diopiApplyPenaltyV2(diopiContextHandle_t ctx, diopiTensorHandle_t logits, diopiConstTensorHandle_t presencePenalty,
diopiConstTensorHandle_t frequencyPenalty, diopiConstTensorHandle_t repetitionPenalty, diopiConstTensorHandle_t pTokenIds,
diopiConstTensorHandle_t pTokenCounts) {
Expand Down
Loading