Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions benchmarks/python/slice_update_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright © 2023-2024 Apple Inc.

import argparse

import mlx.core as mx
import torch
from time_utils import measure_runtime


def benchmark_slice_update_mlx(dst_shape, slice_shape, slice_range, dtype, iters=10):
def slice_update(arguments):
for i in range(iters):
arguments["dst"] = (
arguments["dst"].at[slice_range].add(arguments["updates"])
)
mx.eval(arguments)

dtype = getattr(mx, dtype)
arguments = {
"dst": mx.random.normal(dst_shape).astype(dtype),
"updates": mx.random.normal(slice_shape).astype(dtype),
}

runtime = measure_runtime(slice_update, arguments=arguments)
bytes_processed = (
arguments["dst"][slice_range].nbytes * 2 + arguments["updates"].nbytes
) * iters
bandwidth_gb_s = bytes_processed / runtime / 1e6
return runtime, bandwidth_gb_s


def benchmark_slice_update_torch(
dst_shape, slice_shape, slice_range, device, dtype, iters=10
):
def slice_update(dst, updates, slice_range):
for i in range(iters):
dst[slice_range] = dst[slice_range] + updates
if device == torch.device("mps"):
torch.mps.synchronize()

dtype = getattr(torch, dtype)
updates = torch.randn(slice_shape, dtype=dtype).to(device)
dst = torch.randn(dst_shape, dtype=dtype).to(device)

runtime = measure_runtime(
slice_update, dst=dst, updates=updates, slice_range=slice_range
)
bytes_processed = (dst[slice_range].nbytes * 2 + updates.nbytes) * iters
bandwidth_gb_s = bytes_processed / runtime / 1e6
return runtime, bandwidth_gb_s


if __name__ == "__main__":
parser = argparse.ArgumentParser("Slice update benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()

if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
elif torch.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
raise ValueError()

dtypes = ["float32", "bfloat16"]

test_cases = [
((10_000_000,), slice(0, 1_000_000), (1_000_000,)),
((100_000,), slice(10_000, 20_000), (10_000,)),
((1000, 64), slice(100, 200), (100, 64)),
((100, 100, 64), slice(20, 40), (20, 100, 64)),
(
(2048, 2048, 128),
(slice(500, 1500), slice(200, 1200), slice(32, 96)),
(1000, 1000, 64),
),
(
(2048, 2048, 128),
(slice(1800, 1850), slice(100, 200), slice(64, 128)),
(50, 100, 64),
),
(
(2048, 2048, 128),
(slice(1000, 1010), slice(1000, 1010), slice(64, 128)),
(10, 10, 64),
),
]

print(
f"{'Dtype':<12} {'Dst Shape':<25} {'Update Shape':<20} "
f"{'MLX (ms)':<12} {'MLX GB/s':<12} {'Torch (ms)':<12} {'Torch GB/s':<12}"
)
print("-" * 110)

for dtype in dtypes:
for dst_shape, slice_range, update_shape in test_cases:
mlx_time, mlx_bw = benchmark_slice_update_mlx(
dst_shape, update_shape, slice_range, dtype
)
torch_time, torch_bw = benchmark_slice_update_torch(
dst_shape, update_shape, slice_range, device, dtype
)
print(
f"{dtype:<12} {str(dst_shape):<25} {str(update_shape):<20} "
f"{mlx_time:<12.3f} {mlx_bw:<12.2f} {torch_time:<12.3f} {torch_bw:<12.2f}"
)
33 changes: 33 additions & 0 deletions mlx/backend/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,39 @@ struct ContiguousIterator {
loc += strides_[i];
}

void step(int64_t s) {
int dims = shape_.size();
if (dims == 0) {
return;
}
int i = dims - 1;
while (s > 0) {
if (shape_[i] - pos_[i] > 1) {
int steps = static_cast<int>(
std::min(static_cast<int64_t>(shape_[i] - pos_[i] - 1), s));
pos_[i] += steps;
loc += strides_[i] * steps;
s -= steps;
} else {
while (pos_[i] == (shape_[i] - 1) && i > 0) {
pos_[i] = 0;
loc -= (shape_[i] - 1) * strides_[i];
i--;
}
pos_[i]++;
loc += strides_[i];
s--;
}
}
}

int64_t contiguous_suffix() {
if (shape_.size() == 0) {
return 0;
}
return (strides_.back() == 1) ? shape_.back() : 0;
}

void seek(int64_t n) {
loc = 0;
for (int i = shape_.size() - 1; i >= 0; --i) {
Expand Down
133 changes: 130 additions & 3 deletions mlx/backend/cpu/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
#include <cmath>

#include "mlx/allocator.h"
#include "mlx/primitives.h"

#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/slicing.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"

namespace mlx::core {

Expand Down Expand Up @@ -788,7 +791,7 @@ void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& mask = inputs[1];
auto& src = inputs[2];

// Copy src into out (copy allocates memory for out)
// Copy dst into out (copy allocates memory for out)
auto ctype =
dst.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(dst, out, ctype, stream());
Expand Down Expand Up @@ -851,4 +854,128 @@ void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
});
}

template <typename T, typename Op>
void slice_update_impl(
array& out,
const array& upd,
int64_t data_offset,
const Strides& out_strides) {
ContiguousIterator out_it(upd.shape(), out_strides, upd.ndim());
ContiguousIterator upd_it(upd);
Op op;

constexpr int SIMD_START = 32;

T* out_ptr = out.data<T>() + data_offset;
const T* upd_ptr = upd.data<T>();
int64_t size = upd.size();
int64_t suffix = out_it.contiguous_suffix();

if (upd.data_size() == 1) {
if (suffix >= SIMD_START) {
for (int64_t i = 0; i < size; i += suffix) {
VectorScalar<Op>{}(
out_ptr + out_it.loc, upd_ptr, out_ptr + out_it.loc, suffix);
out_it.step(suffix);
}
} else {
T update = upd_ptr[0];
for (int64_t i = 0; i < size; i++) {
out_ptr[out_it.loc] = op(out_ptr[out_it.loc], update);
out_it.step();
}
}
} else if (suffix == upd_it.contiguous_suffix() && suffix >= SIMD_START) {
for (int64_t i = 0; i < size; i += suffix) {
VectorVector<Op>{}(
out_ptr + out_it.loc,
upd_ptr + upd_it.loc,
out_ptr + out_it.loc,
suffix);
out_it.step(suffix);
upd_it.step(suffix);
}
} else {
for (int64_t i = 0; i < size; i++) {
out_ptr[out_it.loc] = op(out_ptr[out_it.loc], upd_ptr[upd_it.loc]);
out_it.step();
upd_it.step();
}
}
}

void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(allocator::malloc(0));
return;
}

auto& in = inputs[0];
auto& upd = inputs[1];

if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}

// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());

// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_);

// Do copy
if (reduce_type_ == SliceUpdate::None) {
copy_cpu_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral,
stream());
return;
}

auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(upd);
encoder.set_output_array(out);
encoder.dispatch([upd = array::unsafe_weak_copy(upd),
out = array::unsafe_weak_copy(out),
data_offset = data_offset,
out_strides = std::move(out_strides),
reduce_type = reduce_type_]() mutable {
dispatch_all_types(out.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
switch (reduce_type) {
case SliceUpdate::Sum:
slice_update_impl<T, detail::Add>(out, upd, data_offset, out_strides);
break;
case SliceUpdate::Prod:
slice_update_impl<T, detail::Multiply>(
out, upd, data_offset, out_strides);
break;
case SliceUpdate::Max:
slice_update_impl<T, detail::Maximum>(
out, upd, data_offset, out_strides);
break;
case SliceUpdate::Min:
slice_update_impl<T, detail::Minimum>(
out, upd, data_offset, out_strides);
break;
case SliceUpdate::None:
// Should never be here
break;
}
});
});
}

} // namespace mlx::core
38 changes: 0 additions & 38 deletions mlx/backend/cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,44 +398,6 @@ void DynamicSliceUpdate::eval_cpu(
}
}

void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(allocator::malloc(0));
return;
}

auto& in = inputs[0];
auto& upd = inputs[1];

if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}

// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());

// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_);

// Do copy
copy_cpu_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral,
stream());
}

void View::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
Expand Down
Loading
Loading