|
4 | 4 | * @copyright (c) 2023, DeepLink. |
5 | 5 | */ |
6 | 6 |
|
7 | | -#include "../common/acloprunner.hpp" |
| 7 | +#include "../aclnn/acl_scalar.hpp" |
| 8 | +#include "../aclnn/adaptor.hpp" |
8 | 9 |
|
9 | 10 | namespace impl { |
10 | 11 | namespace ascend { |
11 | 12 |
|
12 | 13 | diopiError_t diopiBaddbmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t batch1, |
13 | 14 | diopiConstTensorHandle_t batch2, double beta, double alpha) { |
14 | | - diopiDtype_t outDtype; |
15 | | - diopiGetTensorDtype(out, &outDtype); |
| 15 | + AscendTensor inAt(input); |
| 16 | + auto betas = constructDiopiScalarT(inAt.dtype(), beta); |
| 17 | + auto alphas = constructDiopiScalarT(inAt.dtype(), alpha); |
16 | 18 |
|
17 | | - AscendTensor inputAt(input); |
18 | | - AscendTensor outputAt(out); |
19 | | - AscendTensor batch1At(batch1); |
20 | | - AscendTensor batch2At(batch2); |
21 | | - |
22 | | - // get the size of batch1 * batch2 |
23 | | - std::vector<int64_t> batch1Shape = batch1At.shape(); |
24 | | - std::vector<int64_t> batch2Shape = batch2At.shape(); |
25 | | - std::vector<int64_t> vectorSizeBatchMatMulTensor = {batch1Shape[0], batch1Shape[1], batch2Shape[2]}; |
26 | | - |
27 | | - // init a tensor according to the size of batch1 * batch2 ; |
28 | | - diopiSize_t diopiSizeBatchMatMulTensor = vectorToDiopiSize(vectorSizeBatchMatMulTensor); |
29 | | - AscendTensor batchMatMulTensorAt; |
30 | | - makeTensor(ctx, batchMatMulTensorAt, &diopiSizeBatchMatMulTensor, outDtype, diopiDevice_t::diopi_device); |
31 | | - |
32 | | - // does batch1/batch2 need to transpose? |
33 | | - bool isSelfT = false; |
34 | | - bool isMat2T = false; |
35 | | - |
36 | | - // do batch1 times batch2 -> BatchMatMulTensor |
37 | | - AclOpRunner<2, 1>("BatchMatMul", ctx) |
38 | | - .addInput(batch1At) |
39 | | - .addInput(batch2At) |
40 | | - .addOutput(batchMatMulTensorAt) |
41 | | - .setAttr("adj_x1", isSelfT) |
42 | | - .setAttr("adj_x2", isMat2T) |
43 | | - .run(); |
44 | | - |
45 | | - // init memory based on the size of alphaMulTensor and betaMulTensor |
46 | | - AscendTensor alphaMulTensor; |
47 | | - AscendTensor betaMulTensor; |
48 | | - makeTensorLike(ctx, alphaMulTensor, batchMatMulTensorAt, outDtype); |
49 | | - makeTensorLike(ctx, betaMulTensor, inputAt, outDtype); |
50 | | - |
51 | | - diopiScalar_t alphaScalar = constructDiopiScalarT(outDtype, alpha); |
52 | | - diopiScalar_t betaScalar = constructDiopiScalarT(outDtype, beta); |
53 | | - |
54 | | - // transform ascendTensor to diopiTensorHandle_t |
55 | | - diopiTensorHandle_t diopiAlphaMulTensor = const_cast<diopiTensorHandle_t>(alphaMulTensor.tensorHandle()); |
56 | | - diopiTensorHandle_t diopiBateMulTensor = const_cast<diopiTensorHandle_t>(betaMulTensor.tensorHandle()); |
57 | | - diopiTensorHandle_t diopiAsBatchMatMulTensor = const_cast<diopiTensorHandle_t>(batchMatMulTensorAt.tensorHandle()); |
58 | | - diopiTensorHandle_t diopiInput = const_cast<diopiTensorHandle_t>(inputAt.tensorHandle()); |
59 | | - |
60 | | - // alpha times BatchMatMulTensor -> alphaMulTensor and beta times input -> betaMulTensor |
61 | | - diopiMulScalar(ctx, diopiAlphaMulTensor, diopiAsBatchMatMulTensor, &alphaScalar); |
62 | | - diopiMulScalar(ctx, diopiBateMulTensor, diopiInput, &betaScalar); |
63 | | - |
64 | | - diopiScalar_t otherScalar = constructDiopiScalarT(outDtype, 1); |
65 | | - diopiTensorHandle_t diopiOutput = const_cast<diopiTensorHandle_t>(outputAt.tensorHandle()); |
66 | | - diopiAdd(ctx, diopiOutput, diopiAlphaMulTensor, diopiBateMulTensor, &otherScalar); |
| 19 | + int cubeMathType = 0; |
| 20 | + DIOPI_ASCEND_CALL_ACLNN(aclnnBaddbmm, ctx, input, batch1, batch2, &betas, &alphas, out, cubeMathType); |
67 | 21 | return diopiSuccess; |
68 | 22 | } |
69 | 23 |
|
70 | 24 | diopiError_t diopiBaddbmmInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t batch1, diopiConstTensorHandle_t batch2, double beta, |
71 | 25 | double alpha) { |
72 | | - return diopiBaddbmm(ctx, input, input, batch1, batch2, beta, alpha); |
| 26 | + AscendTensor inAt(input); |
| 27 | + auto betas = constructDiopiScalarT(inAt.dtype(), beta); |
| 28 | + auto alphas = constructDiopiScalarT(inAt.dtype(), alpha); |
| 29 | + |
| 30 | + int cubeMathType = 0; |
| 31 | + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceBaddbmm, ctx, input, batch1, batch2, &betas, &alphas, cubeMathType); |
| 32 | + return diopiSuccess; |
73 | 33 | } |
74 | 34 |
|
75 | 35 | } // namespace ascend |
|
0 commit comments