diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 5699e0e8a1..7a7e8b1102 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -15,8 +15,30 @@ bool is_none_slice(const nb::slice& in_slice) { nb::getattr(in_slice, "step").is_none()); } +bool is_index_scalar(const nb::object& obj) { + if (nb::isinstance(obj)) { + return false; + } + if (!PyIndex_Check(obj.ptr())) { + return false; + } + // Exclude multi-dimensional arrays (mx.array, np.ndarray) by checking ndim + if (nb::hasattr(obj, "ndim")) { + auto ndim = nb::getattr(obj, "ndim"); + if (nb::isinstance(ndim) && nb::cast(ndim) > 0) { + return false; + } + } + return true; +} + int safe_to_int32(nb::object obj) { - auto val = nb::cast(nb::cast(obj)); + auto idx = nb::steal(PyNumber_Index(obj.ptr())); + if (!idx.is_valid()) { + throw nb::python_error(); + } + + auto val = nb::cast(nb::cast(idx)); if (val > INT32_MAX || val < INT32_MIN) { throw std::invalid_argument("Slice indices must be 32-bit integers."); } @@ -25,7 +47,7 @@ int safe_to_int32(nb::object obj) { int get_slice_int(nb::object obj, int default_val) { if (!obj.is_none()) { - if (!nb::isinstance(obj)) { + if (!is_index_scalar(obj)) { throw std::invalid_argument("Slice indices must be integers or None."); } return safe_to_int32(obj); @@ -60,7 +82,7 @@ mx::array get_int_index(nb::object idx, int axis_size) { } bool is_valid_index_type(const nb::object& obj) { - return nb::isinstance(obj) || nb::isinstance(obj) || + return nb::isinstance(obj) || is_index_scalar(obj) || nb::isinstance(obj) || obj.is_none() || nb::ellipsis().is(obj) || nb::isinstance(obj); } @@ -102,7 +124,7 @@ mx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) { return take(src, indices, 0); } -mx::array mlx_get_item_int(const mx::array& src, const nb::int_& idx) { +mx::array mlx_get_item_int(const mx::array& src, const nb::object& idx) { // Check input and raise error if 0 dim for parity with np if (src.ndim() == 0) { throw std::invalid_argument( @@ -139,7 +161,7 @@ mx::array mlx_gather_nd( gather_indices.push_back(arange(start, end, stride, mx::uint32)); num_slices++; is_slice[i] = true; - } else if (nb::isinstance(idx)) { + } else if (is_index_scalar(idx)) { gather_indices.push_back(get_int_index(idx, src.shape(i))); } else if (nb::isinstance(idx)) { auto arr = nb::cast(idx); @@ -289,7 +311,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { bool have_non_array = false; bool gather_first = false; for (auto& idx : indices) { - if (nb::isinstance(idx) || (nb::isinstance(idx))) { + if (nb::isinstance(idx) || is_index_scalar(idx)) { if (have_array && have_non_array) { gather_first = true; break; @@ -312,7 +334,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { // Then find the last array for (last_array = indices.size() - 1; last_array >= 0; last_array--) { auto& idx = indices[last_array]; - if (nb::isinstance(idx) || nb::isinstance(idx)) { + if (nb::isinstance(idx) || is_index_scalar(idx)) { break; } } @@ -348,7 +370,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { } else { for (int i = 0; i < indices.size(); i++) { auto& idx = indices[i]; - if (nb::isinstance(idx) || nb::isinstance(idx)) { + if (nb::isinstance(idx) || is_index_scalar(idx)) { break; } else if (idx.is_none()) { remaining_indices.push_back(idx); @@ -385,8 +407,8 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { int axis = 0; for (auto& idx : remaining_indices) { if (!idx.is_none()) { - if (!have_array && nb::isinstance(idx)) { - int st = nb::cast(idx); + if (!have_array && is_index_scalar(idx)) { + int st = safe_to_int32(idx); st = (st < 0) ? st + src.shape(axis) : st; starts[axis] = st; @@ -419,7 +441,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { auto& idx = remaining_indices[axis]; if (unsqueeze_needed && idx.is_none()) { unsqueeze_axes.push_back(axis - squeeze_axes.size()); - } else if (squeeze_needed && nb::isinstance(idx)) { + } else if (squeeze_needed && is_index_scalar(idx)) { squeeze_axes.push_back(axis - unsqueeze_axes.size()); } } @@ -439,8 +461,8 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) { return mlx_get_item_slice(src, nb::cast(obj)); } else if (nb::isinstance(obj)) { return mlx_get_item_array(src, nb::cast(obj)); - } else if (nb::isinstance(obj)) { - return mlx_get_item_int(src, nb::cast(obj)); + } else if (is_index_scalar(obj)) { + return mlx_get_item_int(src, obj); } else if (nb::isinstance(obj)) { return mlx_get_item_nd(src, nb::cast(obj)); } else if (nb::isinstance(obj)) { @@ -457,7 +479,7 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) { std::tuple, mx::array, std::vector> mlx_scatter_args_int( const mx::array& src, - const nb::int_& idx, + const nb::object& idx, const mx::array& update) { if (src.ndim() == 0) { throw std::invalid_argument( @@ -685,7 +707,7 @@ mlx_scatter_args_nd( // Add the shape to the update update_shape[ax - 1] = 1; } - } else if (nb::isinstance(pyidx)) { + } else if (is_index_scalar(pyidx)) { // Add index to arrays arr_indices.push_back(get_int_index(pyidx, src.shape(ax++))); // Add the shape to the update @@ -755,8 +777,8 @@ mlx_compute_scatter_args( return mlx_scatter_args_slice(src, nb::cast(obj), vals); } else if (nb::isinstance(obj)) { return mlx_scatter_args_array(src, nb::cast(obj), vals); - } else if (nb::isinstance(obj)) { - return mlx_scatter_args_int(src, nb::cast(obj), vals); + } else if (is_index_scalar(obj)) { + return mlx_scatter_args_int(src, obj, vals); } else if (nb::isinstance(obj)) { return mlx_scatter_args_nd(src, nb::cast(obj), vals); } else if (obj.is_none()) { @@ -780,9 +802,9 @@ mlx_compute_slice_update_args( mx::Shape strides(src.ndim(), 1); // Can't route to slice update if not slice, tuple, or int - if (src.ndim() == 0 || nb::isinstance(obj) || + if (src.ndim() == 0 || (!nb::isinstance(obj) && !nb::isinstance(obj) && - !nb::isinstance(obj))) { + !is_index_scalar(obj))) { return std::make_tuple( std::nullopt, std::move(starts), std::move(stops), std::move(strides)); } @@ -816,13 +838,13 @@ mlx_compute_slice_update_args( update = mx::squeeze(update, squeeze_axes); // Single int then make it a slice of size 1 - if (nb::isinstance(obj)) { + if (is_index_scalar(obj)) { if (src.ndim() < 1) { std::ostringstream msg; msg << "Too many indices for array with " << src.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - auto idx = nb::cast(obj); + auto idx = safe_to_int32(obj); idx = idx < 0 ? idx + stops[0] : idx; starts[0] = idx; stops[0] = idx + 1; @@ -884,8 +906,8 @@ mlx_compute_slice_update_args( src.shape(ax)); ax--; upd_ax--; - } else if (nb::isinstance(pyidx)) { - int st = nb::cast(pyidx); + } else if (is_index_scalar(pyidx)) { + int st = safe_to_int32(pyidx); st = (st < 0) ? st + src.shape(i) : st; starts[ax] = st; stops[ax] = st + 1; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index f3dee5a0f3..bc53ada034 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -960,6 +960,11 @@ def test_indexing(self): a_sliced_mlx = a_mlx[-1] self.assertTrue(np.array_equal(a_sliced_mlx, a_npy[-1])) + # NumPy integer scalar indexing + a_sliced_mlx = a_mlx[np.int64(5)] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[np.int64(5)])) + # Basic content check, empty index a_sliced_mlx = a_mlx[()] a_sliced_npy = np.asarray(a_sliced_mlx) @@ -1121,8 +1126,11 @@ def test_setitem(self): a[-1] = 2 self.assertEqual(a.tolist(), [2, 2, 2]) + a[np.int64(1)] = 9 + self.assertEqual(a.tolist(), [2, 9, 2]) + a[0] = mx.array([[[1]]]) - self.assertEqual(a.tolist(), [1, 2, 2]) + self.assertEqual(a.tolist(), [1, 9, 2]) a[:] = 0 self.assertEqual(a.tolist(), [0, 0, 0])