Skip to content

Commit ec68771

Browse files
authored
[Ascend]zq/update batch_norm by aclnn (#1181)
1 parent 160952d commit ec68771

File tree

3 files changed

+31
-128
lines changed

3 files changed

+31
-128
lines changed

impl/ascend/aclnn/adaptor.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ inline aclTensor* createAclTensorFromDiopiTensor(diopiConstTensorHandle_t tensor
7676
if (tensor == nullptr) {
7777
return nullptr;
7878
}
79+
7980
diopiSize_t shape{};
8081
diopiGetTensorShape(tensor, &shape);
8182
diopiSize_t stride{};

impl/ascend/functions/batch_norm.cpp

Lines changed: 28 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -4,146 +4,48 @@
44
* @copyright (c) 2023, DeepLink.
55
*/
66

7-
#include "../common/acloprunner.hpp"
7+
#include "../aclnn/acl_scalar.hpp"
8+
#include "../aclnn/adaptor.hpp"
89

910
namespace impl {
1011
namespace ascend {
1112

12-
void updateInputAscendTensorDim(AscendTensor& inputAt, bool training) {
13-
int64_t dim = inputAt.dim();
14-
if (2 == dim) {
15-
inputAt.unsqueeze(2);
16-
inputAt.unsqueeze(3);
17-
} else if (3 == dim) {
18-
inputAt.unsqueeze(3);
19-
} else if (5 == dim && !training) {
20-
std::vector<int64_t> shape4d{inputAt.shape(0), inputAt.shape(1), inputAt.shape(2), inputAt.shape(3) * inputAt.shape(4)};
21-
inputAt.view(shape4d);
22-
}
23-
}
24-
25-
void batchNormBackwardTrainingUpdate(diopiContextHandle_t ctx, diopiTensorHandle_t gradWeight, diopiTensorHandle_t gradBias, AscendTensor gradOutputAt,
26-
AscendTensor inputAt, diopiConstTensorHandle_t saveMean, diopiConstTensorHandle_t saveInvstd, double eps) {
27-
std::string name = (inputAt.dim() == 5) ? "BN3DTrainingUpdateGrad" : "BNTrainingUpdateGrad";
28-
AclOpRunner<4, 2>(name, ctx)
29-
.addInput(gradOutputAt)
30-
.addInput(inputAt)
31-
.addInput(saveMean)
32-
.addInput(saveInvstd)
33-
.addOutput(gradWeight)
34-
.addOutput(gradBias)
35-
.setAttr<float>("epsilon", static_cast<float>(eps))
36-
.run();
37-
}
38-
39-
void batchNormBackwardTrainingReduceNocheck(diopiContextHandle_t ctx, AscendTensor gradInputAt, diopiConstTensorHandle_t gradWeight,
40-
diopiConstTensorHandle_t gradBias, AscendTensor gradOutputAt, AscendTensor inputAt, diopiConstTensorHandle_t weight,
41-
diopiConstTensorHandle_t saveMean, diopiConstTensorHandle_t saveInvstd, double eps) {
42-
std::string name = (inputAt.dim() == 5) ? "BN3DTrainingReduceGrad" : "BNTrainingReduceGrad";
43-
AclOpRunner<7, 1>(name, ctx)
44-
.addInput(gradOutputAt)
45-
.addInput(inputAt)
46-
.addInput(gradWeight)
47-
.addInput(gradBias)
48-
.addInput(weight)
49-
.addInput(saveMean)
50-
.addInput(saveInvstd)
51-
.addOutput(gradInputAt)
52-
.setAttr<float>("epsilon", static_cast<float>(eps))
53-
.run();
54-
}
55-
5613
diopiError_t diopiBatchNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t saveMean, diopiTensorHandle_t saveInvstd,
5714
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, diopiTensorHandle_t runningMean,
5815
diopiTensorHandle_t runningVar, bool training, double momentum, double eps) {
59-
AscendTensor inputAt(input), outputAt(out);
60-
updateInputAscendTensorDim(inputAt, training);
61-
outputAt.view(inputAt.getAclMemShape());
62-
63-
std::vector<int64_t> batchShapeV{inputAt.shape(1)};
64-
diopiSize_t batchShapeSizeT{batchShapeV.data(), static_cast<int64_t>(batchShapeV.size())};
65-
diopiTensorHandle_t weightTemp = createTensorIfNullptrOrConstCast(ctx, weight, batchShapeSizeT, inputAt.dtype(), true, 1);
66-
diopiTensorHandle_t biasTemp = createTensorIfNullptrOrConstCast(ctx, bias, batchShapeSizeT, inputAt.dtype(), true, 0);
67-
diopiTensorHandle_t runningMeanTemp = createTensorIfNullptrOrConstCast(ctx, runningMean, batchShapeSizeT, inputAt.dtype(), true, 0);
68-
diopiTensorHandle_t runningVarTemp = createTensorIfNullptrOrConstCast(ctx, runningVar, batchShapeSizeT, inputAt.dtype(), true, 1);
69-
70-
if (!training) {
71-
AclOpRunner<5, 1>("BNInfer", ctx)
72-
.addInput(inputAt)
73-
.addInput(weightTemp)
74-
.addInput(biasTemp)
75-
.addInput(runningMeanTemp)
76-
.addInput(runningVarTemp)
77-
.addOutput(outputAt)
78-
.setAttr("epsilon", static_cast<float>(eps))
79-
.run();
80-
81-
diopiTensorHandle_t runningVarBroadcasted;
82-
makeTensorLike(ctx, &runningVarBroadcasted, input);
83-
AscendTensor runningVarAt(runningVar);
84-
runningVarAt.unsqueeze(0);
85-
runningVarAt.unsqueeze(2);
86-
runningVarAt.unsqueeze(3);
87-
AclOpRunner<2, 1>("BroadcastTo", ctx).addInput(runningVarAt).addConstInput(inputAt.shape()).addOutput(runningVarBroadcasted).run();
88-
} else {
89-
diopiTensorHandle_t sum = nullptr, squareSum = nullptr;
90-
diopiSize_t shape, stride;
91-
diopiGetTensorShape(runningMeanTemp, &shape);
92-
diopiGetTensorStride(runningMeanTemp, &stride);
93-
diopiRequireTensor(ctx, &sum, &shape, &stride, diopiDtype_t::diopi_dtype_float32, diopi_device);
94-
diopiRequireTensor(ctx, &squareSum, &shape, &stride, diopiDtype_t::diopi_dtype_float32, diopi_device);
95-
AclOpRunner<1, 2>("BNTrainingReduce", ctx).addInput(inputAt).addOutput(sum).setAttr("epsilon", static_cast<float>(eps)).addOutput(squareSum).run();
96-
AclOpRunner<7, 5>("BNTrainingUpdate", ctx)
97-
.addInput(inputAt)
98-
.addInput(sum)
99-
.addInput(squareSum)
100-
.addInput(weightTemp)
101-
.addInput(biasTemp)
102-
.addInput(runningMeanTemp)
103-
.addInput(runningVarTemp)
104-
.setAttr("epsilon", static_cast<float>(eps))
105-
.setAttr("factor", static_cast<float>(momentum))
106-
.addOutput(outputAt)
107-
.addOutput(runningMeanTemp)
108-
.addOutput(runningVarTemp)
109-
.addOutput(saveMean)
110-
.addOutput(saveInvstd)
111-
.run();
112-
}
16+
DIOPI_ASCEND_CALL_ACLNN(aclnnBatchNorm, ctx, input, weight, bias, runningMean, runningVar, training, momentum, eps, out, saveMean, saveInvstd);
11317
return diopiSuccess;
11418
}
11519

11620
diopiError_t diopiBatchNormBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiTensorHandle_t gradWeight, diopiTensorHandle_t gradBias,
11721
diopiConstTensorHandle_t gradOutput, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight,
118-
diopiConstTensorHandle_t runninMean, diopiConstTensorHandle_t runningVar, diopiConstTensorHandle_t saveMean,
22+
diopiConstTensorHandle_t runningMean, diopiConstTensorHandle_t runningVar, diopiConstTensorHandle_t saveMean,
11923
diopiConstTensorHandle_t saveInvstd, bool training, double eps) {
120-
AscendTensor inputAt(input), gradOutputAt(gradOutput), gradInputAt(gradInput);
121-
updateInputAscendTensorDim(inputAt, training);
122-
gradOutputAt.view(inputAt.getAclMemShape());
123-
gradInputAt.view(inputAt.getAclMemShape());
124-
125-
if (!training) {
126-
batchNormBackwardTrainingUpdate(ctx, gradWeight, gradBias, gradOutputAt, inputAt, runninMean, runningVar, eps);
127-
128-
AclOpRunner<3, 1>("BNInferGrad", ctx)
129-
.addInput(gradOutputAt)
130-
.addInput(weight)
131-
.addInput(runningVar)
132-
.addOutput(gradInputAt)
133-
.setAttr<float>("epsilon", static_cast<float>(eps))
134-
.run();
135-
136-
diopiTensorHandle_t runningVarBroadcasted;
137-
makeTensorLike(ctx, &runningVarBroadcasted, input);
138-
AscendTensor runningVarAt(runningVar);
139-
runningVarAt.unsqueeze(0);
140-
runningVarAt.unsqueeze(2);
141-
runningVarAt.unsqueeze(3);
142-
AclOpRunner<2, 1>("BroadcastTo", ctx).addInput(runningVarAt).addConstInput(inputAt.shape()).addOutput(runningVarBroadcasted).run();
143-
} else {
144-
batchNormBackwardTrainingUpdate(ctx, gradWeight, gradBias, gradOutputAt, inputAt, saveMean, saveInvstd, eps);
145-
batchNormBackwardTrainingReduceNocheck(ctx, gradInputAt, gradWeight, gradBias, gradOutputAt, inputAt, weight, saveMean, saveInvstd, eps);
24+
std::array<bool, 3> gradMask = {true, true, true};
25+
if (nullptr == gradInput) {
26+
gradMask[0] = false;
27+
}
28+
if (nullptr == gradWeight) {
29+
gradMask[1] = false;
30+
}
31+
if (nullptr == gradBias) {
32+
gradMask[2] = false;
14633
}
34+
DIOPI_ASCEND_CALL_ACLNN(aclnnBatchNormBackward,
35+
ctx,
36+
gradOutput,
37+
input,
38+
weight,
39+
runningMean,
40+
runningVar,
41+
saveMean,
42+
saveInvstd,
43+
training,
44+
eps,
45+
gradMask,
46+
gradInput,
47+
gradWeight,
48+
gradBias);
14749
return diopiSuccess;
14850
}
14951

impl/ascend_npu/ascend_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ ascend:
2020
- diopiAtanInp
2121
- diopiBaddbmm
2222
- diopiBaddbmmInp
23+
- diopiBatchNorm
24+
- diopiBatchNormBackward
2325
- diopiBitwiseNot
2426
- diopiBitwiseNotInp
2527
- diopiBitwiseAnd
@@ -219,8 +221,6 @@ ascend:
219221
- diopiZeros
220222
ascend_npu:
221223
- diopiAdamW
222-
- diopiBatchNorm
223-
- diopiBatchNormBackward
224224
- diopiNonzero
225225
- diopiCat
226226
- diopiCopyInp

0 commit comments

Comments
 (0)