Skip to content

Commit 608abcb

Browse files
authored
[Ascend] fuj/aclnn-replace (#1195)
* fix maximum and minimum
1 parent a536285 commit 608abcb

File tree

3 files changed

+34
-17
lines changed

3 files changed

+34
-17
lines changed

impl/ascend/functions/binary.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,21 @@ diopiError_t diopiDivInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t inp
9797
}
9898

9999
diopiError_t diopiMaximum(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t other) {
100+
AscendTensor outAt(out);
101+
if (outAt.numel() == 0) {
102+
return diopiSuccess;
103+
}
104+
100105
DIOPI_ASCEND_CALL_ACLNN(aclnnMaximum, ctx, input, other, out);
101106
return diopiSuccess;
102107
}
103108

104109
diopiError_t diopiMinimum(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t other) {
110+
AscendTensor outAt(out);
111+
if (outAt.numel() == 0) {
112+
return diopiSuccess;
113+
}
114+
105115
DIOPI_ASCEND_CALL_ACLNN(aclnnMinimum, ctx, input, other, out);
106116
return diopiSuccess;
107117
}

impl/ascend/functions/minmax.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,40 @@
44
* @copyright (c) 2023, DeepLink.
55
*/
66

7-
#include <numeric>
8-
9-
#include "../common/acloprunner.hpp"
7+
#include "../aclnn/acl_scalar.hpp"
8+
#include "../aclnn/adaptor.hpp"
109

1110
namespace impl {
1211
namespace ascend {
1312
diopiError_t diopiMax(diopiContextHandle_t ctx, diopiTensorHandle_t max, diopiTensorHandle_t maxIndices, diopiConstTensorHandle_t input, int64_t dim) {
14-
AclOpRunner<1, 2>("ArgMaxWithValue", ctx).setAttr<int>("dimension", static_cast<int>(dim)).addInput(input).addOutput(maxIndices).addOutput(max).run();
13+
AscendTensor inAt(input);
14+
AscendTensor maxAt(max);
15+
bool keepdim = false;
16+
if (inAt.dim() == maxAt.dim()) {
17+
keepdim = true;
18+
}
19+
DIOPI_ASCEND_CALL_ACLNN(aclnnMaxDim, ctx, input, dim, keepdim, max, maxIndices);
1520
return diopiSuccess;
1621
}
1722

1823
diopiError_t diopiMaxAll(diopiContextHandle_t ctx, diopiTensorHandle_t max, diopiConstTensorHandle_t input) {
19-
diopiSize_t inS;
20-
diopiGetTensorShape(input, &inS);
21-
std::vector<int64_t> dimAllVector(inS.len);
22-
std::iota(std::begin(dimAllVector), std::end(dimAllVector), 0);
23-
diopiSize_t dimAll = vectorToDiopiSize(dimAllVector);
24-
AclOpRunner<2, 1>("ReduceMax", ctx).addInput(input).addConstInput(dimAll).addOutput(max).run();
24+
DIOPI_ASCEND_CALL_ACLNN(aclnnMax, ctx, input, max);
2525
return diopiSuccess;
2626
}
2727

2828
diopiError_t diopiMin(diopiContextHandle_t ctx, diopiTensorHandle_t min, diopiTensorHandle_t minIndices, diopiConstTensorHandle_t input, int64_t dim) {
29-
AclOpRunner<1, 2>("ArgMinWithValue", ctx).setAttr<int>("dimension", static_cast<int>(dim)).addInput(input).addOutput(minIndices).addOutput(min).run();
29+
AscendTensor inAt(input);
30+
AscendTensor minAt(min);
31+
bool keepdim = false;
32+
if (inAt.dim() == minAt.dim()) {
33+
keepdim = true;
34+
}
35+
DIOPI_ASCEND_CALL_ACLNN(aclnnMinDim, ctx, input, dim, keepdim, min, minIndices);
3036
return diopiSuccess;
3137
}
3238

3339
diopiError_t diopiMinAll(diopiContextHandle_t ctx, diopiTensorHandle_t min, diopiConstTensorHandle_t input) {
34-
diopiSize_t inS;
35-
diopiGetTensorShape(input, &inS);
36-
std::vector<int64_t> dimAllVector(inS.len);
37-
std::iota(std::begin(dimAllVector), std::end(dimAllVector), 0);
38-
diopiSize_t dimAll = vectorToDiopiSize(dimAllVector);
39-
AclOpRunner<2, 1>("ReduceMin", ctx).addInput(input).addConstInput(dimAll).addOutput(min).run();
40+
DIOPI_ASCEND_CALL_ACLNN(aclnnMin, ctx, input, min);
4041
return diopiSuccess;
4142
}
4243

impl/ascend_npu/ascend_config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,19 @@ ascend:
9898
- diopiLtInp
9999
- diopiLtInpScalar
100100
- diopiLtScalar
101+
- diopiMax
102+
- diopiMaxAll
103+
- diopiMaximum
101104
- diopiMean
102105
- diopiMSELoss
103106
- diopiMSELossBackward
104107
- diopiMul
105108
- diopiMulInp
106109
- diopiMulInpScalar
107110
- diopiMulScalar
111+
- diopiMin
112+
- diopiMinAll
113+
- diopiMinimum
108114
- diopiNe
109115
- diopiNeInp
110116
- diopiNeInpScalar

0 commit comments

Comments
 (0)