From 31250651447370d412a6edbe29a0b771b8ed6d97 Mon Sep 17 00:00:00 2001 From: AN Long Date: Fri, 6 Mar 2026 00:30:09 +0900 Subject: [PATCH 1/4] Support indexing with any type which implmented __index__ --- python/src/indexing.cpp | 61 ++++++++++++++++++++++---------------- python/tests/test_array.py | 8 +++++ 2 files changed, 43 insertions(+), 26 deletions(-) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 564c4cb45b..49bd2c94d8 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -15,8 +15,17 @@ bool is_none_slice(const nb::slice& in_slice) { nb::getattr(in_slice, "step").is_none()); } +bool is_index_scalar(const nb::object& obj) { + return !nb::isinstance(obj) && PyIndex_Check(obj.ptr()); +} + int safe_to_int32(nb::object obj) { - auto val = nb::cast(nb::cast(obj)); + auto idx = nb::steal(PyNumber_Index(obj.ptr())); + if (!idx.is_valid()) { + throw nb::python_error(); + } + + auto val = nb::cast(nb::cast(idx)); if (val > INT32_MAX || val < INT32_MIN) { throw std::invalid_argument("Slice indices must be 32-bit integers."); } @@ -25,7 +34,7 @@ int safe_to_int32(nb::object obj) { int get_slice_int(nb::object obj, int default_val) { if (!obj.is_none()) { - if (!nb::isinstance(obj)) { + if (!is_index_scalar(obj)) { throw std::invalid_argument("Slice indices must be integers or None."); } return safe_to_int32(obj); @@ -60,9 +69,9 @@ mx::array get_int_index(nb::object idx, int axis_size) { } bool is_valid_index_type(const nb::object& obj) { - return nb::isinstance(obj) || nb::isinstance(obj) || - nb::isinstance(obj) || obj.is_none() || - nb::ellipsis().is(obj) || nb::isinstance(obj); + return nb::isinstance(obj) || is_index_scalar(obj) || + nb::isinstance(obj) || obj.is_none() || nb::ellipsis().is(obj) || + nb::isinstance(obj); } mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) { @@ -102,7 +111,7 @@ mx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) { return take(src, indices, 0); } -mx::array mlx_get_item_int(const mx::array& src, const nb::int_& idx) { +mx::array mlx_get_item_int(const mx::array& src, const nb::object& idx) { // Check input and raise error if 0 dim for parity with np if (src.ndim() == 0) { throw std::invalid_argument( @@ -139,7 +148,7 @@ mx::array mlx_gather_nd( gather_indices.push_back(arange(start, end, stride, mx::uint32)); num_slices++; is_slice[i] = true; - } else if (nb::isinstance(idx)) { + } else if (is_index_scalar(idx)) { gather_indices.push_back(get_int_index(idx, src.shape(i))); } else if (nb::isinstance(idx)) { auto arr = nb::cast(idx); @@ -289,7 +298,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { bool have_non_array = false; bool gather_first = false; for (auto& idx : indices) { - if (nb::isinstance(idx) || (nb::isinstance(idx))) { + if (nb::isinstance(idx) || is_index_scalar(idx)) { if (have_array && have_non_array) { gather_first = true; break; @@ -312,7 +321,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { // Then find the last array for (last_array = indices.size() - 1; last_array >= 0; last_array--) { auto& idx = indices[last_array]; - if (nb::isinstance(idx) || nb::isinstance(idx)) { + if (nb::isinstance(idx) || is_index_scalar(idx)) { break; } } @@ -348,7 +357,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { } else { for (int i = 0; i < indices.size(); i++) { auto& idx = indices[i]; - if (nb::isinstance(idx) || nb::isinstance(idx)) { + if (nb::isinstance(idx) || is_index_scalar(idx)) { break; } else if (idx.is_none()) { remaining_indices.push_back(idx); @@ -385,8 +394,8 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { int axis = 0; for (auto& idx : remaining_indices) { if (!idx.is_none()) { - if (!have_array && nb::isinstance(idx)) { - int st = nb::cast(idx); + if (!have_array && is_index_scalar(idx)) { + int st = safe_to_int32(idx); st = (st < 0) ? st + src.shape(axis) : st; starts[axis] = st; @@ -419,7 +428,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { auto& idx = remaining_indices[axis]; if (unsqueeze_needed && idx.is_none()) { unsqueeze_axes.push_back(axis - squeeze_axes.size()); - } else if (squeeze_needed && nb::isinstance(idx)) { + } else if (squeeze_needed && is_index_scalar(idx)) { squeeze_axes.push_back(axis - unsqueeze_axes.size()); } } @@ -439,8 +448,8 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) { return mlx_get_item_slice(src, nb::cast(obj)); } else if (nb::isinstance(obj)) { return mlx_get_item_array(src, nb::cast(obj)); - } else if (nb::isinstance(obj)) { - return mlx_get_item_int(src, nb::cast(obj)); + } else if (is_index_scalar(obj)) { + return mlx_get_item_int(src, obj); } else if (nb::isinstance(obj)) { return mlx_get_item_nd(src, nb::cast(obj)); } else if (nb::isinstance(obj)) { @@ -457,7 +466,7 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) { std::tuple, mx::array, std::vector> mlx_scatter_args_int( const mx::array& src, - const nb::int_& idx, + const nb::object& idx, const mx::array& update) { if (src.ndim() == 0) { throw std::invalid_argument( @@ -685,7 +694,7 @@ mlx_scatter_args_nd( // Add the shape to the update update_shape[ax - 1] = 1; } - } else if (nb::isinstance(pyidx)) { + } else if (is_index_scalar(pyidx)) { // Add index to arrays arr_indices.push_back(get_int_index(pyidx, src.shape(ax++))); // Add the shape to the update @@ -755,8 +764,8 @@ mlx_compute_scatter_args( return mlx_scatter_args_slice(src, nb::cast(obj), vals); } else if (nb::isinstance(obj)) { return mlx_scatter_args_array(src, nb::cast(obj), vals); - } else if (nb::isinstance(obj)) { - return mlx_scatter_args_int(src, nb::cast(obj), vals); + } else if (is_index_scalar(obj)) { + return mlx_scatter_args_int(src, obj, vals); } else if (nb::isinstance(obj)) { return mlx_scatter_args_nd(src, nb::cast(obj), vals); } else if (obj.is_none()) { @@ -774,9 +783,9 @@ auto mlx_slice_update( const nb::object& obj, const ScalarOrArray& v) { // Can't route to slice update if not slice, tuple, or int - if (src.ndim() == 0 || nb::isinstance(obj) || - (!nb::isinstance(obj) && !nb::isinstance(obj) && - !nb::isinstance(obj))) { + if (src.ndim() == 0 || (!nb::isinstance(obj) && + !nb::isinstance(obj) && + !is_index_scalar(obj))) { return std::make_pair(false, src); } if (nb::isinstance(obj)) { @@ -806,13 +815,13 @@ auto mlx_slice_update( mx::Shape starts(src.ndim(), 0); mx::Shape stops = src.shape(); mx::Shape strides(src.ndim(), 1); - if (nb::isinstance(obj)) { + if (is_index_scalar(obj)) { if (src.ndim() < 1) { std::ostringstream msg; msg << "Too many indices for array with " << src.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - auto idx = nb::cast(obj); + auto idx = safe_to_int32(obj); idx = idx < 0 ? idx + stops[0] : idx; starts[0] = idx; stops[0] = idx + 1; @@ -872,8 +881,8 @@ auto mlx_slice_update( src.shape(ax)); ax--; upd_ax--; - } else if (nb::isinstance(pyidx)) { - int st = nb::cast(pyidx); + } else if (is_index_scalar(pyidx)) { + int st = safe_to_int32(pyidx); st = (st < 0) ? st + src.shape(i) : st; starts[ax] = st; stops[ax] = st + 1; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 4efed9dac9..c34a0b71bb 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -964,6 +964,11 @@ def test_indexing(self): a_sliced_mlx = a_mlx[-1] self.assertTrue(np.array_equal(a_sliced_mlx, a_npy[-1])) + # NumPy integer scalar indexing + a_sliced_mlx = a_mlx[np.int64(5)] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[np.int64(5)])) + # Basic content check, empty index a_sliced_mlx = a_mlx[()] a_sliced_npy = np.asarray(a_sliced_mlx) @@ -1125,6 +1130,9 @@ def test_setitem(self): a[-1] = 2 self.assertEqual(a.tolist(), [2, 2, 2]) + a[np.int64(1)] = 9 + self.assertEqual(a.tolist(), [2, 9, 2]) + a[0] = mx.array([[[1]]]) self.assertEqual(a.tolist(), [1, 2, 2]) From 74134703aee116c29e13974118d9f3b212399acb Mon Sep 17 00:00:00 2001 From: AN Long Date: Wed, 18 Mar 2026 21:27:45 +0800 Subject: [PATCH 2/4] Apply clang-format to fix the lint issue --- python/src/indexing.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 49bd2c94d8..47e2a51427 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -70,8 +70,8 @@ mx::array get_int_index(nb::object idx, int axis_size) { bool is_valid_index_type(const nb::object& obj) { return nb::isinstance(obj) || is_index_scalar(obj) || - nb::isinstance(obj) || obj.is_none() || nb::ellipsis().is(obj) || - nb::isinstance(obj); + nb::isinstance(obj) || obj.is_none() || + nb::ellipsis().is(obj) || nb::isinstance(obj); } mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) { @@ -783,9 +783,9 @@ auto mlx_slice_update( const nb::object& obj, const ScalarOrArray& v) { // Can't route to slice update if not slice, tuple, or int - if (src.ndim() == 0 || (!nb::isinstance(obj) && - !nb::isinstance(obj) && - !is_index_scalar(obj))) { + if (src.ndim() == 0 || + (!nb::isinstance(obj) && !nb::isinstance(obj) && + !is_index_scalar(obj))) { return std::make_pair(false, src); } if (nb::isinstance(obj)) { From 1a58c8b1d023ac65e40f2e1d30c42a50d2a6c388 Mon Sep 17 00:00:00 2001 From: AN Long Date: Wed, 18 Mar 2026 22:17:55 +0800 Subject: [PATCH 3/4] Fix error introduced in conflict resolving --- python/src/indexing.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index a952c363a2..ac41e6223e 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -978,7 +978,7 @@ mx::array mlx_add_item( auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { - return scatterAdd(src, indices, updates, axes); + return scatter_add(src, indices, updates, axes); } else { return src + updates; } @@ -996,7 +996,7 @@ mx::array mlx_subtract_item( auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { - return scatterAdd(src, indices, -updates, axes); + return scatter_add(src, indices, -updates, axes); } else { return src - updates; } @@ -1014,7 +1014,7 @@ mx::array mlx_multiply_item( auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { - return scatterProd(src, indices, updates, axes); + return scatter_prod(src, indices, updates, axes); } else { return src * updates; } @@ -1032,7 +1032,7 @@ mx::array mlx_divide_item( auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { - return scatterProd(src, indices, reciprocal(updates), axes); + return scatter_prod(src, indices, reciprocal(updates), axes); } else { return src / updates; } @@ -1050,7 +1050,7 @@ mx::array mlx_maximum_item( auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { - return scatterMax(src, indices, updates, axes); + return scatter_max(src, indices, updates, axes); } else { return maximum(src, updates); } @@ -1068,7 +1068,7 @@ mx::array mlx_minimum_item( auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { - return scatterMin(src, indices, updates, axes); + return scatter_min(src, indices, updates, axes); } else { return minimum(src, updates); } From daa2f60ab5533b2fa23c90d35d729702c01f54b9 Mon Sep 17 00:00:00 2001 From: AN Long Date: Thu, 19 Mar 2026 00:59:50 +0800 Subject: [PATCH 4/4] Fix tests --- python/src/indexing.cpp | 15 ++++++++++++++- python/tests/test_array.py | 2 +- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index ac41e6223e..7a7e8b1102 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -16,7 +16,20 @@ bool is_none_slice(const nb::slice& in_slice) { } bool is_index_scalar(const nb::object& obj) { - return !nb::isinstance(obj) && PyIndex_Check(obj.ptr()); + if (nb::isinstance(obj)) { + return false; + } + if (!PyIndex_Check(obj.ptr())) { + return false; + } + // Exclude multi-dimensional arrays (mx.array, np.ndarray) by checking ndim + if (nb::hasattr(obj, "ndim")) { + auto ndim = nb::getattr(obj, "ndim"); + if (nb::isinstance(ndim) && nb::cast(ndim) > 0) { + return false; + } + } + return true; } int safe_to_int32(nb::object obj) { diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 042a0cfeea..bc53ada034 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1130,7 +1130,7 @@ def test_setitem(self): self.assertEqual(a.tolist(), [2, 9, 2]) a[0] = mx.array([[[1]]]) - self.assertEqual(a.tolist(), [1, 2, 2]) + self.assertEqual(a.tolist(), [1, 9, 2]) a[:] = 0 self.assertEqual(a.tolist(), [0, 0, 0])