From fd740f1bc5f890b90130e97e65701ac5216a223d Mon Sep 17 00:00:00 2001 From: mm65x Date: Sun, 15 Mar 2026 16:12:56 +0000 Subject: [PATCH] support numpy scalar indexing --- python/src/indexing.cpp | 53 +++++++++++++++++++++++--------------- python/tests/test_array.py | 13 ++++++++++ 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 564c4cb45b..cedcc12334 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -8,6 +8,16 @@ #include "mlx/ops.h" +// Matches any object implementing __index__ (PEP 357): +// Python int, numpy integer scalars, etc. +bool is_int_like(const nb::object& obj) { + return PyIndex_Check(obj.ptr()); +} + +nb::int_ to_nb_int(const nb::object& obj) { + return nb::steal(PyNumber_Index(obj.ptr())); +} + bool is_none_slice(const nb::slice& in_slice) { return ( nb::getattr(in_slice, "start").is_none() && @@ -16,7 +26,8 @@ bool is_none_slice(const nb::slice& in_slice) { } int safe_to_int32(nb::object obj) { - auto val = nb::cast(nb::cast(obj)); + auto as_int = is_int_like(obj) ? to_nb_int(obj) : nb::cast(obj); + auto val = nb::cast(as_int); if (val > INT32_MAX || val < INT32_MIN) { throw std::invalid_argument("Slice indices must be 32-bit integers."); } @@ -25,10 +36,10 @@ int safe_to_int32(nb::object obj) { int get_slice_int(nb::object obj, int default_val) { if (!obj.is_none()) { - if (!nb::isinstance(obj)) { + if (!is_int_like(obj)) { throw std::invalid_argument("Slice indices must be integers or None."); } - return safe_to_int32(obj); + return safe_to_int32(to_nb_int(obj)); } return default_val; } @@ -60,7 +71,7 @@ mx::array get_int_index(nb::object idx, int axis_size) { } bool is_valid_index_type(const nb::object& obj) { - return nb::isinstance(obj) || nb::isinstance(obj) || + return nb::isinstance(obj) || is_int_like(obj) || nb::isinstance(obj) || obj.is_none() || nb::ellipsis().is(obj) || nb::isinstance(obj); } @@ -139,7 +150,7 @@ mx::array mlx_gather_nd( gather_indices.push_back(arange(start, end, stride, mx::uint32)); num_slices++; is_slice[i] = true; - } else if (nb::isinstance(idx)) { + } else if (is_int_like(idx)) { gather_indices.push_back(get_int_index(idx, src.shape(i))); } else if (nb::isinstance(idx)) { auto arr = nb::cast(idx); @@ -289,7 +300,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { bool have_non_array = false; bool gather_first = false; for (auto& idx : indices) { - if (nb::isinstance(idx) || (nb::isinstance(idx))) { + if (nb::isinstance(idx) || (is_int_like(idx))) { if (have_array && have_non_array) { gather_first = true; break; @@ -312,7 +323,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { // Then find the last array for (last_array = indices.size() - 1; last_array >= 0; last_array--) { auto& idx = indices[last_array]; - if (nb::isinstance(idx) || nb::isinstance(idx)) { + if (nb::isinstance(idx) || is_int_like(idx)) { break; } } @@ -348,7 +359,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { } else { for (int i = 0; i < indices.size(); i++) { auto& idx = indices[i]; - if (nb::isinstance(idx) || nb::isinstance(idx)) { + if (nb::isinstance(idx) || is_int_like(idx)) { break; } else if (idx.is_none()) { remaining_indices.push_back(idx); @@ -385,8 +396,8 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { int axis = 0; for (auto& idx : remaining_indices) { if (!idx.is_none()) { - if (!have_array && nb::isinstance(idx)) { - int st = nb::cast(idx); + if (!have_array && is_int_like(idx)) { + int st = nb::cast(to_nb_int(idx)); st = (st < 0) ? st + src.shape(axis) : st; starts[axis] = st; @@ -419,7 +430,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { auto& idx = remaining_indices[axis]; if (unsqueeze_needed && idx.is_none()) { unsqueeze_axes.push_back(axis - squeeze_axes.size()); - } else if (squeeze_needed && nb::isinstance(idx)) { + } else if (squeeze_needed && is_int_like(idx)) { squeeze_axes.push_back(axis - unsqueeze_axes.size()); } } @@ -439,8 +450,8 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) { return mlx_get_item_slice(src, nb::cast(obj)); } else if (nb::isinstance(obj)) { return mlx_get_item_array(src, nb::cast(obj)); - } else if (nb::isinstance(obj)) { - return mlx_get_item_int(src, nb::cast(obj)); + } else if (is_int_like(obj)) { + return mlx_get_item_int(src, to_nb_int(obj)); } else if (nb::isinstance(obj)) { return mlx_get_item_nd(src, nb::cast(obj)); } else if (nb::isinstance(obj)) { @@ -685,7 +696,7 @@ mlx_scatter_args_nd( // Add the shape to the update update_shape[ax - 1] = 1; } - } else if (nb::isinstance(pyidx)) { + } else if (is_int_like(pyidx)) { // Add index to arrays arr_indices.push_back(get_int_index(pyidx, src.shape(ax++))); // Add the shape to the update @@ -755,8 +766,8 @@ mlx_compute_scatter_args( return mlx_scatter_args_slice(src, nb::cast(obj), vals); } else if (nb::isinstance(obj)) { return mlx_scatter_args_array(src, nb::cast(obj), vals); - } else if (nb::isinstance(obj)) { - return mlx_scatter_args_int(src, nb::cast(obj), vals); + } else if (is_int_like(obj)) { + return mlx_scatter_args_int(src, to_nb_int(obj), vals); } else if (nb::isinstance(obj)) { return mlx_scatter_args_nd(src, nb::cast(obj), vals); } else if (obj.is_none()) { @@ -776,7 +787,7 @@ auto mlx_slice_update( // Can't route to slice update if not slice, tuple, or int if (src.ndim() == 0 || nb::isinstance(obj) || (!nb::isinstance(obj) && !nb::isinstance(obj) && - !nb::isinstance(obj))) { + !is_int_like(obj))) { return std::make_pair(false, src); } if (nb::isinstance(obj)) { @@ -806,13 +817,13 @@ auto mlx_slice_update( mx::Shape starts(src.ndim(), 0); mx::Shape stops = src.shape(); mx::Shape strides(src.ndim(), 1); - if (nb::isinstance(obj)) { + if (is_int_like(obj)) { if (src.ndim() < 1) { std::ostringstream msg; msg << "Too many indices for array with " << src.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - auto idx = nb::cast(obj); + auto idx = nb::cast(to_nb_int(obj)); idx = idx < 0 ? idx + stops[0] : idx; starts[0] = idx; stops[0] = idx + 1; @@ -872,8 +883,8 @@ auto mlx_slice_update( src.shape(ax)); ax--; upd_ax--; - } else if (nb::isinstance(pyidx)) { - int st = nb::cast(pyidx); + } else if (is_int_like(pyidx)) { + int st = nb::cast(to_nb_int(pyidx)); st = (st < 0) ? st + src.shape(i) : st; starts[ax] = st; stops[ax] = st + 1; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 4efed9dac9..6fb70ed0e5 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1100,6 +1100,19 @@ def check_slices(arr_np, *idx_np): a_mlx = mx.array(a_np) self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0]))) + # Numpy scalar indexing (issue #2710) + a = mx.array([10, 20, 30, 40, 50]) + for np_int in [np.int8, np.int16, np.int32, np.int64, np.intp]: + self.assertEqual(a[np_int(1)].item(), 20) + self.assertEqual(a[np_int(-1)].item(), 50) + # Numpy scalar slicing + self.assertTrue( + np.array_equal(np.array(a[np.int64(1) : np.int64(4)]), [20, 30, 40]) + ) + # Numpy scalar in tuple index + b = mx.array(np.arange(12).reshape(3, 4)) + self.assertEqual(b[np.int64(1), np.int64(2)].item(), 6) + def test_indexing_grad(self): x = mx.array([[1, 2], [3, 4]]).astype(mx.float32) ind = mx.array([0, 1, 0]).astype(mx.float32)