Skip to content

Commit a45c104

Browse files
authored
[ascend] tyf/fixMaskedSelect (#1206)
1 parent 17d721b commit a45c104

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

impl/camb/functions/masked_select.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,21 @@
66
namespace impl {
77
namespace camb {
88

9-
diopiError_t diopiMaskedSelect(diopiContextHandle_t ctx, diopiTensorHandle_t *out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t mask) {
9+
diopiError_t diopiMaskedSelect(diopiContextHandle_t ctx, diopiTensorHandle_t* out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t mask) {
1010
cnnlHandle_t handle = cnnlHandlePool.get(ctx);
1111
DiopiTensor inputTensor(input);
1212
DiopiTensor maskTensor(mask);
1313

14-
std::vector<DiopiTensor *> pmask{&maskTensor};
14+
DIOPI_CALL(contiguous(ctx, inputTensor));
15+
DIOPI_CALL(contiguous(ctx, maskTensor));
16+
17+
std::vector<DiopiTensor*> pmask{&maskTensor};
1518
std::set<diopiDtype_t> maskDtypes{diopi_dtype_bool};
1619
DIOPI_CALL(autoCastTensorType(ctx, pmask, maskDtypes));
1720
// When the data type of masked tensor is not bool, the data type of input
1821
// tensor must be same with the data type of the masked tensor.
1922

20-
std::vector<DiopiTensor *> pinput{&inputTensor};
23+
std::vector<DiopiTensor*> pinput{&inputTensor};
2124
std::set<diopiDtype_t> inputDtypes{
2225
diopi_dtype_bool, diopi_dtype_int8, diopi_dtype_uint8, diopi_dtype_int16, diopi_dtype_int32, diopi_dtype_float16, diopi_dtype_float32};
2326
DIOPI_CALL(autoCastTensorType(ctx, pinput, inputDtypes));
@@ -32,7 +35,7 @@ diopiError_t diopiMaskedSelect(diopiContextHandle_t ctx, diopiTensorHandle_t *ou
3235

3336
size_t workspaceSize = 0;
3437
DIOPI_CALL_CNNL(cnnlGetMaskedWorkspaceSize(handle, maskMode, inputDesc.get(), maskDesc.get(), nullptr, outDesc.get(), &workspaceSize));
35-
void *workspace = nullptr;
38+
void* workspace = nullptr;
3639
if (0 != workspaceSize) {
3740
workspace = requiresBuffer(ctx, workspaceSize).data();
3841
}
@@ -55,7 +58,7 @@ diopiError_t diopiMaskedSelect(diopiContextHandle_t ctx, diopiTensorHandle_t *ou
5558
workspaceSize,
5659
outDesc.get(),
5760
tempOutputTensor.data(),
58-
reinterpret_cast<uint32_t *>(numTrue.data())));
61+
reinterpret_cast<uint32_t*>(numTrue.data())));
5962
#else
6063
DIOPI_CALL_CNNL(cnnlMasked_v3(handle,
6164
maskMode,
@@ -69,7 +72,7 @@ diopiError_t diopiMaskedSelect(diopiContextHandle_t ctx, diopiTensorHandle_t *ou
6972
workspaceSize,
7073
outDesc.get(),
7174
tempOutputTensor.data(),
72-
reinterpret_cast<uint32_t *>(numTrue.data())));
75+
reinterpret_cast<uint32_t*>(numTrue.data())));
7376
#endif
7477
syncStreamInCtx(ctx);
7578
uint32_t numTrueHost = 0;
@@ -96,11 +99,11 @@ DIOPI_API diopiError_t diopiMaskedSelectBackward(diopiContextHandle_t ctx, diopi
9699
return diopiSuccess;
97100
}
98101

99-
std::vector<DiopiTensor *> pmask{&maskTensor};
102+
std::vector<DiopiTensor*> pmask{&maskTensor};
100103
std::set<diopiDtype_t> maskDtypes{diopi_dtype_bool};
101104
DIOPI_CALL(autoCastTensorType(ctx, pmask, maskDtypes));
102105

103-
std::vector<DiopiTensor *> pGradInput{&tempGradInputTensor, &gradOutputTensor};
106+
std::vector<DiopiTensor*> pGradInput{&tempGradInputTensor, &gradOutputTensor};
104107
std::set<diopiDtype_t> gradInputDtypes{diopi_dtype_float16, diopi_dtype_float32};
105108
DIOPI_CALL(autoCastTensorType(ctx, pGradInput, gradInputDtypes));
106109

@@ -111,7 +114,7 @@ DIOPI_API diopiError_t diopiMaskedSelectBackward(diopiContextHandle_t ctx, diopi
111114

112115
size_t workspaceSize = 0;
113116
DIOPI_CALL_CNNL(cnnlGetMaskedWorkspaceSize(handle, maskMode, gradInputDesc.get(), maskDesc.get(), gradOutDesc.get(), gradInputDesc.get(), &workspaceSize));
114-
void *workspace = nullptr;
117+
void* workspace = nullptr;
115118
if (0 != workspaceSize) {
116119
workspace = requiresBuffer(ctx, workspaceSize).data();
117120
}

0 commit comments

Comments
 (0)