From 97d35c7955096cead637270d7ef038026a00c431 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 20 Mar 2026 15:12:07 +0900 Subject: [PATCH] Use nb::ndarray for checking arrays --- python/src/indexing.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 7a7e8b1102..9521ad04ca 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -3,10 +3,11 @@ #include #include -#include "python/src/convert.h" -#include "python/src/indexing.h" +#include #include "mlx/ops.h" +#include "python/src/convert.h" +#include "python/src/indexing.h" bool is_none_slice(const nb::slice& in_slice) { return ( @@ -22,12 +23,8 @@ bool is_index_scalar(const nb::object& obj) { if (!PyIndex_Check(obj.ptr())) { return false; } - // Exclude multi-dimensional arrays (mx.array, np.ndarray) by checking ndim - if (nb::hasattr(obj, "ndim")) { - auto ndim = nb::getattr(obj, "ndim"); - if (nb::isinstance(ndim) && nb::cast(ndim) > 0) { - return false; - } + if (nb::ndarray_check(obj) && nb::cast>(obj).ndim() > 0) { + return false; } return true; }