Skip to content

Commit 17d721b

Browse files
authored
[ascend]Zq/update clamp by aclnn (#1194)
1 parent 608abcb commit 17d721b

File tree

2 files changed

+133
-152
lines changed

2 files changed

+133
-152
lines changed

impl/ascend/functions/clamp.cpp

Lines changed: 121 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -4,207 +4,188 @@
44
* @copyright (c) 2023, DeepLink.
55
*/
66

7-
#include <cfloat>
8-
#include <climits>
9-
#include <limits>
10-
#include <map>
11-
#include <string>
12-
7+
#include "../aclnn/acl_scalar.hpp"
8+
#include "../aclnn/adaptor.hpp"
139
#include "../common/acloprunner.hpp"
10+
#include "../common/utils.hpp"
1411

1512
namespace impl {
1613
namespace ascend {
1714

18-
// to get the limit value according to diopiDtype
19-
std::pair<double, double> getFloatMinMaxFromDtype(diopiDtype_t tensorDtype) {
20-
switch (tensorDtype) {
21-
case diopi_dtype_float16:
22-
return std::make_pair(std::numeric_limits<half_float::half>::lowest(), std::numeric_limits<half_float::half>::max());
23-
case diopi_dtype_float32:
24-
return std::make_pair(std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
25-
case diopi_dtype_float64:
26-
return std::make_pair(std::numeric_limits<double>::lowest(), std::numeric_limits<double>::max());
27-
default:
28-
break;
29-
}
30-
}
31-
32-
std::pair<int64_t, int64_t> getIntMinMaxFromDtype(diopiDtype_t tensorDtype) {
33-
switch (tensorDtype) {
34-
case diopi_dtype_int8:
35-
return std::make_pair(std::numeric_limits<int8_t>::lowest(), std::numeric_limits<int8_t>::max());
36-
case diopi_dtype_uint8:
37-
return std::make_pair(std::numeric_limits<uint8_t>::lowest(), std::numeric_limits<uint8_t>::max());
38-
case diopi_dtype_int16:
39-
return std::make_pair(std::numeric_limits<int16_t>::lowest(), std::numeric_limits<int16_t>::max());
40-
case diopi_dtype_uint16:
41-
return std::make_pair(std::numeric_limits<uint16_t>::lowest(), std::numeric_limits<uint16_t>::max());
42-
case diopi_dtype_int32:
43-
return std::make_pair(std::numeric_limits<int32_t>::lowest(), std::numeric_limits<int32_t>::max());
44-
case diopi_dtype_uint32:
45-
return std::make_pair(std::numeric_limits<uint32_t>::lowest(), std::numeric_limits<uint32_t>::max());
46-
case diopi_dtype_int64:
47-
return std::make_pair(std::numeric_limits<int64_t>::lowest(), std::numeric_limits<int64_t>::max());
48-
case diopi_dtype_uint64:
49-
return std::make_pair(std::numeric_limits<uint64_t>::lowest(), std::numeric_limits<uint64_t>::max());
50-
case diopi_dtype_bool:
51-
return std::make_pair(std::numeric_limits<bool>::lowest(), std::numeric_limits<bool>::max());
52-
default:
53-
break;
54-
}
55-
}
56-
5715
diopiError_t diopiClamp(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t min,
5816
diopiConstTensorHandle_t max) {
59-
diopiDtype_t outDtype, inputDtype;
60-
diopiTensorHandle_t minTmp, maxTmp, boolOut;
61-
diopiScalar_t minScalar, maxScalar;
62-
6317
AscendTensor inputAt(input);
6418
AscendTensor outAt(out);
65-
const std::vector<int64_t>& sizes = inputAt.shape();
66-
inputDtype = inputAt.dtype();
67-
outDtype = outAt.dtype();
6819

69-
if (min != nullptr) {
70-
makeTensorLike(ctx, &minTmp, input, outDtype);
71-
broadcast(ctx, minTmp, min, sizes);
72-
} else {
73-
makeTensorLike(ctx, &minTmp, input, outDtype);
74-
if (isFloatingType(outDtype)) {
75-
double minVal = getFloatMinMaxFromDtype(outDtype).first;
76-
minScalar = constructDiopiScalarT(outDtype, minVal);
77-
} else {
78-
int64_t minVal = getIntMinMaxFromDtype(outDtype).first;
79-
minScalar = constructDiopiScalarT(outDtype, minVal);
80-
}
81-
diopiFill(ctx, minTmp, &minScalar);
20+
if (input == nullptr || inputAt.numel() == 0) {
21+
return diopiSuccess;
8222
}
8323

84-
if (max != nullptr) {
85-
makeTensorLike(ctx, &maxTmp, input, outDtype);
86-
broadcast(ctx, maxTmp, max, sizes);
24+
castTensor(ctx, inputAt, outAt.dtype());
25+
26+
if (min != nullptr && max != nullptr) {
27+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampTensor, ctx, inputAt, min, max, outAt);
8728
} else {
88-
makeTensorLike(ctx, &maxTmp, input, outDtype);
89-
if (isFloatingType(outDtype)) {
90-
double maxVal = getFloatMinMaxFromDtype(outDtype).second;
91-
maxScalar = constructDiopiScalarT(outDtype, maxVal);
29+
if (max != nullptr) {
30+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampMaxTensor, ctx, inputAt, max, outAt);
9231
} else {
93-
int64_t maxVal = getIntMinMaxFromDtype(outDtype).second;
94-
maxScalar = constructDiopiScalarT(outDtype, maxVal);
32+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampMinTensor, ctx, inputAt, min, outAt);
9533
}
96-
diopiFill(ctx, maxTmp, &maxScalar);
9734
}
9835

99-
// Perform a clamp operation according PyTorch's special handling of the case when max is less than min.
100-
// In this case, update the value of min to be equal to max to ensure correct behavior.
101-
makeTensorLike(ctx, &boolOut, input, diopi_dtype_bool);
102-
diopiLt(ctx, boolOut, maxTmp, minTmp);
103-
diopiMaskedFill(ctx, minTmp, minTmp, boolOut, maxTmp);
104-
105-
AclOpRunner<3, 1> runner("ClipByValue", ctx);
106-
runner.addInput(input, outDtype).addInput(minTmp, outDtype).addInput(maxTmp, outDtype).addOutput(out).run();
10736
return diopiSuccess;
10837
}
10938

11039
diopiError_t diopiClampScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* minPtr,
11140
const diopiScalar_t* maxPtr) {
11241
AscendTensor inputAt(input);
11342
AscendTensor outAt(out);
114-
diopiDtype_t inputDtype, outDtype;
115-
diopiGetTensorDtype(input, &inputDtype);
116-
diopiGetTensorDtype(out, &outDtype);
117-
diopiScalar_t min, max;
118-
double minVal, maxVal;
119-
120-
if (minPtr != nullptr) {
121-
min = *minPtr;
122-
if (isFloatingType(min.stype)) {
123-
minVal = min.fval;
124-
} else {
125-
minVal = min.ival;
126-
}
127-
} else {
128-
if (isFloatingType(outDtype)) {
129-
double minLimitVal = getFloatMinMaxFromDtype(outDtype).first;
130-
min = constructDiopiScalarT(outDtype, minLimitVal);
131-
minVal = minLimitVal;
132-
} else {
133-
int64_t minLimitVal = getIntMinMaxFromDtype(outDtype).first;
134-
min = constructDiopiScalarT(outDtype, minLimitVal);
135-
minVal = minLimitVal;
136-
}
43+
44+
if (input == nullptr || inputAt.numel() == 0) {
45+
return diopiSuccess;
13746
}
47+
castTensor(ctx, inputAt, outAt.dtype());
13848

139-
if (maxPtr != nullptr) {
140-
max = *maxPtr;
141-
if (isFloatingType(max.stype)) {
142-
maxVal = max.fval;
143-
} else {
144-
maxVal = max.ival;
145-
}
49+
if (minPtr != nullptr && maxPtr != nullptr) {
50+
DIOPI_ASCEND_CALL_ACLNN(aclnnClamp, ctx, inputAt, minPtr, maxPtr, outAt);
14651
} else {
147-
if (isFloatingType(outDtype)) {
148-
double maxLimitVal = getFloatMinMaxFromDtype(outDtype).second;
149-
max = constructDiopiScalarT(outDtype, maxLimitVal);
150-
maxVal = maxLimitVal;
151-
} else {
152-
int64_t maxLimitVal = getIntMinMaxFromDtype(outDtype).second;
153-
max = constructDiopiScalarT(outDtype, maxLimitVal);
154-
maxVal = maxLimitVal;
52+
if (minPtr != nullptr) {
53+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampMin, ctx, inputAt, minPtr, outAt);
54+
}
55+
if (maxPtr != nullptr) {
56+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampMax, ctx, inputAt, maxPtr, outAt);
15557
}
15658
}
15759

158-
// Perform a clamp operation according PyTorch's special handling of the case when max is less than min.
159-
// In this case, update the value of min to be equal to max to ensure correct behavior.
160-
if (maxVal < minVal) {
161-
min = constructDiopiScalarT(outDtype, maxVal);
162-
}
163-
164-
AclOpRunner<3, 1> runner("ClipByValue", ctx);
165-
runner.addInput(input, outDtype).addConstInput(min, outDtype).addConstInput(max, outDtype).addOutput(out).run();
16660
return diopiSuccess;
16761
}
16862

16963
diopiError_t diopiClampInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t min, diopiConstTensorHandle_t max) {
170-
return diopiClamp(ctx, input, input, min, max);
64+
AscendTensor inputAt(input);
65+
if (input == nullptr || inputAt.numel() == 0) {
66+
return diopiSuccess;
67+
}
68+
69+
if (min != nullptr && max != nullptr) {
70+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampTensor, ctx, input, min, max, input);
71+
} else {
72+
if (max != nullptr) {
73+
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceClampMaxTensor, ctx, input, max);
74+
} else {
75+
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceClampMinTensor, ctx, input, min);
76+
}
77+
}
78+
79+
return diopiSuccess;
17180
}
17281

17382
diopiError_t diopiClampInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* min, const diopiScalar_t* max) {
174-
return diopiClampScalar(ctx, input, input, min, max);
83+
AscendTensor inputAt(input);
84+
if (input == nullptr || inputAt.numel() == 0) {
85+
return diopiSuccess;
86+
}
87+
88+
if (min != nullptr && max != nullptr) {
89+
DIOPI_ASCEND_CALL_ACLNN(aclnnClamp, ctx, input, min, max, input);
90+
} else {
91+
if (max != nullptr) {
92+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampMax, ctx, input, max, input);
93+
} else {
94+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampMin, ctx, input, min, input);
95+
}
96+
}
97+
return diopiSuccess;
17598
}
17699

177100
DIOPI_API diopiError_t diopiClampMinInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* min) {
178-
return diopiClampMinScalar(ctx, input, input, min);
101+
AscendTensor inputAt(input);
102+
if (input == nullptr || inputAt.numel() == 0) {
103+
return diopiSuccess;
104+
}
105+
106+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampMin, ctx, input, min, input);
107+
return diopiSuccess;
179108
}
180109

181110
DIOPI_API diopiError_t diopiClampMinInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t min) {
182-
return diopiClampMin(ctx, input, input, min);
111+
AscendTensor inputAt(input);
112+
if (input == nullptr || inputAt.numel() == 0) {
113+
return diopiSuccess;
114+
}
115+
116+
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceClampMinTensor, ctx, input, min);
117+
return diopiSuccess;
183118
}
184119

185120
DIOPI_API diopiError_t diopiClampMinScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* min) {
186-
return diopiClampScalar(ctx, out, input, min, nullptr);
121+
AscendTensor inputAt(input);
122+
AscendTensor outAt(out);
123+
124+
if (input == nullptr || inputAt.numel() == 0) {
125+
return diopiSuccess;
126+
}
127+
128+
castTensor(ctx, inputAt, outAt.dtype());
129+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampMin, ctx, inputAt, min, outAt);
130+
return diopiSuccess;
187131
}
188132

189133
DIOPI_API diopiError_t diopiClampMin(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t min) {
190-
return diopiClamp(ctx, out, input, min, nullptr);
134+
AscendTensor inputAt(input);
135+
AscendTensor outAt(out);
136+
137+
if (input == nullptr || inputAt.numel() == 0) {
138+
return diopiSuccess;
139+
}
140+
141+
castTensor(ctx, inputAt, outAt.dtype());
142+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampMinTensor, ctx, inputAt, min, outAt);
143+
return diopiSuccess;
191144
}
192145

193146
DIOPI_API diopiError_t diopiClampMaxInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* max) {
194-
return diopiClampMaxScalar(ctx, input, input, max);
147+
AscendTensor inputAt(input);
148+
if (input == nullptr || inputAt.numel() == 0) {
149+
return diopiSuccess;
150+
}
151+
152+
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceClampMax, ctx, input, max);
153+
return diopiSuccess;
195154
}
196155

197156
DIOPI_API diopiError_t diopiClampMaxInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t max) {
198-
return diopiClampMax(ctx, input, input, max);
157+
AscendTensor inputAt(input);
158+
if (input == nullptr || inputAt.numel() == 0) {
159+
return diopiSuccess;
160+
}
161+
162+
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceClampMaxTensor, ctx, input, max);
163+
return diopiSuccess;
199164
}
200165

201166
DIOPI_API diopiError_t diopiClampMaxScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* max) {
202-
return diopiClampScalar(ctx, out, input, nullptr, max);
167+
AscendTensor inputAt(input);
168+
AscendTensor outAt(out);
169+
if (input == nullptr || inputAt.numel() == 0) {
170+
return diopiSuccess;
171+
}
172+
173+
castTensor(ctx, inputAt, outAt.dtype());
174+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampMax, ctx, inputAt, max, outAt);
175+
return diopiSuccess;
203176
}
204177

205178
DIOPI_API diopiError_t diopiClampMax(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t max) {
206-
return diopiClamp(ctx, out, input, nullptr, max);
207-
}
179+
AscendTensor inputAt(input);
180+
AscendTensor outAt(out);
181+
182+
if (input == nullptr || inputAt.numel() == 0) {
183+
return diopiSuccess;
184+
}
208185

186+
castTensor(ctx, inputAt, outAt.dtype());
187+
DIOPI_ASCEND_CALL_ACLNN(aclnnClampMaxTensor, ctx, inputAt, max, outAt);
188+
return diopiSuccess;
189+
}
209190
} // namespace ascend
210191
} // namespace impl

impl/ascend_npu/ascend_config.yaml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@ ascend:
2424
- diopiBitwiseOrScalar
2525
- diopiBitwiseOrInpScalar
2626
- diopiCastDtype
27+
- diopiClamp
28+
- diopiClampInp
29+
- diopiClampInpScalar
30+
- diopiClampMax
31+
- diopiClampMaxInp
32+
- diopiClampMaxInpScalar
33+
- diopiClampMaxScalar
34+
- diopiClampMin
35+
- diopiClampMinInp
36+
- diopiClampMinInpScalar
37+
- diopiClampMinScalar
38+
- diopiClampScalar
2739
- diopiCeil
2840
- diopiCeilInp
2941
- diopiCol2Im
@@ -244,18 +256,6 @@ ascend_npu:
244256
- diopiScatterInp
245257
- diopiScatterScalar
246258
- diopiScatterInpScalar
247-
- diopiClamp
248-
- diopiClampInp
249-
- diopiClampInpScalar
250-
- diopiClampMax
251-
- diopiClampMaxInp
252-
- diopiClampMaxInpScalar
253-
- diopiClampMaxScalar
254-
- diopiClampMin
255-
- diopiClampMinInp
256-
- diopiClampMinInpScalar
257-
- diopiClampMinScalar
258-
- diopiClampScalar
259259
- diopiUpsampleLinear
260260
- diopiUpsampleLinearBackward
261261
- diopiUpsampleNearest

0 commit comments

Comments
 (0)