Skip to content

Commit 98bd9ab

Browse files
enhance argsort for UT (#17573)
Co-authored-by: Neo Zhang <zhang.jianyu@outlook.com>
1 parent 746f9ee commit 98bd9ab

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1787,6 +1787,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
17871787
const sycl::range<3> block_dims(1, 1, nth);
17881788
const sycl::range<3> block_nums(1, nrows, 1);
17891789
const size_t shared_mem = ncols_pad * sizeof(int);
1790+
GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo);
17901791

17911792
if (order == GGML_SORT_ORDER_ASC) {
17921793
stream->submit([&](sycl::handler &cgh) {
@@ -4348,6 +4349,9 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_
43484349
}
43494350

43504351
static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4352+
ggml_backend_sycl_device_context *sycl_ctx =
4353+
(ggml_backend_sycl_device_context *)dev->context;
4354+
int device = sycl_ctx->device;
43514355
switch (op->op) {
43524356
case GGML_OP_CONV_TRANSPOSE_1D:
43534357
{
@@ -4601,8 +4605,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
46014605
case GGML_OP_SUM:
46024606
case GGML_OP_SUM_ROWS:
46034607
case GGML_OP_MEAN:
4604-
case GGML_OP_ARGSORT:
46054608
return ggml_is_contiguous(op->src[0]);
4609+
case GGML_OP_ARGSORT:
4610+
return op->src[0]->ne[0] * sizeof(int) <=
4611+
ggml_sycl_info().devices[device].smpbo;
46064612
case GGML_OP_POOL_2D:
46074613
case GGML_OP_ACC:
46084614
return true;

0 commit comments

Comments
 (0)