Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 45 additions & 23 deletions python/src/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,30 @@ bool is_none_slice(const nb::slice& in_slice) {
nb::getattr(in_slice, "step").is_none());
}

bool is_index_scalar(const nb::object& obj) {
if (nb::isinstance<nb::bool_>(obj)) {
return false;
}
if (!PyIndex_Check(obj.ptr())) {
return false;
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should simply be return PyIndex_Check(obj.ptr()) && !nb::isinstance<nb::bool_>(obj);

// Exclude multi-dimensional arrays (mx.array, np.ndarray) by checking ndim
if (nb::hasattr(obj, "ndim")) {
auto ndim = nb::getattr(obj, "ndim");
if (nb::isinstance<nb::int_>(ndim) && nb::cast<int>(ndim) > 0) {
return false;
}
}
return true;
}

int safe_to_int32(nb::object obj) {
auto val = nb::cast<int64_t>(nb::cast<nb::int_>(obj));
auto idx = nb::steal<nb::object>(PyNumber_Index(obj.ptr()));
if (!idx.is_valid()) {
throw nb::python_error();
}

auto val = nb::cast<int64_t>(nb::cast<nb::int_>(idx));
if (val > INT32_MAX || val < INT32_MIN) {
throw std::invalid_argument("Slice indices must be 32-bit integers.");
}
Expand All @@ -25,7 +47,7 @@ int safe_to_int32(nb::object obj) {

int get_slice_int(nb::object obj, int default_val) {
if (!obj.is_none()) {
if (!nb::isinstance<nb::int_>(obj)) {
if (!is_index_scalar(obj)) {
throw std::invalid_argument("Slice indices must be integers or None.");
}
return safe_to_int32(obj);
Expand Down Expand Up @@ -60,7 +82,7 @@ mx::array get_int_index(nb::object idx, int axis_size) {
}

bool is_valid_index_type(const nb::object& obj) {
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
return nb::isinstance<nb::slice>(obj) || is_index_scalar(obj) ||
nb::isinstance<mx::array>(obj) || obj.is_none() ||
nb::ellipsis().is(obj) || nb::isinstance<nb::list>(obj);
}
Expand Down Expand Up @@ -102,7 +124,7 @@ mx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) {
return take(src, indices, 0);
}

mx::array mlx_get_item_int(const mx::array& src, const nb::int_& idx) {
mx::array mlx_get_item_int(const mx::array& src, const nb::object& idx) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
Expand Down Expand Up @@ -139,7 +161,7 @@ mx::array mlx_gather_nd(
gather_indices.push_back(arange(start, end, stride, mx::uint32));
num_slices++;
is_slice[i] = true;
} else if (nb::isinstance<nb::int_>(idx)) {
} else if (is_index_scalar(idx)) {
gather_indices.push_back(get_int_index(idx, src.shape(i)));
} else if (nb::isinstance<mx::array>(idx)) {
auto arr = nb::cast<mx::array>(idx);
Expand Down Expand Up @@ -289,7 +311,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
bool have_non_array = false;
bool gather_first = false;
for (auto& idx : indices) {
if (nb::isinstance<mx::array>(idx) || (nb::isinstance<nb::int_>(idx))) {
if (nb::isinstance<mx::array>(idx) || is_index_scalar(idx)) {
if (have_array && have_non_array) {
gather_first = true;
break;
Expand All @@ -312,7 +334,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
// Then find the last array
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
auto& idx = indices[last_array];
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
if (nb::isinstance<mx::array>(idx) || is_index_scalar(idx)) {
break;
}
}
Expand Down Expand Up @@ -348,7 +370,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
} else {
for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i];
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
if (nb::isinstance<mx::array>(idx) || is_index_scalar(idx)) {
break;
} else if (idx.is_none()) {
remaining_indices.push_back(idx);
Expand Down Expand Up @@ -385,8 +407,8 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
int axis = 0;
for (auto& idx : remaining_indices) {
if (!idx.is_none()) {
if (!have_array && nb::isinstance<nb::int_>(idx)) {
int st = nb::cast<int>(idx);
if (!have_array && is_index_scalar(idx)) {
int st = safe_to_int32(idx);
st = (st < 0) ? st + src.shape(axis) : st;

starts[axis] = st;
Expand Down Expand Up @@ -419,7 +441,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
auto& idx = remaining_indices[axis];
if (unsqueeze_needed && idx.is_none()) {
unsqueeze_axes.push_back(axis - squeeze_axes.size());
} else if (squeeze_needed && nb::isinstance<nb::int_>(idx)) {
} else if (squeeze_needed && is_index_scalar(idx)) {
squeeze_axes.push_back(axis - unsqueeze_axes.size());
}
}
Expand All @@ -439,8 +461,8 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
} else if (nb::isinstance<mx::array>(obj)) {
return mlx_get_item_array(src, nb::cast<mx::array>(obj));
} else if (nb::isinstance<nb::int_>(obj)) {
return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
} else if (is_index_scalar(obj)) {
return mlx_get_item_int(src, obj);
} else if (nb::isinstance<nb::tuple>(obj)) {
return mlx_get_item_nd(src, nb::cast<nb::tuple>(obj));
} else if (nb::isinstance<nb::ellipsis>(obj)) {
Expand All @@ -457,7 +479,7 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_scatter_args_int(
const mx::array& src,
const nb::int_& idx,
const nb::object& idx,
const mx::array& update) {
if (src.ndim() == 0) {
throw std::invalid_argument(
Expand Down Expand Up @@ -685,7 +707,7 @@ mlx_scatter_args_nd(
// Add the shape to the update
update_shape[ax - 1] = 1;
}
} else if (nb::isinstance<nb::int_>(pyidx)) {
} else if (is_index_scalar(pyidx)) {
// Add index to arrays
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
// Add the shape to the update
Expand Down Expand Up @@ -755,8 +777,8 @@ mlx_compute_scatter_args(
return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
} else if (nb::isinstance<mx::array>(obj)) {
return mlx_scatter_args_array(src, nb::cast<mx::array>(obj), vals);
} else if (nb::isinstance<nb::int_>(obj)) {
return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
} else if (is_index_scalar(obj)) {
return mlx_scatter_args_int(src, obj, vals);
} else if (nb::isinstance<nb::tuple>(obj)) {
return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
} else if (obj.is_none()) {
Expand All @@ -780,9 +802,9 @@ mlx_compute_slice_update_args(
mx::Shape strides(src.ndim(), 1);

// Can't route to slice update if not slice, tuple, or int
if (src.ndim() == 0 || nb::isinstance<nb::bool_>(obj) ||
if (src.ndim() == 0 ||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
!nb::isinstance<nb::int_>(obj))) {
!is_index_scalar(obj))) {
return std::make_tuple(
std::nullopt, std::move(starts), std::move(stops), std::move(strides));
}
Expand Down Expand Up @@ -816,13 +838,13 @@ mlx_compute_slice_update_args(
update = mx::squeeze(update, squeeze_axes);

// Single int then make it a slice of size 1
if (nb::isinstance<nb::int_>(obj)) {
if (is_index_scalar(obj)) {
if (src.ndim() < 1) {
std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
auto idx = nb::cast<int>(obj);
auto idx = safe_to_int32(obj);
idx = idx < 0 ? idx + stops[0] : idx;
starts[0] = idx;
stops[0] = idx + 1;
Expand Down Expand Up @@ -884,8 +906,8 @@ mlx_compute_slice_update_args(
src.shape(ax));
ax--;
upd_ax--;
} else if (nb::isinstance<nb::int_>(pyidx)) {
int st = nb::cast<int>(pyidx);
} else if (is_index_scalar(pyidx)) {
int st = safe_to_int32(pyidx);
st = (st < 0) ? st + src.shape(i) : st;
starts[ax] = st;
stops[ax] = st + 1;
Expand Down
10 changes: 9 additions & 1 deletion python/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,11 @@ def test_indexing(self):
a_sliced_mlx = a_mlx[-1]
self.assertTrue(np.array_equal(a_sliced_mlx, a_npy[-1]))

# NumPy integer scalar indexing
a_sliced_mlx = a_mlx[np.int64(5)]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[np.int64(5)]))

# Basic content check, empty index
a_sliced_mlx = a_mlx[()]
a_sliced_npy = np.asarray(a_sliced_mlx)
Expand Down Expand Up @@ -1121,8 +1126,11 @@ def test_setitem(self):
a[-1] = 2
self.assertEqual(a.tolist(), [2, 2, 2])

a[np.int64(1)] = 9
self.assertEqual(a.tolist(), [2, 9, 2])

a[0] = mx.array([[[1]]])
self.assertEqual(a.tolist(), [1, 2, 2])
self.assertEqual(a.tolist(), [1, 9, 2])

a[:] = 0
self.assertEqual(a.tolist(), [0, 0, 0])
Expand Down
Loading