66namespace impl {
77namespace camb {
88
9- diopiError_t diopiMaskedSelect (diopiContextHandle_t ctx, diopiTensorHandle_t * out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t mask) {
9+ diopiError_t diopiMaskedSelect (diopiContextHandle_t ctx, diopiTensorHandle_t* out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t mask) {
1010 cnnlHandle_t handle = cnnlHandlePool.get (ctx);
1111 DiopiTensor inputTensor (input);
1212 DiopiTensor maskTensor (mask);
1313
14- std::vector<DiopiTensor *> pmask{&maskTensor};
14+ DIOPI_CALL (contiguous (ctx, inputTensor));
15+ DIOPI_CALL (contiguous (ctx, maskTensor));
16+
17+ std::vector<DiopiTensor*> pmask{&maskTensor};
1518 std::set<diopiDtype_t> maskDtypes{diopi_dtype_bool};
1619 DIOPI_CALL (autoCastTensorType (ctx, pmask, maskDtypes));
1720 // When the data type of masked tensor is not bool, the data type of input
1821 // tensor must be same with the data type of the masked tensor.
1922
20- std::vector<DiopiTensor *> pinput{&inputTensor};
23+ std::vector<DiopiTensor*> pinput{&inputTensor};
2124 std::set<diopiDtype_t> inputDtypes{
2225 diopi_dtype_bool, diopi_dtype_int8, diopi_dtype_uint8, diopi_dtype_int16, diopi_dtype_int32, diopi_dtype_float16, diopi_dtype_float32};
2326 DIOPI_CALL (autoCastTensorType (ctx, pinput, inputDtypes));
@@ -32,7 +35,7 @@ diopiError_t diopiMaskedSelect(diopiContextHandle_t ctx, diopiTensorHandle_t *ou
3235
3336 size_t workspaceSize = 0 ;
3437 DIOPI_CALL_CNNL (cnnlGetMaskedWorkspaceSize (handle, maskMode, inputDesc.get (), maskDesc.get (), nullptr , outDesc.get (), &workspaceSize));
35- void * workspace = nullptr ;
38+ void * workspace = nullptr ;
3639 if (0 != workspaceSize) {
3740 workspace = requiresBuffer (ctx, workspaceSize).data ();
3841 }
@@ -55,7 +58,7 @@ diopiError_t diopiMaskedSelect(diopiContextHandle_t ctx, diopiTensorHandle_t *ou
5558 workspaceSize,
5659 outDesc.get (),
5760 tempOutputTensor.data (),
58- reinterpret_cast <uint32_t *>(numTrue.data ())));
61+ reinterpret_cast <uint32_t *>(numTrue.data ())));
5962#else
6063 DIOPI_CALL_CNNL (cnnlMasked_v3 (handle,
6164 maskMode,
@@ -69,7 +72,7 @@ diopiError_t diopiMaskedSelect(diopiContextHandle_t ctx, diopiTensorHandle_t *ou
6972 workspaceSize,
7073 outDesc.get (),
7174 tempOutputTensor.data (),
72- reinterpret_cast <uint32_t *>(numTrue.data ())));
75+ reinterpret_cast <uint32_t *>(numTrue.data ())));
7376#endif
7477 syncStreamInCtx (ctx);
7578 uint32_t numTrueHost = 0 ;
@@ -96,11 +99,11 @@ DIOPI_API diopiError_t diopiMaskedSelectBackward(diopiContextHandle_t ctx, diopi
9699 return diopiSuccess;
97100 }
98101
99- std::vector<DiopiTensor *> pmask{&maskTensor};
102+ std::vector<DiopiTensor*> pmask{&maskTensor};
100103 std::set<diopiDtype_t> maskDtypes{diopi_dtype_bool};
101104 DIOPI_CALL (autoCastTensorType (ctx, pmask, maskDtypes));
102105
103- std::vector<DiopiTensor *> pGradInput{&tempGradInputTensor, &gradOutputTensor};
106+ std::vector<DiopiTensor*> pGradInput{&tempGradInputTensor, &gradOutputTensor};
104107 std::set<diopiDtype_t> gradInputDtypes{diopi_dtype_float16, diopi_dtype_float32};
105108 DIOPI_CALL (autoCastTensorType (ctx, pGradInput, gradInputDtypes));
106109
@@ -111,7 +114,7 @@ DIOPI_API diopiError_t diopiMaskedSelectBackward(diopiContextHandle_t ctx, diopi
111114
112115 size_t workspaceSize = 0 ;
113116 DIOPI_CALL_CNNL (cnnlGetMaskedWorkspaceSize (handle, maskMode, gradInputDesc.get (), maskDesc.get (), gradOutDesc.get (), gradInputDesc.get (), &workspaceSize));
114- void * workspace = nullptr ;
117+ void * workspace = nullptr ;
115118 if (0 != workspaceSize) {
116119 workspace = requiresBuffer (ctx, workspaceSize).data ();
117120 }
0 commit comments