diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e82b51206e..a264ade0b7 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1787,6 +1787,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, const sycl::range<3> block_dims(1, 1, nth); const sycl::range<3> block_nums(1, nrows, 1); const size_t shared_mem = ncols_pad * sizeof(int); + GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo); if (order == GGML_SORT_ORDER_ASC) { stream->submit([&](sycl::handler &cgh) { @@ -4348,6 +4349,9 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_ } static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_sycl_device_context *sycl_ctx = + (ggml_backend_sycl_device_context *)dev->context; + int device = sycl_ctx->device; switch (op->op) { case GGML_OP_CONV_TRANSPOSE_1D: { @@ -4601,8 +4605,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: - case GGML_OP_ARGSORT: return ggml_is_contiguous(op->src[0]); + case GGML_OP_ARGSORT: + return op->src[0]->ne[0] * sizeof(int) <= + ggml_sycl_info().devices[device].smpbo; case GGML_OP_POOL_2D: case GGML_OP_ACC: return true;