Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 32 additions & 21 deletions python/src/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@

#include "mlx/ops.h"

// Matches any object implementing __index__ (PEP 357):
// Python int, numpy integer scalars, etc.
bool is_int_like(const nb::object& obj) {
return PyIndex_Check(obj.ptr());
}

nb::int_ to_nb_int(const nb::object& obj) {
return nb::steal<nb::int_>(PyNumber_Index(obj.ptr()));
}

bool is_none_slice(const nb::slice& in_slice) {
return (
nb::getattr(in_slice, "start").is_none() &&
Expand All @@ -16,7 +26,8 @@ bool is_none_slice(const nb::slice& in_slice) {
}

int safe_to_int32(nb::object obj) {
auto val = nb::cast<int64_t>(nb::cast<nb::int_>(obj));
auto as_int = is_int_like(obj) ? to_nb_int(obj) : nb::cast<nb::int_>(obj);
auto val = nb::cast<int64_t>(as_int);
if (val > INT32_MAX || val < INT32_MIN) {
throw std::invalid_argument("Slice indices must be 32-bit integers.");
}
Expand All @@ -25,10 +36,10 @@ int safe_to_int32(nb::object obj) {

int get_slice_int(nb::object obj, int default_val) {
if (!obj.is_none()) {
if (!nb::isinstance<nb::int_>(obj)) {
if (!is_int_like(obj)) {
throw std::invalid_argument("Slice indices must be integers or None.");
}
return safe_to_int32(obj);
return safe_to_int32(to_nb_int(obj));
}
return default_val;
}
Expand Down Expand Up @@ -60,7 +71,7 @@ mx::array get_int_index(nb::object idx, int axis_size) {
}

bool is_valid_index_type(const nb::object& obj) {
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
return nb::isinstance<nb::slice>(obj) || is_int_like(obj) ||
nb::isinstance<mx::array>(obj) || obj.is_none() ||
nb::ellipsis().is(obj) || nb::isinstance<nb::list>(obj);
}
Expand Down Expand Up @@ -139,7 +150,7 @@ mx::array mlx_gather_nd(
gather_indices.push_back(arange(start, end, stride, mx::uint32));
num_slices++;
is_slice[i] = true;
} else if (nb::isinstance<nb::int_>(idx)) {
} else if (is_int_like(idx)) {
gather_indices.push_back(get_int_index(idx, src.shape(i)));
} else if (nb::isinstance<mx::array>(idx)) {
auto arr = nb::cast<mx::array>(idx);
Expand Down Expand Up @@ -289,7 +300,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
bool have_non_array = false;
bool gather_first = false;
for (auto& idx : indices) {
if (nb::isinstance<mx::array>(idx) || (nb::isinstance<nb::int_>(idx))) {
if (nb::isinstance<mx::array>(idx) || (is_int_like(idx))) {
if (have_array && have_non_array) {
gather_first = true;
break;
Expand All @@ -312,7 +323,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
// Then find the last array
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
auto& idx = indices[last_array];
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
if (nb::isinstance<mx::array>(idx) || is_int_like(idx)) {
break;
}
}
Expand Down Expand Up @@ -348,7 +359,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
} else {
for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i];
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
if (nb::isinstance<mx::array>(idx) || is_int_like(idx)) {
break;
} else if (idx.is_none()) {
remaining_indices.push_back(idx);
Expand Down Expand Up @@ -385,8 +396,8 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
int axis = 0;
for (auto& idx : remaining_indices) {
if (!idx.is_none()) {
if (!have_array && nb::isinstance<nb::int_>(idx)) {
int st = nb::cast<int>(idx);
if (!have_array && is_int_like(idx)) {
int st = nb::cast<int>(to_nb_int(idx));
st = (st < 0) ? st + src.shape(axis) : st;

starts[axis] = st;
Expand Down Expand Up @@ -419,7 +430,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
auto& idx = remaining_indices[axis];
if (unsqueeze_needed && idx.is_none()) {
unsqueeze_axes.push_back(axis - squeeze_axes.size());
} else if (squeeze_needed && nb::isinstance<nb::int_>(idx)) {
} else if (squeeze_needed && is_int_like(idx)) {
squeeze_axes.push_back(axis - unsqueeze_axes.size());
}
}
Expand All @@ -439,8 +450,8 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
} else if (nb::isinstance<mx::array>(obj)) {
return mlx_get_item_array(src, nb::cast<mx::array>(obj));
} else if (nb::isinstance<nb::int_>(obj)) {
return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
} else if (is_int_like(obj)) {
return mlx_get_item_int(src, to_nb_int(obj));
} else if (nb::isinstance<nb::tuple>(obj)) {
return mlx_get_item_nd(src, nb::cast<nb::tuple>(obj));
} else if (nb::isinstance<nb::ellipsis>(obj)) {
Expand Down Expand Up @@ -685,7 +696,7 @@ mlx_scatter_args_nd(
// Add the shape to the update
update_shape[ax - 1] = 1;
}
} else if (nb::isinstance<nb::int_>(pyidx)) {
} else if (is_int_like(pyidx)) {
// Add index to arrays
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
// Add the shape to the update
Expand Down Expand Up @@ -755,8 +766,8 @@ mlx_compute_scatter_args(
return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
} else if (nb::isinstance<mx::array>(obj)) {
return mlx_scatter_args_array(src, nb::cast<mx::array>(obj), vals);
} else if (nb::isinstance<nb::int_>(obj)) {
return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
} else if (is_int_like(obj)) {
return mlx_scatter_args_int(src, to_nb_int(obj), vals);
} else if (nb::isinstance<nb::tuple>(obj)) {
return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
} else if (obj.is_none()) {
Expand All @@ -776,7 +787,7 @@ auto mlx_slice_update(
// Can't route to slice update if not slice, tuple, or int
if (src.ndim() == 0 || nb::isinstance<nb::bool_>(obj) ||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
!nb::isinstance<nb::int_>(obj))) {
!is_int_like(obj))) {
return std::make_pair(false, src);
}
if (nb::isinstance<nb::tuple>(obj)) {
Expand Down Expand Up @@ -806,13 +817,13 @@ auto mlx_slice_update(
mx::Shape starts(src.ndim(), 0);
mx::Shape stops = src.shape();
mx::Shape strides(src.ndim(), 1);
if (nb::isinstance<nb::int_>(obj)) {
if (is_int_like(obj)) {
if (src.ndim() < 1) {
std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
auto idx = nb::cast<int>(obj);
auto idx = nb::cast<int>(to_nb_int(obj));
idx = idx < 0 ? idx + stops[0] : idx;
starts[0] = idx;
stops[0] = idx + 1;
Expand Down Expand Up @@ -872,8 +883,8 @@ auto mlx_slice_update(
src.shape(ax));
ax--;
upd_ax--;
} else if (nb::isinstance<nb::int_>(pyidx)) {
int st = nb::cast<int>(pyidx);
} else if (is_int_like(pyidx)) {
int st = nb::cast<int>(to_nb_int(pyidx));
st = (st < 0) ? st + src.shape(i) : st;
starts[ax] = st;
stops[ax] = st + 1;
Expand Down
13 changes: 13 additions & 0 deletions python/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,19 @@ def check_slices(arr_np, *idx_np):
a_mlx = mx.array(a_np)
self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0])))

# Numpy scalar indexing (issue #2710)
a = mx.array([10, 20, 30, 40, 50])
for np_int in [np.int8, np.int16, np.int32, np.int64, np.intp]:
self.assertEqual(a[np_int(1)].item(), 20)
self.assertEqual(a[np_int(-1)].item(), 50)
# Numpy scalar slicing
self.assertTrue(
np.array_equal(np.array(a[np.int64(1) : np.int64(4)]), [20, 30, 40])
)
# Numpy scalar in tuple index
b = mx.array(np.arange(12).reshape(3, 4))
self.assertEqual(b[np.int64(1), np.int64(2)].item(), 6)

def test_indexing_grad(self):
x = mx.array([[1, 2], [3, 4]]).astype(mx.float32)
ind = mx.array([0, 1, 0]).astype(mx.float32)
Expand Down