Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlx/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ void array::copy_shared_buffer(
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
int64_t offset /* = 0 */) {
array_desc_->data = other.array_desc_->data;
array_desc_->strides = strides;
array_desc_->flags = flags;
Expand Down
2 changes: 1 addition & 1 deletion mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ class array {
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
int64_t offset = 0);

void copy_shared_buffer(const array& other);

Expand Down
29 changes: 16 additions & 13 deletions mlx/backend/common/slicing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,13 @@ std::tuple<int64_t, Strides> prepare_slice(
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
}
// Normalize the offset
if (data_offset < 0) {
data_offset += in.data_size();
}
return std::make_tuple(data_offset, inp_strides);
}

void shared_buffer_slice(
const array& in,
const Strides& out_strides,
size_t data_offset,
int64_t data_offset,
size_t data_size,
array& out) {
// Compute row/col contiguity
Expand All @@ -51,17 +47,24 @@ void slice(

// Calculate out strides, initial offset
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
int64_t data_end = 1;
for (int i = 0; i < start_indices.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];

// Get the location of the end based on the inp strides and out.shape()
int64_t low_idx = 0;
int64_t high_idx = 0;
for (int i = 0; i < inp_strides.size(); ++i) {
auto delta = inp_strides[i] * (out.shape()[i] - 1);
if (inp_strides[i] > 0) {
high_idx += delta;
} else {
low_idx += delta;
}
}
if (data_end < 0) {
data_end += in.data_size();
int64_t data_size = (high_idx - low_idx) + 1;
if (data_size < 0) {
std::ostringstream msg;
msg << "[slice] Computed invalid data size: " << data_size << ".";
throw std::runtime_error(msg.str());
}
size_t data_size = (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}

Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/gpu/slicing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ void slice_gpu(
array& out,
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
const Stream&) {
slice(in, out, start_indices, strides);
}

Expand Down
5 changes: 5 additions & 0 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3058,6 +3058,11 @@ def test_slice_with_negative_stride(self):
out = a[::-1]
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))

a = mx.arange(8)
for _ in range(4):
a = a[::-1]
self.assertTrue(mx.array_equal(a, mx.arange(8)))

def test_complex_ops(self):
x = mx.array(
[
Expand Down
16 changes: 15 additions & 1 deletion tests/ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ TEST_CASE("test slice") {

out = slice(x, {0}, {4}, {2});
eval(out);
CHECK_EQ(out.data_size(), 4);
CHECK_EQ(out.data_size(), 3);

x = ones({4, 4});
out = slice(x, {0, 0}, {2, 4});
Expand Down Expand Up @@ -325,6 +325,20 @@ TEST_CASE("test slice") {
out = slice(x, {2, 2, 2}, {3, 4, 3});
eval(out);
CHECK_EQ(out.data_size(), 5);

x = ones({8});
out = slice(x, {7}, {-9}, {-1});
eval(out);
CHECK_EQ(out.data_size(), 8);

out = slice(x, {7}, {-9}, {-1});
eval(out);
CHECK_EQ(out.data_size(), 8);

x = ones({4, 2});
out = slice(x, {3, 0}, {-5, 2}, {-1, 1});
eval(out);
CHECK_EQ(out.data_size(), 8);
}

TEST_CASE("test slice update") {
Expand Down