Skip to content

Commit 7b59b56

Browse files
[Ascend] fuj / fix-index-error-for-ascend (#1353)
* fix index error for ascend * fix range-based for loop in index * add unittest for index * Update diopi_configs.py * fix dim check for AscendTensor squeeze --------- Co-authored-by: zhangzefeng92 <zhang_zefeng@foxmail.com>
1 parent 36af33a commit 7b59b56

File tree

4 files changed

+60
-3
lines changed

4 files changed

+60
-3
lines changed

diopi_test/python/configs/diopi_configs.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4953,6 +4953,33 @@
49534953
),
49544954
),
49554955

4956+
'index_mask': dict(
4957+
name=["index"],
4958+
interface=["CustomizedTest"],
4959+
# input[:, mask]
4960+
tensor_para=dict(
4961+
args=[
4962+
{
4963+
"ins": ['input'],
4964+
"shape": ((3,4,5), (4,5,6,7), (5,6,7,8,9)),
4965+
"gen_fn": 'Genfunc.randn',
4966+
"dtype": [np.int16, np.int32, np.int64, np.uint8, np.int8, np.bool_, np.float16, np.float32, np.float64],
4967+
},
4968+
{
4969+
"ins": ['idx1'],
4970+
"shape": (None, None, None),
4971+
"dtype": [np.int64]
4972+
},
4973+
{
4974+
"ins": ['idx2'],
4975+
"shape": ((4,5), (5,6,7), (6,7,8,9)),
4976+
"gen_fn": 'Genfunc.mask',
4977+
"dtype": [np.bool_]
4978+
}
4979+
]
4980+
)
4981+
),
4982+
49564983
'sgd': dict(
49574984
name=["sgd"],
49584985
interface=["CustomizedTest"],

impl/ascend/ascend_tensor.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,20 @@ AscendTensor& AscendTensor::unsqueeze(int dim) {
199199
return *this;
200200
}
201201

202+
AscendTensor& AscendTensor::squeeze(int dim) {
203+
auto shape = this->shape();
204+
if (shape[dim] != 1) {
205+
return *this;
206+
}
207+
auto strides = this->stride();
208+
209+
shape.erase(shape.begin() + dim);
210+
strides.erase(strides.begin() + dim);
211+
212+
this->asStrided(shape, strides);
213+
return *this;
214+
}
215+
202216
AscendTensor& AscendTensor::view(const std::vector<int64_t>& shape) {
203217
// must be contiguous
204218
ASCEND_CHECK_ABORT(this->isContiguous(), "now only contiguous tensor support view by shape.");

impl/ascend/ascend_tensor.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ class AscendTensor final {
244244
// Those methods may change the class attribute.
245245
AscendTensor& asStrided(const std::vector<int64_t>& shape, const std::vector<int64_t>& stride);
246246
AscendTensor& unsqueeze(int dim);
247+
AscendTensor& squeeze(int dim);
247248
AscendTensor& view(const std::vector<int64_t>& shape);
248249
AscendTensor& resize(const std::vector<int64_t>& shape);
249250
AscendTensor& select(int64_t dim, int64_t index);

impl/ascend/functions/index.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ static AscendTensor nonZeroTensor(diopiContextHandle_t ctx, const AscendTensor&
6262

6363
auto aclNZTensor = ::aclCreateTensor(
6464
nShape.data(), nShape.size(), aclDataType::ACL_INT64, nStride.data(), 0, aclFormat::ACL_FORMAT_ND, &numELem, 1, const_cast<void*>(nzTensor.data()));
65-
DIOPI_ASCEND_CALL_ACLNN(aclnnNonzero, ctx, self, aclNZTensor);
65+
DIOPI_ASCEND_CALL_ACLNN_SYNC(aclnnNonzero, ctx, self, aclNZTensor);
6666

6767
int64_t* vDims = nullptr;
6868
uint64_t vDimsNum = 0;
@@ -106,8 +106,23 @@ static std::vector<AscendTensor> expandIndicesTensors(diopiContextHandle_t ctx,
106106
srcIdx);
107107
}
108108
AscendTensor non = nonZeroTensor(ctx, t);
109-
for (int64_t j = 0; j < t.dim(); j++) {
110-
result.push_back(non.select(0, j));
109+
110+
auto shape = non.shape();
111+
shape[0] = 1;
112+
diopiSize_t size = vectorToDiopiSize(shape);
113+
std::vector<diopiTensorHandle_t> nons;
114+
115+
for (int i = 0; i < non.shape(0); i++) {
116+
diopiTensorHandle_t tmp = nullptr;
117+
diopiRequireTensor(ctx, &tmp, &size, nullptr, diopi_dtype_int64, diopi_device);
118+
nons.push_back(tmp);
119+
}
120+
std::vector<int64_t> splitSize(non.shape(0), 1);
121+
diopiSize_t splitSizeDiopi = vectorToDiopiSize(splitSize);
122+
DIOPI_ASCEND_CALL_ACLNN(aclnnSplitWithSize, ctx, non, splitSizeDiopi, 0, nons);
123+
for (const auto nj : nons) {
124+
AscendTensor njTensor(nj);
125+
result.push_back(njTensor.squeeze(0));
111126
}
112127
} else {
113128
result.push_back(t);

0 commit comments

Comments
 (0)