|
4 | 4 | * @copyright (c) 2023, DeepLink. |
5 | 5 | */ |
6 | 6 |
|
7 | | -#include "../common/acloprunner.hpp" |
| 7 | +#include "../aclnn/acl_scalar.hpp" |
| 8 | +#include "../aclnn/adaptor.hpp" |
8 | 9 |
|
9 | 10 | namespace impl { |
10 | 11 | namespace ascend { |
11 | 12 |
|
12 | | -void updateInputAscendTensorDim(AscendTensor& inputAt, bool training) { |
13 | | - int64_t dim = inputAt.dim(); |
14 | | - if (2 == dim) { |
15 | | - inputAt.unsqueeze(2); |
16 | | - inputAt.unsqueeze(3); |
17 | | - } else if (3 == dim) { |
18 | | - inputAt.unsqueeze(3); |
19 | | - } else if (5 == dim && !training) { |
20 | | - std::vector<int64_t> shape4d{inputAt.shape(0), inputAt.shape(1), inputAt.shape(2), inputAt.shape(3) * inputAt.shape(4)}; |
21 | | - inputAt.view(shape4d); |
22 | | - } |
23 | | -} |
24 | | - |
25 | | -void batchNormBackwardTrainingUpdate(diopiContextHandle_t ctx, diopiTensorHandle_t gradWeight, diopiTensorHandle_t gradBias, AscendTensor gradOutputAt, |
26 | | - AscendTensor inputAt, diopiConstTensorHandle_t saveMean, diopiConstTensorHandle_t saveInvstd, double eps) { |
27 | | - std::string name = (inputAt.dim() == 5) ? "BN3DTrainingUpdateGrad" : "BNTrainingUpdateGrad"; |
28 | | - AclOpRunner<4, 2>(name, ctx) |
29 | | - .addInput(gradOutputAt) |
30 | | - .addInput(inputAt) |
31 | | - .addInput(saveMean) |
32 | | - .addInput(saveInvstd) |
33 | | - .addOutput(gradWeight) |
34 | | - .addOutput(gradBias) |
35 | | - .setAttr<float>("epsilon", static_cast<float>(eps)) |
36 | | - .run(); |
37 | | -} |
38 | | - |
39 | | -void batchNormBackwardTrainingReduceNocheck(diopiContextHandle_t ctx, AscendTensor gradInputAt, diopiConstTensorHandle_t gradWeight, |
40 | | - diopiConstTensorHandle_t gradBias, AscendTensor gradOutputAt, AscendTensor inputAt, diopiConstTensorHandle_t weight, |
41 | | - diopiConstTensorHandle_t saveMean, diopiConstTensorHandle_t saveInvstd, double eps) { |
42 | | - std::string name = (inputAt.dim() == 5) ? "BN3DTrainingReduceGrad" : "BNTrainingReduceGrad"; |
43 | | - AclOpRunner<7, 1>(name, ctx) |
44 | | - .addInput(gradOutputAt) |
45 | | - .addInput(inputAt) |
46 | | - .addInput(gradWeight) |
47 | | - .addInput(gradBias) |
48 | | - .addInput(weight) |
49 | | - .addInput(saveMean) |
50 | | - .addInput(saveInvstd) |
51 | | - .addOutput(gradInputAt) |
52 | | - .setAttr<float>("epsilon", static_cast<float>(eps)) |
53 | | - .run(); |
54 | | -} |
55 | | - |
56 | 13 | diopiError_t diopiBatchNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t saveMean, diopiTensorHandle_t saveInvstd, |
57 | 14 | diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, diopiTensorHandle_t runningMean, |
58 | 15 | diopiTensorHandle_t runningVar, bool training, double momentum, double eps) { |
59 | | - AscendTensor inputAt(input), outputAt(out); |
60 | | - updateInputAscendTensorDim(inputAt, training); |
61 | | - outputAt.view(inputAt.getAclMemShape()); |
62 | | - |
63 | | - std::vector<int64_t> batchShapeV{inputAt.shape(1)}; |
64 | | - diopiSize_t batchShapeSizeT{batchShapeV.data(), static_cast<int64_t>(batchShapeV.size())}; |
65 | | - diopiTensorHandle_t weightTemp = createTensorIfNullptrOrConstCast(ctx, weight, batchShapeSizeT, inputAt.dtype(), true, 1); |
66 | | - diopiTensorHandle_t biasTemp = createTensorIfNullptrOrConstCast(ctx, bias, batchShapeSizeT, inputAt.dtype(), true, 0); |
67 | | - diopiTensorHandle_t runningMeanTemp = createTensorIfNullptrOrConstCast(ctx, runningMean, batchShapeSizeT, inputAt.dtype(), true, 0); |
68 | | - diopiTensorHandle_t runningVarTemp = createTensorIfNullptrOrConstCast(ctx, runningVar, batchShapeSizeT, inputAt.dtype(), true, 1); |
69 | | - |
70 | | - if (!training) { |
71 | | - AclOpRunner<5, 1>("BNInfer", ctx) |
72 | | - .addInput(inputAt) |
73 | | - .addInput(weightTemp) |
74 | | - .addInput(biasTemp) |
75 | | - .addInput(runningMeanTemp) |
76 | | - .addInput(runningVarTemp) |
77 | | - .addOutput(outputAt) |
78 | | - .setAttr("epsilon", static_cast<float>(eps)) |
79 | | - .run(); |
80 | | - |
81 | | - diopiTensorHandle_t runningVarBroadcasted; |
82 | | - makeTensorLike(ctx, &runningVarBroadcasted, input); |
83 | | - AscendTensor runningVarAt(runningVar); |
84 | | - runningVarAt.unsqueeze(0); |
85 | | - runningVarAt.unsqueeze(2); |
86 | | - runningVarAt.unsqueeze(3); |
87 | | - AclOpRunner<2, 1>("BroadcastTo", ctx).addInput(runningVarAt).addConstInput(inputAt.shape()).addOutput(runningVarBroadcasted).run(); |
88 | | - } else { |
89 | | - diopiTensorHandle_t sum = nullptr, squareSum = nullptr; |
90 | | - diopiSize_t shape, stride; |
91 | | - diopiGetTensorShape(runningMeanTemp, &shape); |
92 | | - diopiGetTensorStride(runningMeanTemp, &stride); |
93 | | - diopiRequireTensor(ctx, &sum, &shape, &stride, diopiDtype_t::diopi_dtype_float32, diopi_device); |
94 | | - diopiRequireTensor(ctx, &squareSum, &shape, &stride, diopiDtype_t::diopi_dtype_float32, diopi_device); |
95 | | - AclOpRunner<1, 2>("BNTrainingReduce", ctx).addInput(inputAt).addOutput(sum).setAttr("epsilon", static_cast<float>(eps)).addOutput(squareSum).run(); |
96 | | - AclOpRunner<7, 5>("BNTrainingUpdate", ctx) |
97 | | - .addInput(inputAt) |
98 | | - .addInput(sum) |
99 | | - .addInput(squareSum) |
100 | | - .addInput(weightTemp) |
101 | | - .addInput(biasTemp) |
102 | | - .addInput(runningMeanTemp) |
103 | | - .addInput(runningVarTemp) |
104 | | - .setAttr("epsilon", static_cast<float>(eps)) |
105 | | - .setAttr("factor", static_cast<float>(momentum)) |
106 | | - .addOutput(outputAt) |
107 | | - .addOutput(runningMeanTemp) |
108 | | - .addOutput(runningVarTemp) |
109 | | - .addOutput(saveMean) |
110 | | - .addOutput(saveInvstd) |
111 | | - .run(); |
112 | | - } |
| 16 | + DIOPI_ASCEND_CALL_ACLNN(aclnnBatchNorm, ctx, input, weight, bias, runningMean, runningVar, training, momentum, eps, out, saveMean, saveInvstd); |
113 | 17 | return diopiSuccess; |
114 | 18 | } |
115 | 19 |
|
116 | 20 | diopiError_t diopiBatchNormBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiTensorHandle_t gradWeight, diopiTensorHandle_t gradBias, |
117 | 21 | diopiConstTensorHandle_t gradOutput, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, |
118 | | - diopiConstTensorHandle_t runninMean, diopiConstTensorHandle_t runningVar, diopiConstTensorHandle_t saveMean, |
| 22 | + diopiConstTensorHandle_t runningMean, diopiConstTensorHandle_t runningVar, diopiConstTensorHandle_t saveMean, |
119 | 23 | diopiConstTensorHandle_t saveInvstd, bool training, double eps) { |
120 | | - AscendTensor inputAt(input), gradOutputAt(gradOutput), gradInputAt(gradInput); |
121 | | - updateInputAscendTensorDim(inputAt, training); |
122 | | - gradOutputAt.view(inputAt.getAclMemShape()); |
123 | | - gradInputAt.view(inputAt.getAclMemShape()); |
124 | | - |
125 | | - if (!training) { |
126 | | - batchNormBackwardTrainingUpdate(ctx, gradWeight, gradBias, gradOutputAt, inputAt, runninMean, runningVar, eps); |
127 | | - |
128 | | - AclOpRunner<3, 1>("BNInferGrad", ctx) |
129 | | - .addInput(gradOutputAt) |
130 | | - .addInput(weight) |
131 | | - .addInput(runningVar) |
132 | | - .addOutput(gradInputAt) |
133 | | - .setAttr<float>("epsilon", static_cast<float>(eps)) |
134 | | - .run(); |
135 | | - |
136 | | - diopiTensorHandle_t runningVarBroadcasted; |
137 | | - makeTensorLike(ctx, &runningVarBroadcasted, input); |
138 | | - AscendTensor runningVarAt(runningVar); |
139 | | - runningVarAt.unsqueeze(0); |
140 | | - runningVarAt.unsqueeze(2); |
141 | | - runningVarAt.unsqueeze(3); |
142 | | - AclOpRunner<2, 1>("BroadcastTo", ctx).addInput(runningVarAt).addConstInput(inputAt.shape()).addOutput(runningVarBroadcasted).run(); |
143 | | - } else { |
144 | | - batchNormBackwardTrainingUpdate(ctx, gradWeight, gradBias, gradOutputAt, inputAt, saveMean, saveInvstd, eps); |
145 | | - batchNormBackwardTrainingReduceNocheck(ctx, gradInputAt, gradWeight, gradBias, gradOutputAt, inputAt, weight, saveMean, saveInvstd, eps); |
| 24 | + std::array<bool, 3> gradMask = {true, true, true}; |
| 25 | + if (nullptr == gradInput) { |
| 26 | + gradMask[0] = false; |
| 27 | + } |
| 28 | + if (nullptr == gradWeight) { |
| 29 | + gradMask[1] = false; |
| 30 | + } |
| 31 | + if (nullptr == gradBias) { |
| 32 | + gradMask[2] = false; |
146 | 33 | } |
| 34 | + DIOPI_ASCEND_CALL_ACLNN(aclnnBatchNormBackward, |
| 35 | + ctx, |
| 36 | + gradOutput, |
| 37 | + input, |
| 38 | + weight, |
| 39 | + runningMean, |
| 40 | + runningVar, |
| 41 | + saveMean, |
| 42 | + saveInvstd, |
| 43 | + training, |
| 44 | + eps, |
| 45 | + gradMask, |
| 46 | + gradInput, |
| 47 | + gradWeight, |
| 48 | + gradBias); |
147 | 49 | return diopiSuccess; |
148 | 50 | } |
149 | 51 |
|
|
0 commit comments