diff --git a/benchmarks/python/slice_update_bench.py b/benchmarks/python/slice_update_bench.py new file mode 100644 index 0000000000..60f12b9814 --- /dev/null +++ b/benchmarks/python/slice_update_bench.py @@ -0,0 +1,109 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse + +import mlx.core as mx +import torch +from time_utils import measure_runtime + + +def benchmark_slice_update_mlx(dst_shape, slice_shape, slice_range, dtype, iters=10): + def slice_update(arguments): + for i in range(iters): + arguments["dst"] = ( + arguments["dst"].at[slice_range].add(arguments["updates"]) + ) + mx.eval(arguments) + + dtype = getattr(mx, dtype) + arguments = { + "dst": mx.random.normal(dst_shape).astype(dtype), + "updates": mx.random.normal(slice_shape).astype(dtype), + } + + runtime = measure_runtime(slice_update, arguments=arguments) + bytes_processed = ( + arguments["dst"][slice_range].nbytes * 2 + arguments["updates"].nbytes + ) * iters + bandwidth_gb_s = bytes_processed / runtime / 1e6 + return runtime, bandwidth_gb_s + + +def benchmark_slice_update_torch( + dst_shape, slice_shape, slice_range, device, dtype, iters=10 +): + def slice_update(dst, updates, slice_range): + for i in range(iters): + dst[slice_range] = dst[slice_range] + updates + if device == torch.device("mps"): + torch.mps.synchronize() + + dtype = getattr(torch, dtype) + updates = torch.randn(slice_shape, dtype=dtype).to(device) + dst = torch.randn(dst_shape, dtype=dtype).to(device) + + runtime = measure_runtime( + slice_update, dst=dst, updates=updates, slice_range=slice_range + ) + bytes_processed = (dst[slice_range].nbytes * 2 + updates.nbytes) * iters + bandwidth_gb_s = bytes_processed / runtime / 1e6 + return runtime, bandwidth_gb_s + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Slice update benchmarks.") + parser.add_argument("--cpu", action="store_true", help="Use the CPU.") + args = parser.parse_args() + + if args.cpu: + mx.set_default_device(mx.cpu) + device = torch.device("cpu") + elif torch.mps.is_available(): + device = torch.device("mps") + elif torch.cuda.is_available(): + device = torch.device("cuda") + else: + raise ValueError() + + dtypes = ["float32", "bfloat16"] + + test_cases = [ + ((10_000_000,), slice(0, 1_000_000), (1_000_000,)), + ((100_000,), slice(10_000, 20_000), (10_000,)), + ((1000, 64), slice(100, 200), (100, 64)), + ((100, 100, 64), slice(20, 40), (20, 100, 64)), + ( + (2048, 2048, 128), + (slice(500, 1500), slice(200, 1200), slice(32, 96)), + (1000, 1000, 64), + ), + ( + (2048, 2048, 128), + (slice(1800, 1850), slice(100, 200), slice(64, 128)), + (50, 100, 64), + ), + ( + (2048, 2048, 128), + (slice(1000, 1010), slice(1000, 1010), slice(64, 128)), + (10, 10, 64), + ), + ] + + print( + f"{'Dtype':<12} {'Dst Shape':<25} {'Update Shape':<20} " + f"{'MLX (ms)':<12} {'MLX GB/s':<12} {'Torch (ms)':<12} {'Torch GB/s':<12}" + ) + print("-" * 110) + + for dtype in dtypes: + for dst_shape, slice_range, update_shape in test_cases: + mlx_time, mlx_bw = benchmark_slice_update_mlx( + dst_shape, update_shape, slice_range, dtype + ) + torch_time, torch_bw = benchmark_slice_update_torch( + dst_shape, update_shape, slice_range, device, dtype + ) + print( + f"{dtype:<12} {str(dst_shape):<25} {str(update_shape):<20} " + f"{mlx_time:<12.3f} {mlx_bw:<12.2f} {torch_time:<12.3f} {torch_bw:<12.2f}" + ) diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 1b6902ff33..c6d7820619 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -116,6 +116,39 @@ struct ContiguousIterator { loc += strides_[i]; } + void step(int64_t s) { + int dims = shape_.size(); + if (dims == 0) { + return; + } + int i = dims - 1; + while (s > 0) { + if (shape_[i] - pos_[i] > 1) { + int steps = static_cast( + std::min(static_cast(shape_[i] - pos_[i] - 1), s)); + pos_[i] += steps; + loc += strides_[i] * steps; + s -= steps; + } else { + while (pos_[i] == (shape_[i] - 1) && i > 0) { + pos_[i] = 0; + loc -= (shape_[i] - 1) * strides_[i]; + i--; + } + pos_[i]++; + loc += strides_[i]; + s--; + } + } + } + + int64_t contiguous_suffix() { + if (shape_.size() == 0) { + return 0; + } + return (strides_.back() == 1) ? shape_.back() : 0; + } + void seek(int64_t n) { loc = 0; for (int i = shape_.size() - 1; i >= 0; --i) { diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index a1b412c7dc..ec4090172f 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -4,11 +4,14 @@ #include #include "mlx/allocator.h" -#include "mlx/primitives.h" - #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/binary.h" +#include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/slicing.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" namespace mlx::core { @@ -788,7 +791,7 @@ void MaskedScatter::eval_cpu(const std::vector& inputs, array& out) { auto& mask = inputs[1]; auto& src = inputs[2]; - // Copy src into out (copy allocates memory for out) + // Copy dst into out (copy allocates memory for out) auto ctype = dst.flags().row_contiguous ? CopyType::Vector : CopyType::General; copy_cpu(dst, out, ctype, stream()); @@ -851,4 +854,128 @@ void MaskedScatter::eval_cpu(const std::vector& inputs, array& out) { }); } +template +void slice_update_impl( + array& out, + const array& upd, + int64_t data_offset, + const Strides& out_strides) { + ContiguousIterator out_it(upd.shape(), out_strides, upd.ndim()); + ContiguousIterator upd_it(upd); + Op op; + + constexpr int SIMD_START = 32; + + T* out_ptr = out.data() + data_offset; + const T* upd_ptr = upd.data(); + int64_t size = upd.size(); + int64_t suffix = out_it.contiguous_suffix(); + + if (upd.data_size() == 1) { + if (suffix >= SIMD_START) { + for (int64_t i = 0; i < size; i += suffix) { + VectorScalar{}( + out_ptr + out_it.loc, upd_ptr, out_ptr + out_it.loc, suffix); + out_it.step(suffix); + } + } else { + T update = upd_ptr[0]; + for (int64_t i = 0; i < size; i++) { + out_ptr[out_it.loc] = op(out_ptr[out_it.loc], update); + out_it.step(); + } + } + } else if (suffix == upd_it.contiguous_suffix() && suffix >= SIMD_START) { + for (int64_t i = 0; i < size; i += suffix) { + VectorVector{}( + out_ptr + out_it.loc, + upd_ptr + upd_it.loc, + out_ptr + out_it.loc, + suffix); + out_it.step(suffix); + upd_it.step(suffix); + } + } else { + for (int64_t i = 0; i < size; i++) { + out_ptr[out_it.loc] = op(out_ptr[out_it.loc], upd_ptr[upd_it.loc]); + out_it.step(); + upd_it.step(); + } + } +} + +void SliceUpdate::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + if (out.size() == 0) { + out.set_data(allocator::malloc(0)); + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + + if (upd.size() == 0) { + out.copy_shared_buffer(in); + return; + } + + // Check if materialization is needed + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); + + // Calculate out strides, initial offset and if copy needs to be made + auto [data_offset, out_strides] = + prepare_slice(out, start_indices_, strides_); + + // Do copy + if (reduce_type_ == SliceUpdate::None) { + copy_cpu_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const std::vector& data_shape = */ upd.shape(), + /* const std::vector& i_strides = */ upd.strides(), + /* const std::vector& o_strides = */ out_strides, + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ data_offset, + /* CopyType ctype = */ CopyType::GeneralGeneral, + stream()); + return; + } + + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(upd); + encoder.set_output_array(out); + encoder.dispatch([upd = array::unsafe_weak_copy(upd), + out = array::unsafe_weak_copy(out), + data_offset = data_offset, + out_strides = std::move(out_strides), + reduce_type = reduce_type_]() mutable { + dispatch_all_types(out.dtype(), [&](auto type_tag) { + using T = MLX_GET_TYPE(type_tag); + switch (reduce_type) { + case SliceUpdate::Sum: + slice_update_impl(out, upd, data_offset, out_strides); + break; + case SliceUpdate::Prod: + slice_update_impl( + out, upd, data_offset, out_strides); + break; + case SliceUpdate::Max: + slice_update_impl( + out, upd, data_offset, out_strides); + break; + case SliceUpdate::Min: + slice_update_impl( + out, upd, data_offset, out_strides); + break; + case SliceUpdate::None: + // Should never be here + break; + } + }); + }); +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index 4e59b1ebec..f1d83dd306 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -398,44 +398,6 @@ void DynamicSliceUpdate::eval_cpu( } } -void SliceUpdate::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - if (out.size() == 0) { - out.set_data(allocator::malloc(0)); - return; - } - - auto& in = inputs[0]; - auto& upd = inputs[1]; - - if (upd.size() == 0) { - out.copy_shared_buffer(in); - return; - } - - // Check if materialization is needed - auto ctype = in.flags().contiguous && in.size() == in.data_size() - ? CopyType::Vector - : CopyType::General; - copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); - - // Calculate out strides, initial offset and if copy needs to be made - auto [data_offset, out_strides] = - prepare_slice(out, start_indices_, strides_); - - // Do copy - copy_cpu_inplace( - /* const array& src = */ upd, - /* array& dst = */ out, - /* const std::vector& data_shape = */ upd.shape(), - /* const std::vector& i_strides = */ upd.strides(), - /* const std::vector& o_strides = */ out_strides, - /* int64_t i_offset = */ 0, - /* int64_t o_offset = */ data_offset, - /* CopyType ctype = */ CopyType::GeneralGeneral, - stream()); -} - void View::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; diff --git a/mlx/backend/cuda/device/slice_update.cuh b/mlx/backend/cuda/device/slice_update.cuh new file mode 100644 index 0000000000..1ca63807c4 --- /dev/null +++ b/mlx/backend/cuda/device/slice_update.cuh @@ -0,0 +1,75 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + typename Op, + bool OUT_ROW_CONTIG, + bool UPD_ROW_CONTIG, + bool UPD_SCALAR, + int NWORK> +__global__ void slice_update_op( + const T* updates, + T* out, + int64_t update_size, + const __grid_constant__ Shape update_shape, + const __grid_constant__ Strides update_strides, + int32_t update_ndim, + const __grid_constant__ Strides output_strides, + int64_t output_offset) { + Op op; + + IdxT idx = cg::this_grid().thread_rank() * NWORK; + IdxT out_idx; + IdxT update_idx; + + if constexpr (OUT_ROW_CONTIG) { + out_idx = idx; + } else { + out_idx = elem_to_loc( + idx, update_shape.data(), output_strides.data(), update_ndim); + } + + if constexpr (!UPD_SCALAR) { + if constexpr (UPD_ROW_CONTIG) { + update_idx = idx; + } else { + update_idx = elem_to_loc( + idx, update_shape.data(), update_strides.data(), update_ndim); + } + } else { + update_idx = 0; + } + + out += output_offset; + + for (int j = 0; j < NWORK && idx < update_size; j++) { + out[out_idx] = op(out[out_idx], updates[update_idx]); + idx++; + + if constexpr (OUT_ROW_CONTIG) { + out_idx = idx; + } else { + out_idx += output_strides[update_ndim - 1]; + } + + if constexpr (UPD_ROW_CONTIG) { + update_idx = idx; + } else if constexpr (!UPD_SCALAR) { + update_idx += update_strides[update_ndim - 1]; + } + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index a84de113da..8cbd443b4d 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -1,11 +1,13 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/compiled.h" +#include "mlx/backend/common/slicing.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/scan.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -24,6 +26,8 @@ namespace mlx::core { namespace { constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"}; +constexpr const char* g_slice_ops[] = + {"Maximum", "Minimum", "Add", "Multiply", ""}; void append_indices_arg( cu::KernelArgs& args, @@ -562,4 +566,120 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { kernel, num_blocks, block_dims, {}, 0, args.args()); } +void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("SliceUpdate::eval_gpu"); + assert(inputs.size() == 2); + if (out.size() == 0) { + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + + if (upd.size() == 0) { + out.copy_shared_buffer(in); + return; + } + + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); + + // Calculate out strides, initial offset and if copy needs to be made + auto [data_offset, out_strides] = + prepare_slice(out, start_indices_, strides_); + + // Do copy for None reduce type + if (reduce_type_ == SliceUpdate::None) { + copy_gpu_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const Shape& data_shape = */ upd.shape(), + /* const Strides& i_strides = */ upd.strides(), + /* const Strides& o_strides = */ out_strides, + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ data_offset, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ stream()); + return; + } + + auto [shape, strides] = + collapse_contiguous_dims(upd.shape(), {upd.strides(), out_strides}); + int nwork = 1; + if (shape.back() % 4 == 0) { + nwork = 4; + } else if (shape.back() % 2 == 0) { + nwork = 2; + } + + const char* op_name = g_slice_ops[reduce_type_]; + auto [ds, rc, cc] = check_contiguity(shape, strides[1]); + bool upd_contiguous = upd.flags().row_contiguous; + bool upd_scalar = upd.data_size() == 1; + bool out_contiguous = rc; + bool large = upd.size() > INT32_MAX; + std::string module_name = + fmt::format("slice_update_{}_{}", op_name, dtype_to_string(out.dtype())); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int out_c = 0; out_c <= 1; ++out_c) { + for (int upd_c = 0; upd_c <= 1; ++upd_c) { + for (int upd_s = 0; upd_s <= 1; ++upd_s) { + for (int large = 0; large <= 1; ++large) { + for (int nwork = 1; nwork <= 16; nwork *= 2) { + kernel_names.push_back( + fmt::format( + "mlx::core::cu::slice_update_op<{}, {}, mlx::core::cu::{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + large ? "int64_t" : "int32_t", + op_name, + out_c ? "true" : "false", + upd_c ? "true" : "false", + upd_s ? "true" : "false", + nwork)); + } + } + } + } + } + return std::make_tuple( + false, jit_source_slice_update, std::move(kernel_names)); + }); + + cu::KernelArgs args; + args.append(upd); + args.append(out); + args.append(upd.size()); + args.append_ndim(shape); + args.append_ndim(strides[0]); + args.append(shape.size()); + args.append_ndim(strides[1]); + args.append(data_offset); + + encoder.set_input_array(upd); + encoder.set_output_array(out); + + std::string kernel_name; + kernel_name = fmt::format( + "mlx::core::cu::slice_update_op<{}, {}, mlx::core::cu::{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + large ? "int64_t" : "int32_t", + op_name, + out_contiguous, + upd_contiguous, + upd_scalar, + nwork); + + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(upd, large, nwork); + encoder.add_kernel_node_raw( + kernel, num_blocks, block_dims, {}, 0, args.args()); +} + } // namespace mlx::core diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 0138928c06..268d6290bf 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -217,41 +217,6 @@ void Slice::eval_gpu(const std::vector& inputs, array& out) { slice_gpu(in, out, start_indices_, strides_, stream()); } -void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - if (out.size() == 0) { - out.set_data(allocator::malloc(0)); - return; - } - - auto& in = inputs[0]; - auto& upd = inputs[1]; - - if (upd.size() == 0) { - out.copy_shared_buffer(in); - return; - } - - auto ctype = in.flags().contiguous && in.size() == in.data_size() - ? CopyType::Vector - : CopyType::General; - copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); - auto [data_offset, out_strides] = - prepare_slice(out, start_indices_, strides_); - - // Do copy - copy_gpu_inplace( - /* const array& src = */ upd, - /* array& dst = */ out, - /* const Shape& data_shape = */ upd.shape(), - /* const Strides& i_strides = */ upd.strides(), - /* const Strides& o_strides = */ out_strides, - /* int64_t i_offset = */ 0, - /* int64_t o_offset = */ data_offset, - /* CopyType ctype = */ CopyType::GeneralGeneral, - /* const Stream& s = */ stream()); -} - void Squeeze::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Squeeze::eval_gpu"); eval(inputs, out); diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index e0ebb790f4..9c8a36392d 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -3,9 +3,11 @@ #include #include "mlx/backend/common/compiled.h" +#include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/scan.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/indexing.h" @@ -36,6 +38,22 @@ std::pair make_index_args( return {idx_args.str(), idx_arr.str()}; } +template +inline std::string make_op(typename T::ReduceType r, const std::string& dt) { + switch (r) { + case T::None: + return "None"; + case T::Sum: + return fmt::format("Sum<{0}>", dt); + case T::Prod: + return fmt::format("Prod<{0}>", dt); + case T::Max: + return fmt::format("Max<{0}>", dt); + case T::Min: + return fmt::format("Min<{0}>", dt); + } +} + void Gather::eval_gpu(const std::vector& inputs, array& out) { auto& src = inputs[0]; int nidx = inputs.size() - 1; @@ -307,27 +325,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { std::string out_type_str = get_type_string(out.dtype()); std::string idx_type_str = nidx ? get_type_string(inputs[1].dtype()) : "bool"; - std::string op_type; - switch (reduce_type_) { - case Scatter::None: - op_type = "None"; - break; - case Scatter::Sum: - op_type = "Sum<{0}>"; - break; - case Scatter::Prod: - op_type = "Prod<{0}>"; - break; - case Scatter::Max: - op_type = "Max<{0}>"; - break; - case Scatter::Min: - op_type = "Min<{0}>"; - break; - } - if (reduce_type_ != Scatter::None) { - op_type = fmt::format(fmt::runtime(op_type), out_type_str); - } + std::string op_type = make_op(reduce_type_, out_type_str); auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); kernel_source += fmt::format( @@ -724,4 +722,151 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threads(grid_dims, group_dims); } +void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + if (out.size() == 0) { + out.set_data(allocator::malloc(0)); + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + + if (upd.size() == 0) { + out.copy_shared_buffer(in); + return; + } + + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); + auto [data_offset, out_strides] = + prepare_slice(out, start_indices_, strides_); + + // Do copy + if (reduce_type_ == SliceUpdate::None) { + copy_gpu_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const Shape& data_shape = */ upd.shape(), + /* const Strides& i_strides = */ upd.strides(), + /* const Strides& o_strides = */ out_strides, + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ data_offset, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ stream()); + return; + } + + std::string op_name; + switch (reduce_type_) { + case SliceUpdate::None: + op_name = "none"; + break; + case SliceUpdate::Sum: + op_name = "sum"; + break; + case SliceUpdate::Prod: + op_name = "prod"; + break; + case SliceUpdate::Max: + op_name = "max"; + break; + case SliceUpdate::Min: + op_name = "min"; + break; + } + + bool upd_contiguous = upd.flags().row_contiguous; + bool upd_scalar = upd.data_size() == 1; + + Shape shape; + std::vector strides; + if (upd_scalar) { + std::tie(shape, strides) = + collapse_contiguous_dims(upd.shape(), {out_strides, out_strides}); + } else { + std::tie(shape, strides) = + collapse_contiguous_dims(upd.shape(), {upd.strides(), out_strides}); + } + + int ndim_constant = shape.size(); + if (ndim_constant > 3) { + ndim_constant = 0; + } + + int nwork = 1; + if (shape.back() % 4 == 0) { + nwork = 4; + } else if (shape.back() % 2 == 0) { + nwork = 2; + } + + auto [ds, rc, cc] = check_contiguity(shape, strides[1]); + bool out_contiguous = rc; + bool large = upd.size() > INT32_MAX; + std::string kernel_name = fmt::format( + "slice_update_{0}_{1}{2}_{3}_{4}_{5}_nw{6}_nd{7}", + op_name, + type_to_name(out), + large ? "int64_t" : "int", + out_contiguous ? "oc_true" : "oc_false", + upd_contiguous ? "updc_true" : "updc_false", + upd_scalar ? "upds_true" : "upds_false", + nwork, + ndim_constant); + + auto& s = stream(); + auto& d = metal::device(s.device); + + auto lib = d.get_library(kernel_name, [&]() { + std::string kernel_source = metal::utils(); + concatenate(kernel_source, metal::reduce_utils(), metal::scatter()); + + std::string out_type = get_type_string(out.dtype()); + std::string op_type = make_op(reduce_type_, out_type); + + kernel_source += fmt::format( + slice_update_op_kernel, + kernel_name, + out_type, + large ? "int64_t" : "int", + op_type, + out_contiguous, + upd_contiguous, + upd_scalar, + nwork, + ndim_constant); + + return kernel_source; + }); + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kernel_name, lib); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set all the buffers + int ndim = shape.size(); + int64_t size = upd.size(); + compute_encoder.set_input_array(upd, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_vector_bytes(shape, 2); + compute_encoder.set_vector_bytes(strides[0], 3); + compute_encoder.set_bytes(ndim, 4); + compute_encoder.set_bytes(size, 5); + compute_encoder.set_vector_bytes(strides[1], 6); + compute_encoder.set_bytes(data_offset, 7); + + // Launch grid + int64_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + int64_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + int64_t rest = size / (dim0 * dim1); + dim0 /= nwork; + + auto group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims(dim0, dim1, rest); + compute_encoder.dispatch_threads(grid_dims, group_dims); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/jit/indexing.h b/mlx/backend/metal/jit/indexing.h index fa141fccf5..b36093b857 100644 --- a/mlx/backend/metal/jit/indexing.h +++ b/mlx/backend/metal/jit/indexing.h @@ -74,3 +74,9 @@ constexpr std::string_view scatter_kernels = R"( constexpr std::string_view masked_assign_kernel = R"( template [[host_name("{0}")]] [[kernel]] decltype(masked_assign_impl<{1}, {2}>) masked_assign_impl<{1}, {2}>; )"; + +constexpr std::string_view slice_update_op_kernel = R"( +template [[host_name("{0}")]] +[[kernel]] decltype(slice_update_op_impl<{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}>) +slice_update_op_impl<{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}>; +)"; diff --git a/mlx/backend/metal/kernels/indexing/scatter.h b/mlx/backend/metal/kernels/indexing/scatter.h index f0217b3369..b7f304c9ad 100644 --- a/mlx/backend/metal/kernels/indexing/scatter.h +++ b/mlx/backend/metal/kernels/indexing/scatter.h @@ -57,3 +57,81 @@ METAL_FUNC void scatter_impl( op.atomic_update(out, updates[upd_idx], out_idx); } } + +template < + typename T, + typename IdxT, + typename Op, + bool OUT_ROW_CONTIG, + bool UPD_ROW_CONTIG, + bool UPD_SCALAR, + int NWORK, + int NDIM> +[[kernel]] void slice_update_op_impl( + const device T* updates [[buffer(0)]], + device T* out [[buffer(1)]], + const constant int* update_shape [[buffer(2)]], + const constant int64_t* update_strides [[buffer(3)]], + const constant int& update_ndim [[buffer(4)]], + const constant int64_t& update_size [[buffer(5)]], + const constant int64_t* output_strides [[buffer(6)]], + const constant int64_t& output_offset [[buffer(7)]], + uint3 gid [[thread_position_in_grid]], + uint3 gsize [[threads_per_grid]]) { + Op op; + + IdxT idx = IdxT(gid.z) * gsize.y + gid.y * gsize.x + gid.x * NWORK; + IdxT out_idx; + IdxT update_idx; + + if constexpr (OUT_ROW_CONTIG) { + out_idx = idx; + } else if constexpr (NDIM == 1) { + out_idx = NWORK * gid.x * output_strides[0]; + } else if constexpr (NDIM == 2) { + out_idx = gid.y * output_strides[0] + NWORK * gid.x * output_strides[1]; + } else if constexpr (NDIM == 3) { + out_idx = gid.z * output_strides[0] + gid.y * output_strides[1] + + NWORK * gid.x * output_strides[2]; + } else { + out_idx = elem_to_loc(idx, update_shape, output_strides, update_ndim); + } + + if constexpr (UPD_SCALAR) { + update_idx = 0; + } else if constexpr (UPD_ROW_CONTIG) { + update_idx = idx; + } else if constexpr (NDIM == 1) { + update_idx = NWORK * gid.x * update_strides[0]; + } else if constexpr (NDIM == 2) { + update_idx = gid.y * update_strides[0] + NWORK * gid.x * update_strides[1]; + } else if constexpr (NDIM == 3) { + update_idx = gid.z * update_strides[0] + gid.y * update_strides[1] + + NWORK * gid.x * update_strides[2]; + } else { + update_idx = + elem_to_loc(idx, update_shape, update_strides, update_ndim); + } + + out += output_offset; + + if constexpr (OUT_ROW_CONTIG && (UPD_ROW_CONTIG || UPD_SCALAR)) { + for (int j = 0; j < NWORK; j++) { + out[out_idx] = op(out[out_idx], updates[update_idx]); + out_idx++; + if constexpr (!UPD_SCALAR) { + update_idx++; + } + } + } else { + auto out_stride = output_strides[update_ndim - 1]; + auto update_stride = update_strides[update_ndim - 1]; + for (int j = 0; j < NWORK; j++) { + out[out_idx] = op(out[out_idx], updates[update_idx]); + out_idx += out_stride; + if constexpr (!UPD_SCALAR) { + update_idx += update_stride; + } + } + } +} diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e6554c2c4f..935bbd6efe 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -850,7 +850,11 @@ array slice_update( src.shape(), src.dtype(), std::make_shared( - to_stream(s), std::move(start), std::move(stop), std::move(strides)), + to_stream(s), + SliceUpdate::None, + std::move(start), + std::move(stop), + std::move(strides)), {src, upd}); } @@ -895,6 +899,162 @@ array slice_update( {src, upd, start}); } +array slice_update( + const array& src, + const array& update, + Shape start, + Shape stop, + Shape strides, + SliceUpdate::ReduceType mode, + StreamOrDevice s) { + if (start.size() != src.ndim() || stop.size() != src.ndim() || + strides.size() != src.ndim()) { + std::ostringstream msg; + msg << "[slice_update] Invalid number of indices or strides for " + << "array with dimension " << src.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + auto [has_neg_strides, upd_shape] = + normalize_slice(src.shape(), start, stop, strides); + + auto upd = broadcast_to(astype(update, src.dtype(), s), upd_shape, s); + + if (!has_neg_strides && upd_shape == src.shape()) { + switch (mode) { + case SliceUpdate::None: + return upd; + case SliceUpdate::Sum: + return add(src, upd, s); + case SliceUpdate::Prod: + return multiply(src, upd, s); + case SliceUpdate::Max: + return maximum(src, upd, s); + case SliceUpdate::Min: + return minimum(src, upd, s); + } + } + + return array( + src.shape(), + src.dtype(), + std::make_shared( + to_stream(s), + mode, + std::move(start), + std::move(stop), + std::move(strides)), + {src, upd}); +} + +array slice_update_add( + const array& src, + const array& update, + Shape start, + Shape stop, + Shape strides, + StreamOrDevice s /*= {}*/) { + return slice_update( + src, + update, + std::move(start), + std::move(stop), + std::move(strides), + SliceUpdate::Sum, + s); +} + +array slice_update_add( + const array& src, + const array& update, + Shape start, + Shape stop, + StreamOrDevice s /*= {}*/) { + return slice_update_add( + src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s); +} + +array slice_update_prod( + const array& src, + const array& update, + Shape start, + Shape stop, + Shape strides, + StreamOrDevice s /*= {}*/) { + return slice_update( + src, + update, + std::move(start), + std::move(stop), + std::move(strides), + SliceUpdate::Prod, + s); +} + +array slice_update_prod( + const array& src, + const array& update, + Shape start, + Shape stop, + StreamOrDevice s /*= {}*/) { + return slice_update_prod( + src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s); +} + +array slice_update_max( + const array& src, + const array& update, + Shape start, + Shape stop, + Shape strides, + StreamOrDevice s /*= {}*/) { + return slice_update( + src, + update, + std::move(start), + std::move(stop), + std::move(strides), + SliceUpdate::Max, + s); +} + +array slice_update_max( + const array& src, + const array& update, + Shape start, + Shape stop, + StreamOrDevice s /*= {}*/) { + return slice_update_max( + src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s); +} + +array slice_update_min( + const array& src, + const array& update, + Shape start, + Shape stop, + Shape strides, + StreamOrDevice s /*= {}*/) { + return slice_update( + src, + update, + std::move(start), + std::move(stop), + std::move(strides), + SliceUpdate::Min, + s); +} + +array slice_update_min( + const array& src, + const array& update, + Shape start, + Shape stop, + StreamOrDevice s /*= {}*/) { + return slice_update_min( + src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s); +} + std::vector split( const array& a, const Shape& indices, diff --git a/mlx/ops.h b/mlx/ops.h index 74032c01e0..0295a8a843 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -224,6 +224,78 @@ MLX_API array slice_update( std::vector axes, StreamOrDevice s = {}); +/** Slice update and add updates to given slice. */ +MLX_API array slice_update_add( + const array& src, + const array& update, + Shape start, + Shape stop, + Shape strides, + StreamOrDevice s = {}); + +/** Slice update and add updates to given slice with stride 1 in each dimension. + */ +MLX_API array slice_update_add( + const array& src, + const array& update, + Shape start, + Shape stop, + StreamOrDevice s = {}); + +/** Slice update and prod updates to given slice. */ +MLX_API array slice_update_prod( + const array& src, + const array& update, + Shape start, + Shape stop, + Shape strides, + StreamOrDevice s = {}); + +/** Slice update and prod updates to given slice with stride 1 in each + * dimension. */ +MLX_API array slice_update_prod( + const array& src, + const array& update, + Shape start, + Shape stop, + StreamOrDevice s = {}); + +/** Slice update and max updates to given slice. */ +MLX_API array slice_update_max( + const array& src, + const array& update, + Shape start, + Shape stop, + Shape strides, + StreamOrDevice s = {}); + +/** Slice update and max updates to given slice with stride 1 in each dimension. + */ +MLX_API array slice_update_max( + const array& src, + const array& update, + Shape start, + Shape stop, + StreamOrDevice s = {}); + +/** Slice update and min updates to given slice. */ +MLX_API array slice_update_min( + const array& src, + const array& update, + Shape start, + Shape stop, + Shape strides, + StreamOrDevice s = {}); + +/** Slice update and min updates to given slice with stride 1 in each dimension. + */ +MLX_API array slice_update_min( + const array& src, + const array& update, + Shape start, + Shape stop, + StreamOrDevice s = {}); + /** Split an array into sub-arrays along a given axis. */ MLX_API std::vector split(const array& a, int num_splits, int axis, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 92e54f9991..7afb5ac8e1 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -4793,10 +4793,17 @@ std::pair, std::vector> SliceUpdate::vmap( // No vmapping needed if (src_ax == -1 && upd_ax == -1) { - return {{slice_update(src, upd, start, stop, strides, stream())}, {-1}}; + return { + {array( + src.shape(), + src.dtype(), + std::make_shared( + stream(), reduce_type_, start, stop, strides), + {src, upd})}, + {-1}}; } - // Broadcast src + // Broadcast Src if (src_ax == -1) { src = expand_dims(src, upd_ax, stream()); auto shape = src.shape(); @@ -4819,37 +4826,99 @@ std::pair, std::vector> SliceUpdate::vmap( stop.insert(stop.begin() + src_ax, src.shape(src_ax)); strides.insert(strides.begin() + src_ax, 1); - return {{slice_update(src, upd, start, stop, strides, stream())}, {src_ax}}; + return { + {array( + src.shape(), + src.dtype(), + std::make_shared( + stream(), reduce_type_, start, stop, strides), + {src, upd})}, + {src_ax}}; } std::vector SliceUpdate::vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, - const std::vector&) { + const std::vector& outputs) { // Check inputs assert(primals.size() == 2); - auto& cotan = cotangents[0]; - auto& upd = primals[1]; + const array& result = outputs[0]; + const array& values = primals[0]; + const array& updates = primals.back(); + const array& cotan = cotangents[0]; std::vector vjps; for (int num : argnums) { // Vjp for source if (num == 0) { - vjps.push_back(slice_update( - cotan, - zeros_like(upd, stream()), - start_indices_, - end_indices_, - strides_, - stream())); + switch (reduce_type_) { + case SliceUpdate::None: + vjps.push_back(array( + cotan.shape(), + cotan.dtype(), + std::make_shared( + stream(), + reduce_type_, + start_indices_, + end_indices_, + strides_), + {cotan, zeros_like(updates, stream())})); + break; + case SliceUpdate::Sum: + vjps.push_back(cotan); + break; + case SliceUpdate::Max: + case SliceUpdate::Min: + vjps.push_back(where( + equal(result, values, stream()), + cotan, + array(0, cotan.dtype()), + stream())); + break; + case SliceUpdate::Prod: + vjps.push_back(array( + cotan.shape(), + cotan.dtype(), + std::make_shared( + stream(), + reduce_type_, + start_indices_, + end_indices_, + strides_), + {cotan, updates})); + break; + } } // Vjp fpr updates else { - vjps.push_back( - slice(cotan, start_indices_, end_indices_, strides_, stream())); + auto sliced_cotan = + slice(cotan, start_indices_, end_indices_, strides_, stream()); + switch (reduce_type_) { + case SliceUpdate::None: + case SliceUpdate::Sum: + vjps.emplace_back(std::move(sliced_cotan)); + break; + case SliceUpdate::Max: + case SliceUpdate::Min: { + auto sliced_result = + slice(result, start_indices_, end_indices_, strides_, stream()); + vjps.push_back(where( + equal(sliced_result, updates, stream()), + sliced_cotan, + array(0, cotan.dtype()), + stream())); + break; + } + case SliceUpdate::Prod: { + auto sliced_values = + slice(values, start_indices_, end_indices_, strides_, stream()); + vjps.push_back(multiply(sliced_cotan, sliced_values, stream())); + break; + } + } } } @@ -4862,18 +4931,45 @@ std::vector SliceUpdate::jvp( const std::vector& argnums) { // Check inputs assert(primals.size() == 2); - return {slice_update( - tangents[0], - tangents[1], - start_indices_, - end_indices_, - strides_, - stream())}; + + if (argnums.size() != 2) { + throw std::runtime_error( + "[SliceUpdate] JVP for one argument not implemented yet."); + } + + auto result_tan = tangents[0]; + + switch (reduce_type_) { + case SliceUpdate::None: + return {array( + result_tan.shape(), + result_tan.dtype(), + std::make_shared( + stream(), reduce_type_, start_indices_, end_indices_, strides_), + {result_tan, tangents[1]})}; + case SliceUpdate::Sum: + return {array( + result_tan.shape(), + result_tan.dtype(), + std::make_shared( + stream(), reduce_type_, start_indices_, end_indices_, strides_), + {result_tan, tangents[1]})}; + case SliceUpdate::Prod: + case SliceUpdate::Max: + case SliceUpdate::Min: { + throw std::runtime_error( + "[SliceUpdate] JVP for product, minimum and maximum not implemented."); + } + } + + // Appease gcc (although no path reaches here). + return {}; } bool SliceUpdate::is_equivalent(const Primitive& other) const { const auto& s_other = static_cast(other); return ( + reduce_type_ == s_other.reduce_type_ && start_indices_ == s_other.start_indices_ && end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_); } diff --git a/mlx/primitives.h b/mlx/primitives.h index 4091aafcfb..a3ae63e571 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2055,12 +2055,16 @@ class Slice : public UnaryPrimitive { class SliceUpdate : public UnaryPrimitive { public: + enum ReduceType { Max, Min, Sum, Prod, None }; + explicit SliceUpdate( Stream stream, + ReduceType reduce_type, const Shape& start_indices, const Shape& end_indices, const Shape& strides) : UnaryPrimitive(stream), + reduce_type_(reduce_type), start_indices_(start_indices), end_indices_(end_indices), strides_(strides) {} @@ -2070,14 +2074,32 @@ class SliceUpdate : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_NAME(SliceUpdate) + + const char* name() const override { + switch (reduce_type_) { + case Sum: + return "SliceUpdate Sum"; + case Prod: + return "SliceUpdate Prod"; + case Min: + return "SliceUpdate Min"; + case Max: + return "SliceUpdate Max"; + case None: + return "SliceUpdate"; + } + return ""; + } + bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { - return std::make_tuple(start_indices_, end_indices_, strides_); + return std::make_tuple( + reduce_type_, start_indices_, end_indices_, strides_); } private: + ReduceType reduce_type_; Shape start_indices_; Shape end_indices_; Shape strides_; diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 564c4cb45b..5699e0e8a1 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -769,43 +769,53 @@ mlx_compute_scatter_args( throw std::invalid_argument("Cannot index mlx array using the given type."); } -auto mlx_slice_update( +std::tuple, mx::Shape, mx::Shape, mx::Shape> +mlx_compute_slice_update_args( const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { + // Build the slice params + mx::Shape starts(src.ndim(), 0); + mx::Shape stops = src.shape(); + mx::Shape strides(src.ndim(), 1); + // Can't route to slice update if not slice, tuple, or int if (src.ndim() == 0 || nb::isinstance(obj) || (!nb::isinstance(obj) && !nb::isinstance(obj) && !nb::isinstance(obj))) { - return std::make_pair(false, src); + return std::make_tuple( + std::nullopt, std::move(starts), std::move(stops), std::move(strides)); } if (nb::isinstance(obj)) { // Can't route to slice update if any arrays are present for (auto idx : nb::cast(obj)) { if (nb::isinstance(idx) || nb::isinstance(idx)) { - return std::make_pair(false, src); + return std::make_tuple( + std::nullopt, + std::move(starts), + std::move(stops), + std::move(strides)); } } } - // Should be able to route to slice update - // Pre process tuple - auto upd = to_array(v, src.dtype()); + // Should be able to route to slice update just extract the update value and + // and the slice arguments. + + // Cast v to an array and ensure it is the right type + auto update = to_array(v, src.dtype()); // Remove extra leading singletons dimensions from the update int s = 0; - for (; s < static_cast(upd.ndim()) - 1 && upd.shape(s) == 1 && - (upd.ndim() - s) > src.ndim(); + for (; s < static_cast(update.ndim()) - 1 && update.shape(s) == 1 && + (update.ndim() - s) > src.ndim(); s++) { }; auto squeeze_axes = std::vector(s); std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0); - auto up = mx::squeeze(upd, squeeze_axes); + update = mx::squeeze(update, squeeze_axes); - // Build slice update params - mx::Shape starts(src.ndim(), 0); - mx::Shape stops = src.shape(); - mx::Shape strides(src.ndim(), 1); + // Single int then make it a slice of size 1 if (nb::isinstance(obj)) { if (src.ndim() < 1) { std::ostringstream msg; @@ -816,12 +826,11 @@ auto mlx_slice_update( idx = idx < 0 ? idx + stops[0] : idx; starts[0] = idx; stops[0] = idx + 1; - auto out = slice_update( - src, up, std::move(starts), std::move(stops), std::move(strides)); - return std::make_pair(true, out); + return std::make_tuple( + update, std::move(starts), std::move(stops), std::move(strides)); } - // If it's just a simple slice, just do a slice update and return + // Simple slice, just extract it into the first dim if (nb::isinstance(obj)) { // Read slice arguments get_slice_params( @@ -830,16 +839,14 @@ auto mlx_slice_update( strides[0], nb::cast(obj), src.shape(0)); - - // Do slice update - auto out = slice_update(src, up, starts, stops, strides); - return std::make_pair(true, out); + return std::make_tuple( + update, std::move(starts), std::move(stops), std::move(strides)); } // It must be a tuple auto entries = nb::cast(obj); - // Expand ellipses into a series of ':' slices + // Expand ellipsis into a series of ':' slices auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries); // Dimension check @@ -851,15 +858,20 @@ auto mlx_slice_update( // If no non-None indices return the broadcasted update if (non_none_indices == 0) { - return std::make_pair(true, broadcast_to(up, src.shape())); + return std::make_tuple( + broadcast_to(update, src.shape()), + std::move(starts), + std::move(stops), + std::move(strides)); } + // Parse the update slice int unspecified = src.ndim() - non_none_indices; std::vector squeeze_dims; std::vector expand_dims; for (int i = indices.size() - 1, ax = non_none_indices - 1, - upd_ax = upd.ndim() - unspecified - 1; + upd_ax = update.ndim() - unspecified - 1; i >= 0; --i) { auto& pyidx = indices[i]; @@ -887,11 +899,11 @@ auto mlx_slice_update( } } } + update = mx::squeeze( + mx::expand_dims(update, std::move(expand_dims)), std::move(squeeze_dims)); - up = mx::squeeze( - mx::expand_dims(up, std::move(expand_dims)), std::move(squeeze_dims)); - auto out = slice_update(src, up, starts, stops, strides); - return std::make_pair(true, out); + return std::make_tuple( + update, std::move(starts), std::move(stops), std::move(strides)); } std::optional extract_boolean_mask(const nb::object& obj) { @@ -921,9 +933,11 @@ void mlx_set_item( mx::array& src, const nb::object& obj, const ScalarOrArray& v) { - auto [success, out] = mlx_slice_update(src, obj, v); - if (success) { - src.overwrite_descriptor(out); + auto [update, starts, stops, strides] = + mlx_compute_slice_update_args(src, obj, v); + if (update) { + src.overwrite_descriptor( + slice_update(src, *update, starts, stops, strides)); return; } @@ -947,6 +961,12 @@ mx::array mlx_add_item( const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { + auto [update, starts, stops, strides] = + mlx_compute_slice_update_args(src, obj, v); + if (update) { + return slice_update_add(src, *update, starts, stops, strides); + } + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { return scatter_add(src, indices, updates, axes); @@ -959,6 +979,12 @@ mx::array mlx_subtract_item( const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { + auto [update, starts, stops, strides] = + mlx_compute_slice_update_args(src, obj, v); + if (update) { + return slice_update_add(src, -(*update), starts, stops, strides); + } + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { return scatter_add(src, indices, -updates, axes); @@ -971,6 +997,12 @@ mx::array mlx_multiply_item( const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { + auto [update, starts, stops, strides] = + mlx_compute_slice_update_args(src, obj, v); + if (update) { + return slice_update_prod(src, *update, starts, stops, strides); + } + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { return scatter_prod(src, indices, updates, axes); @@ -983,6 +1015,12 @@ mx::array mlx_divide_item( const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { + auto [update, starts, stops, strides] = + mlx_compute_slice_update_args(src, obj, v); + if (update) { + return slice_update_prod(src, reciprocal(*update), starts, stops, strides); + } + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { return scatter_prod(src, indices, reciprocal(updates), axes); @@ -995,6 +1033,12 @@ mx::array mlx_maximum_item( const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { + auto [update, starts, stops, strides] = + mlx_compute_slice_update_args(src, obj, v); + if (update) { + return slice_update_max(src, *update, starts, stops, strides); + } + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { return scatter_max(src, indices, updates, axes); @@ -1007,6 +1051,12 @@ mx::array mlx_minimum_item( const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { + auto [update, starts, stops, strides] = + mlx_compute_slice_update_args(src, obj, v); + if (update) { + return slice_update_min(src, *update, starts, stops, strides); + } + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { return scatter_min(src, indices, updates, axes); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 4efed9dac9..4cd6b6829b 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1423,6 +1423,106 @@ def test_array_at(self): src = src.at[0:1].add(update) self.assertTrue(mx.array_equal(src, mx.array([[2.0, 4.0]]))) + # Test all array.at ops with slice-only indices + a = mx.random.uniform(shape=(10, 5, 2)) + update = mx.ones((2, 5)) + a[1:3, :, 0] = 0 + a = a.at[1:3, :, 0].add(update) + self.assertEqualArray(a[1:3, :, 0], update) + a = a.at[1:3, :, 0].subtract(update) + self.assertEqualArray(a[1:3, :, 0], mx.zeros_like(update)) + a = a.at[1:3, :, 0].add(2 * update) + self.assertEqualArray(a[1:3, :, 0], 2 * update) + a = a.at[1:3, :, 0].multiply(2 * update) + self.assertEqualArray(a[1:3, :, 0], 4 * update) + a = a.at[1:3, :, 0].divide(3 * update) + self.assertEqualArray(a[1:3, :, 0], (4 / 3) * update) + a[1:3, :, 0] = 5 + update = mx.arange(10).reshape(2, 5) + a = a.at[1:3, :, 0].maximum(update) + self.assertEqualArray(a[1:3, :, 0], mx.maximum(a[1:3, :, 0], update)) + a[1:3, :, 0] = 5 + a = a.at[1:3, :, 0].minimum(update) + self.assertEqualArray(a[1:3, :, 0], mx.minimum(a[1:3, :, 0], update)) + + def test_array_at_slice_update_extensive(self): + # Test with transposed inputs + a = mx.zeros((4, 5)) + update = mx.ones((5, 2)).T # Shape (2, 5) + a = a.at[1:3, :].add(update) + self.assertEqualArray(a[1:3, :], update) + + # Test with transposed updates on transposed slice + a = mx.zeros((5, 4)) + update = mx.ones((2, 5)) + a = a.at[:, 1:3].add(update.T) + self.assertEqualArray(a[:, 1:3], update.T) + + # Test with slice of another array as update + source = mx.arange(20, dtype=mx.float32).reshape(4, 5) + a = mx.zeros((4, 5)) + update = source[1:3, :] # Shape (2, 5) + a = a.at[0:2, :].add(update) + self.assertEqualArray(a[0:2, :], source[1:3, :]) + + # Test with both input and update being slices + source = mx.arange(30, dtype=mx.float32).reshape(5, 6) + a = mx.zeros((5, 6)) + a = a.at[1:4, 1:5].add(source[0:3, 0:4]) + self.assertEqualArray(a[1:4, 1:5], source[0:3, 0:4]) + + # Test with transposed slice of another array + source = mx.arange(20, dtype=mx.float32).reshape(4, 5) + a = mx.zeros((5, 4)) + update = source[1:3, :].T # Shape (5, 2) + a = a.at[:, 1:3].add(update) + self.assertEqualArray(a[:, 1:3], update) + + # Test with negative indexing in slices + a = mx.zeros((5, 5)) + update = mx.ones((2, 5)) + a = a.at[-3:-1, :].add(update) + self.assertEqualArray(a[-3:-1, :], update) + + # Test with strided slices + a = mx.zeros((6, 6)) + update = mx.ones((2, 3)) + a = a.at[1:5:2, 0:6:2].add(update) + self.assertEqualArray(a[1:5:2, 0:6:2], update) + + # Test with slice of transposed array + source = mx.arange(20, dtype=mx.float32).reshape(4, 5) + a = mx.zeros((5, 4)) + update = source.T[:, 1:3] # Shape (5, 2) + a = a.at[:, 1:3].add(update) + self.assertEqualArray(a[:, 1:3], update) + + # Test with 3D arrays and transposed updates + a = mx.zeros((3, 4, 5)) + update = mx.ones((4, 3, 5)).transpose(1, 0, 2) # Shape (3, 4, 5) + a = a.at[:, :, :].add(update) + self.assertEqualArray(a, update) + + # Test with slice of 3D array + source = mx.arange(60, dtype=mx.float32).reshape(3, 4, 5) + a = mx.zeros((3, 4, 5)) + update = source[0:2, :, :] + a = a.at[1:3, :, :].add(update) + self.assertEqualArray(a[1:3, :, :], source[0:2, :, :]) + + # Test with mixed slice and index + a = mx.zeros((4, 5, 6)) + update = mx.ones((2, 6)) + a = a.at[1:3, 2, :].add(update) + self.assertEqualArray(a[1:3, 2, :], update) + + # Test with update from strided slice + source = mx.arange(60, dtype=mx.float32).reshape(3, 4, 5) + a = mx.zeros((3, 2, 5)) + update = source[:, ::2, :] # Shape (3, 2, 5) + a = a.at[:, :, :].add(update) + self.assertEqualArray(a, update) + def test_slice_negative_step(self): a_np = np.arange(20) a_mx = mx.array(a_np) diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index c37161a4d5..97330a2866 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -282,65 +282,138 @@ def fun(x, idx): x[idx] = 2.0 return x.sum() - dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1])) - self.assertTrue(mx.array_equal(dfdx, mx.array([1.0, 0.0, 1.0]))) + dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0, 4.0]), mx.array([1, 3])) + self.assertTrue(mx.array_equal(dfdx, mx.array([1.0, 0.0, 1.0, 0.0]))) self.assertEqual(dfdx.dtype, mx.float32) - y = mx.array([0.0, 1.0, 2.0]) + y = mx.array([0.0, 1.0, 2.0, 3.0]) def fun(x, idx): y[idx] = x return y.sum() - dfdx = mx.grad(fun)(mx.array([2.0]), mx.array([1])) - self.assertTrue(mx.array_equal(dfdx, mx.array([1.0]))) + dfdx = mx.grad(fun)(mx.array([2.0, 3.0]), mx.array([1, 3])) + self.assertTrue(mx.array_equal(dfdx, mx.array([1.0, 1.0]))) self.assertEqual(dfdx.dtype, mx.float32) + def test_scatter_add_vjp(self): + def fun(src, updates): + x = src.at[mx.array([1, 3])].add(updates) + return x + + cotan = mx.array([4.0, 5.0, 6.0, 7.0]) + updates = mx.array([1.0, 2.0]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan]) + mx.eval(vjps) + + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0, 7.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([5.0, 7.0]))) + def test_scatter_max_vjp(self): def fun(src, updates): - x = src.at[1].maximum(updates) + x = src.at[mx.array([1, 3])].maximum(updates) return x - cotan = mx.array([4.0, 5.0, 6.0]) - _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0]), mx.array([[3.0]])], [cotan]) + cotan = mx.array([4.0, 5.0, 6.0, 7.0]) + updates = mx.array([1.0, 2.0]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan]) mx.eval(vjps) - # Update larger than value - self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 6.0]))) - self.assertTrue(mx.allclose(vjps[1], mx.array([5.0]))) + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0, 7.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([0.0, 0.0]))) - cotan = mx.array([[4.0], [5.0], [6.0]]) - _, vjps = mx.vjp( - fun, [mx.array([[1.0], [2.0], [3.0]]), mx.array([[[2.0]]])], [cotan] - ) + updates = mx.array([5.0, 6.0]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan]) mx.eval(vjps) - # Update and value are equal - self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]]))) - self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]]))) + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 6.0, 0.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([5.0, 7.0]))) def test_scatter_min_vjp(self): def fun(src, updates): - x = src.at[1].minimum(updates) + x = src.at[mx.array([1, 3])].minimum(updates) return x - cotan = mx.array([4.0, 5.0, 6.0]) - _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0]), mx.array([[3.0]])], [cotan]) + cotan = mx.array([4.0, 5.0, 6.0, 7.0]) + updates = mx.array([5.0, 6.0]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan]) mx.eval(vjps) - # Update larger than value - self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0]))) - self.assertTrue(mx.allclose(vjps[1], mx.array([0.0]))) + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0, 7.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([0.0, 0.0]))) - cotan = mx.array([[4.0], [5.0], [6.0]]) - _, vjps = mx.vjp( - fun, [mx.array([[1.0], [2.0], [3.0]]), mx.array([[[2.0]]])], [cotan] - ) + updates = mx.array([1.0, 1.0]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan]) + mx.eval(vjps) + + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 6.0, 0.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([5.0, 7.0]))) + + def test_slice_update_max_vjp(self): + def fun(src, updates): + x = src.at[1:3].maximum(updates) + return x + + cotan = mx.array([4.0, 5.0, 6.0, 7.0]) + updates = mx.array([[1.0, 2.0]]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan]) + mx.eval(vjps) + + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0, 7.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([[0.0, 0.0]]))) + + updates = mx.array([[5.0, 6.0]]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan]) + mx.eval(vjps) + + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 0.0, 7.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([[5.0, 6.0]]))) + + def test_slice_update_min_vjp(self): + def fun(src, updates): + x = src.at[1:3].minimum(updates) + return x + + cotan = mx.array([4.0, 5.0, 6.0, 7.0]) + updates = mx.array([[5.0, 6.0]]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan]) + mx.eval(vjps) + + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0, 7.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([[0.0, 0.0]]))) + + updates = mx.array([[1.0, 1.0]]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan]) + mx.eval(vjps) + + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 0.0, 7.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([[5.0, 6.0]]))) + + def test_slice_update_add_vjp(self): + def fun(src, updates): + x = src.at[1:3].add(updates) + return x + + cotan = mx.array([4.0, 5.0, 6.0, 7.0]) + updates = mx.array([[1.0, 2.0]]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan]) + mx.eval(vjps) + + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0, 7.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([[5.0, 6.0]]))) + + def test_slice_update_multiply_vjp(self): + def fun(src, updates): + x = src.at[1:3].multiply(updates) + return x + + cotan = mx.array([4.0, 5.0, 6.0, 7.0]) + updates = mx.array([[2.0, 3.0]]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan]) mx.eval(vjps) - # Update and value are equal - self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]]))) - self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]]))) + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 10.0, 18.0, 7.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([[10.0, 18.0]]))) def test_split_against_slice(self): def f_split(x): diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 51df7dcd29..1d05cb0c15 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -367,6 +367,104 @@ TEST_CASE("test slice update") { CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item()); } +TEST_CASE("test slice update add") { + // Basic slice update add + auto x = zeros({8}, float32); + auto y = ones({4}, float32); + auto out = slice_update_add(x, y, {2}, {6}, {1}); + auto expected = array({0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f}); + CHECK(array_equal(out, expected).item()); + + // Overlapping slice update add + x = zeros({8}, float32); + y = ones({4}, float32); + out = slice_update_add(x, y, {2}, {6}, {1}); + out = slice_update_add(out, y, {4}, {8}, {1}); + expected = array({0.0f, 0.0f, 1.0f, 1.0f, 2.0f, 2.0f, 1.0f, 1.0f}); + CHECK(array_equal(out, expected).item()); + + // Slice update add with stride + x = zeros({10}, float32); + y = ones({3}, float32); + out = slice_update_add(x, y, {1}, {7}, {2}); + expected = + array({0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + CHECK(array_equal(out, expected).item()); + + // 2D slice update add + x = zeros({4, 4}, float32); + y = ones({2, 2}, float32); + out = slice_update_add(x, y, {1, 1}, {3, 3}, {1, 1}); + expected = reshape( + array( + {0.0f, + 0.0f, + 0.0f, + 0.0f, + 0.0f, + 1.0f, + 1.0f, + 0.0f, + 0.0f, + 1.0f, + 1.0f, + 0.0f, + 0.0f, + 0.0f, + 0.0f, + 0.0f}, + {4, 4}), + {4, 4}); + CHECK(array_equal(out, expected).item()); + + // Overlapping 2D slice update add + x = zeros({4, 4}, float32); + y = ones({2, 2}, float32); + out = slice_update_add(x, y, {0, 0}, {2, 2}, {1, 1}); + out = slice_update_add(out, y, {1, 1}, {3, 3}, {1, 1}); + expected = reshape( + array( + {1.0f, + 1.0f, + 0.0f, + 0.0f, + 1.0f, + 2.0f, + 1.0f, + 0.0f, + 0.0f, + 1.0f, + 1.0f, + 0.0f, + 0.0f, + 0.0f, + 0.0f, + 0.0f}, + {4, 4}), + {4, 4}); + CHECK(array_equal(out, expected).item()); + + // Slice update add with different dtypes + x = zeros({4}, int32); + y = ones({2}, int32); + out = slice_update_add(x, y, {1}, {3}, {1}); + expected = array({0, 1, 1, 0}); + CHECK(array_equal(out, expected).item()); + + // Empty slice update add + x = arange(4, float32); + y = array({}); + out = slice_update_add(x, y, {0}, {0}, {1}); + CHECK(array_equal(out, x).item()); + + // Full array slice update add + x = ones({4}, float32); + y = full({4}, 2.0f, float32); + out = slice_update_add(x, y, {0}, {4}, {1}); + expected = array({3.0f, 3.0f, 3.0f, 3.0f}); + CHECK(array_equal(out, expected).item()); +} + TEST_CASE("test dynamic slice") { auto src = reshape(arange(6), {2, 3}); CHECK_THROWS(slice(src, array({1, 0, 0}), {0, 0, 0}, {1, 1}));