From 2daec6ca84bb868876864e8874f7943b6ed13fc5 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 20 Mar 2026 16:14:20 +0900 Subject: [PATCH] Decouple CommandEncoder from Device --- docs/src/dev/extensions.rst | 4 +- examples/extensions/axpby/axpby.cpp | 2 +- mlx/backend/cuda/quantized/quantized.cpp | 3 +- mlx/backend/metal/allocator.cpp | 13 ++-- mlx/backend/metal/allocator.h | 5 +- mlx/backend/metal/binary.cpp | 2 +- mlx/backend/metal/compiled.cpp | 2 +- mlx/backend/metal/conv.cpp | 26 +++---- mlx/backend/metal/copy.cpp | 4 +- mlx/backend/metal/custom_kernel.cpp | 4 +- mlx/backend/metal/device.cpp | 68 +++++-------------- mlx/backend/metal/device.h | 21 ++---- mlx/backend/metal/eval.cpp | 28 ++++---- mlx/backend/metal/event.cpp | 12 ++-- mlx/backend/metal/fence.cpp | 20 +++--- mlx/backend/metal/fft.cpp | 9 ++- mlx/backend/metal/hadamard.cpp | 4 +- mlx/backend/metal/indexing.cpp | 20 +++--- mlx/backend/metal/logsumexp.cpp | 6 +- mlx/backend/metal/matmul.cpp | 66 ++++++++++-------- mlx/backend/metal/normalization.cpp | 16 ++--- mlx/backend/metal/primitives.cpp | 6 +- mlx/backend/metal/quantized.cpp | 39 +++++------ mlx/backend/metal/reduce.cpp | 10 +-- mlx/backend/metal/resident.cpp | 1 + mlx/backend/metal/resident.h | 4 +- mlx/backend/metal/rope.cpp | 2 +- .../metal/scaled_dot_product_attention.cpp | 16 ++--- mlx/backend/metal/scan.cpp | 4 +- mlx/backend/metal/slicing.cpp | 6 +- mlx/backend/metal/softmax.cpp | 2 +- mlx/backend/metal/sort.cpp | 6 +- mlx/backend/metal/ternary.cpp | 2 +- mlx/backend/metal/unary.cpp | 2 +- 34 files changed, 198 insertions(+), 237 deletions(-) diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 392c7cbf94..f98f096a0d 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -404,7 +404,7 @@ below. auto kernel = d.get_kernel(kname, lib); // Prepare to encode kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = mx::metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); // Kernel parameters are registered with buffer indices corresponding to @@ -448,7 +448,7 @@ We can now call the :meth:`axpby` operation on both the CPU and the GPU! A few things to note about MLX and Metal before moving on. MLX keeps track of the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is -associated. We rely on :meth:`d.get_command_encoder` to give us the active +associated. We rely on :meth:`metal::get_command_encoder` to give us the active metal compute command encoder instead of building a new one and calling :meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute pipelines) to the active command buffer until some specified limit is hit or diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 31badbbda9..9ade5c4830 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -192,7 +192,7 @@ void Axpby::eval_gpu( auto kernel = d.get_kernel(kname, lib); // Prepare to encode kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = mx::metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); // Kernel parameters are registered with buffer indices corresponding to diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index d7252ec196..608a51700c 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -109,8 +109,7 @@ void fast::Quantize::eval_gpu( std::vector& outputs) { nvtx3::scoped_range r("Quantize::eval_gpu"); auto& s = stream(); - auto& d = cu::device(s.device); - auto& enc = d.get_command_encoder(s); + auto& enc = cu::get_command_encoder(s); if (dequantize_) { auto wq = ensure_row_contiguous(inputs[0], enc, s); auto scales = ensure_row_contiguous(inputs[1], enc, s); diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 3f9d9b197d..222c6fd9fa 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -31,8 +31,9 @@ void* Buffer::raw_ptr() { namespace metal { -MetalAllocator::MetalAllocator() - : device_(device(mlx::core::Device::gpu).mtl_device()), +MetalAllocator::MetalAllocator(Device& d) + : device_(d.mtl_device()), + residency_set_(d.residency_set()), buffer_cache_( vm_page_size, [](MTL::Buffer* buf) { return buf->length(); }, @@ -42,8 +43,7 @@ MetalAllocator::MetalAllocator() } auto pool = metal::new_scoped_memory_pool(); buf->release(); - }), - residency_set_(device_) { + }) { const auto& info = gpu::device_info(0); auto memsize = std::get(info.at("memory_size")); auto max_rec_size = @@ -52,8 +52,6 @@ MetalAllocator::MetalAllocator() block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize); gc_limit_ = std::min(static_cast(0.95 * max_rec_size), block_limit_); max_pool_size_ = block_limit_; - device(mlx::core::Device::gpu) - .set_residency_set(residency_set_.mtl_residency_set()); bool is_vm = std::get(info.at("device_name")) == "Apple Paravirtual device"; if (is_vm) { @@ -226,7 +224,8 @@ MetalAllocator& allocator() { // By creating the |allocator_| on heap, the destructor of MetalAllocator // will not be called on exit and buffers in the cache will be leaked. This // can save some time at program exit. - static MetalAllocator* allocator_ = new MetalAllocator; + static MetalAllocator* allocator_ = + new MetalAllocator(device(mlx::core::Device::gpu)); return *allocator_; } diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 49e09d7f6b..4cbbfb0adc 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -9,7 +9,6 @@ #include "mlx/allocator.h" #include "mlx/backend/common/buffer_cache.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/resident.h" namespace mlx::core::metal { @@ -52,13 +51,13 @@ class MetalAllocator : public allocator::Allocator { static constexpr int small_size_ = 256; static constexpr int heap_size_ = 1 << 20; - MetalAllocator(); + MetalAllocator(Device& d); ~MetalAllocator(); friend MetalAllocator& allocator(); NS::SharedPtr heap_; - ResidencySet residency_set_; + ResidencySet& residency_set_; // Caching allocator BufferCache buffer_cache_; diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 8c0e8c3332..51771e9b1b 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -106,7 +106,7 @@ void binary_op_gpu_inplace( auto kernel = outputs.size() == 2 ? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op) : get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); int arg_idx = 0; diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index eb51ab750e..cdb0a471be 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -389,7 +389,7 @@ void Compiled::eval_gpu( kernel_name += "_large"; } auto kernel = d.get_kernel(kernel_name, lib); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); // Put the inputs in diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index fac135fe18..5d032779d3 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -26,7 +26,7 @@ ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { return x; } auto result = contiguous_copy_gpu(x, s); - d.add_temporary(result, s.index); + metal::get_command_encoder(s).add_temporary(result); return result; } @@ -52,7 +52,7 @@ void explicit_gemm_conv_ND_gpu( std::string kname; kname.reserve(32); concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); @@ -132,7 +132,7 @@ void explicit_gemm_conv_group_ND_gpu( kname.reserve(32); concatenate( kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); @@ -286,7 +286,7 @@ void implicit_gemm_conv_2D_gpu( small_filter ? 's' : 'l'); // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_steel_conv_kernel( d, kname, @@ -469,7 +469,7 @@ void implicit_gemm_conv_2D_general_gpu( }; // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_steel_conv_general_kernel( d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn); compute_encoder.set_compute_pipeline_state(kernel); @@ -595,7 +595,7 @@ void implicit_gemm_conv_3D_gpu( small_filter ? 's' : 'l'); // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_steel_conv_3d_kernel(d, kname, out, bm, bn, bk, wm, wn, small_filter); compute_encoder.set_compute_pipeline_state(kernel); @@ -644,7 +644,7 @@ void pad_and_slice_conv_3D_gpu( array x_copy(xshape, x.dtype(), nullptr, {}); array zero(0, x.dtype()); pad_gpu(x, zero, x_copy, {0, -1}, {0, 0}, s); - d.add_temporary(x_copy, s.index); + metal::get_command_encoder(s).add_temporary(x_copy); return x_copy; }; @@ -804,7 +804,7 @@ void winograd_conv_2D_gpu( type_to_name(out), "_bc", bc); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); @@ -837,7 +837,7 @@ void winograd_conv_2D_gpu( type_to_name(out), "_bc", bc); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); @@ -889,7 +889,7 @@ void winograd_conv_2D_gpu( type_to_name(out), "_bo", bc); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); @@ -950,7 +950,7 @@ void depthwise_conv_2D_gpu( "_tgp_w_", tw, "_do_flip_", do_flip ? 't' : 'n'); // clang-format on - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(base_name, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); @@ -1044,7 +1044,7 @@ void depthwise_conv_1D_gpu( type_to_name(out), large ? "_large" : ""); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(base_name); compute_encoder.set_compute_pipeline_state(kernel); @@ -1348,7 +1348,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { // Record copies if (!copies.empty()) { - d.add_temporaries(std::move(copies), s.index); + metal::get_command_encoder(s).add_temporaries(std::move(copies)); } } diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 24c1506c22..97c5fccbca 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -107,7 +107,7 @@ void copy_gpu_inplace( auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) : get_copy_kernel(d, kernel_name, in, out); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); inp_offset *= size_of(in.dtype()); @@ -190,7 +190,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) { std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s"); concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out)); auto kernel = get_copy_kernel(d, kernel_name, val, out); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(val, 0); diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index eec6645bdd..6d33ff5007 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -378,7 +378,7 @@ void CustomKernel::eval_gpu( auto lib = d.get_library(name_, [this] { return metal::utils() + source_; }); auto kernel = d.get_kernel(name_, lib); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); int index = 0; for (int i = 0; i < checked_inputs.size(); i++) { @@ -424,7 +424,7 @@ void CustomKernel::eval_gpu( MTL::Size grid_dims = MTL::Size(gx, gy, gz); compute_encoder.dispatch_threads(grid_dims, group_dims); - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); } } // namespace mlx::core::fast diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index bdf2df3936..8da66df85f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -46,6 +46,7 @@ auto get_metal_version() { } NS::SharedPtr load_device() { + auto pool = new_scoped_memory_pool(); auto devices = NS::TransferPtr(MTL::CopyAllDevices()); auto device = NS::RetainPtr(static_cast(devices->object(0))) ?: NS::TransferPtr(MTL::CreateSystemDefaultDevice()); @@ -249,7 +250,7 @@ MTL::Library* load_library( CommandEncoder::CommandEncoder( Device& d, int index, - const MTL::ResidencySet* residency_set) + ResidencySet& residency_set) : device_(d) { auto pool = new_scoped_memory_pool(); queue_ = NS::TransferPtr(device_.mtl_device()->newCommandQueue()); @@ -257,8 +258,8 @@ CommandEncoder::CommandEncoder( throw std::runtime_error( "[metal::CommandEncoder] Failed to make new command queue."); } - if (residency_set) { - queue_->addResidencySet(residency_set); + if (residency_set.mtl_residency_set()) { + queue_->addResidencySet(residency_set.mtl_residency_set()); } debug_set_stream_queue_label(queue_.get(), index); buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences()); @@ -441,9 +442,8 @@ MTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() { return encoder_.get(); } -Device::Device() { +Device::Device() : device_(load_device()), residency_set_(device_.get()) { auto pool = new_scoped_memory_pool(); - device_ = load_device(); default_library_ = NS::TransferPtr(load_default_library(device_.get())); arch_ = env::metal_gpu_arch(); if (arch_.empty()) { @@ -487,38 +487,6 @@ Device::Device() { Device::~Device() = default; -bool Device::command_buffer_needs_commit(int index) { - return get_command_encoder(index).needs_commit(); -} - -MTL::CommandBuffer* Device::get_command_buffer(int index) { - return get_command_encoder(index).get_command_buffer(); -} - -void Device::commit_command_buffer(int index) { - get_command_encoder(index).commit(); -} - -void Device::add_temporary(array arr, int index) { - get_command_encoder(index).add_temporary(std::move(arr)); -} - -void Device::add_temporaries(std::vector arrays, int index) { - get_command_encoder(index).add_temporaries(std::move(arrays)); -} - -void Device::end_encoding(int index) { - get_command_encoder(index).end_encoding(); -} - -CommandEncoder& Device::get_command_encoder(int index) { - auto it = encoders_.find(index); - if (it == encoders_.end()) { - it = encoders_.try_emplace(index, *this, index, residency_set_).first; - } - return it->second; -} - MTL::Library* Device::get_library( const std::string& name, const std::string& path /* = "" */) { @@ -793,21 +761,6 @@ MTL::ComputePipelineState* Device::get_kernel( linked_functions); } -void Device::set_residency_set(const MTL::ResidencySet* residency_set) { - if (residency_set_ != nullptr) { - throw std::runtime_error( - "[Device::set_residency_set] Can only be set once."); - } - if (residency_set == nullptr) { - return; - } - residency_set_ = residency_set; - // Attach residency set to existing command queues - for (auto& [_, encoder] : encoders_) { - encoder.get_command_queue()->addResidencySet(residency_set_); - } -} - Device& device(mlx::core::Device) { // Leak singleton device intentionally, to avoid cases where a compute kernel // returns and tries to access the object after it has been freed by the main @@ -816,6 +769,17 @@ Device& device(mlx::core::Device) { return *metal_device; } +CommandEncoder& get_command_encoder(Stream s) { + // Leak the command encoders for the same reason with device. + static auto* encoders = new std::unordered_map; + auto it = encoders->find(s.index); + if (it == encoders->end()) { + auto& d = device(s.device); + it = encoders->try_emplace(s.index, d, s.index, d.residency_set()).first; + } + return it->second; +} + NS::SharedPtr new_scoped_memory_pool() { return NS::TransferPtr(NS::AutoreleasePool::alloc()->init()); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 4889887e2e..9e58b92f0c 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -11,6 +11,7 @@ #include #include "mlx/array.h" +#include "mlx/backend/metal/resident.h" #include "mlx/device.h" namespace mlx::core::metal { @@ -22,7 +23,7 @@ class Device; class MLX_API CommandEncoder { public: - CommandEncoder(Device& d, int index, const MTL::ResidencySet* residency_set); + CommandEncoder(Device& d, int index, ResidencySet& residency_set); CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; @@ -148,12 +149,6 @@ class MLX_API Device { return std::make_tuple(max_ops_per_buffer_, max_mb_per_buffer_); } - MTL::CommandBuffer* get_command_buffer(int index); - bool command_buffer_needs_commit(int index); - void commit_command_buffer(int index); - CommandEncoder& get_command_encoder(int index); - void end_encoding(int index); - MTL::Library* get_library( const std::string& name, const std::string& path = ""); @@ -177,11 +172,9 @@ class MLX_API Device { const MTLFCList& func_consts = {}, const std::vector& linked_functions = {}); - // Record temporary arrays for the given stream index - void add_temporary(array arr, int index); - void add_temporaries(std::vector arrays, int index); - - void set_residency_set(const MTL::ResidencySet* residency_set); + ResidencySet& residency_set() { + return residency_set_; + } private: NS::SharedPtr build_library_(const std::string& source_string); @@ -214,7 +207,7 @@ class MLX_API Device { const std::vector& linked_functions = {}); NS::SharedPtr device_; - std::unordered_map encoders_; + ResidencySet residency_set_; std::shared_mutex kernel_mtx_; std::shared_mutex library_mtx_; @@ -224,7 +217,6 @@ class MLX_API Device { MTL::Library*, std::unordered_map>> library_kernels_; - const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; int arch_gen_; int max_ops_per_buffer_; @@ -232,6 +224,7 @@ class MLX_API Device { }; MLX_API Device& device(mlx::core::Device); +MLX_API CommandEncoder& get_command_encoder(Stream s); NS::SharedPtr new_scoped_memory_pool(); diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index 5790a21685..00ed754e5f 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -13,7 +13,7 @@ void init() {} void new_stream(Stream stream) { if (stream.device == mlx::core::Device::gpu) { - metal::device(stream.device).get_command_encoder(stream.index); + metal::get_command_encoder(stream); } } @@ -29,8 +29,8 @@ inline void check_error(MTL::CommandBuffer* cbuf) { void eval(array& arr) { auto pool = metal::new_scoped_memory_pool(); auto s = arr.primitive().stream(); - auto& d = metal::device(s.device); - auto command_buffer = d.get_command_buffer(s.index); + auto& encoder = metal::get_command_encoder(s); + auto* command_buffer = encoder.get_command_buffer(); auto outputs = arr.outputs(); { @@ -56,15 +56,15 @@ void eval(array& arr) { buffers.erase(it); } - if (d.command_buffer_needs_commit(s.index)) { - d.end_encoding(s.index); + if (encoder.needs_commit()) { + encoder.end_encoding(); scheduler::notify_new_task(s); command_buffer->addCompletedHandler( [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { scheduler::notify_task_completion(s); check_error(cbuf); }); - d.commit_command_buffer(s.index); + encoder.commit(); } else { command_buffer->addCompletedHandler( [buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { @@ -75,20 +75,20 @@ void eval(array& arr) { void finalize(Stream s) { auto pool = metal::new_scoped_memory_pool(); - auto& d = metal::device(s.device); - auto cb = d.get_command_buffer(s.index); - d.end_encoding(s.index); + auto& encoder = metal::get_command_encoder(s); + auto* cb = encoder.get_command_buffer(); + encoder.end_encoding(); cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); - d.commit_command_buffer(s.index); + encoder.commit(); } void synchronize(Stream s) { auto pool = metal::new_scoped_memory_pool(); - auto& d = metal::device(s.device); - auto cb = d.get_command_buffer(s.index); + auto& encoder = metal::get_command_encoder(s); + auto* cb = encoder.get_command_buffer(); cb->retain(); - d.end_encoding(s.index); - d.commit_command_buffer(s.index); + encoder.end_encoding(); + encoder.commit(); cb->waitUntilCompleted(); check_error(cb); cb->release(); diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index 39a49230d2..78ed4fafe2 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -31,9 +31,9 @@ void Event::wait(Stream stream) { if (stream.device == Device::cpu) { scheduler::enqueue(stream, [*this]() mutable { wait(); }); } else { - auto& d = metal::device(stream.device); - d.end_encoding(stream.index); - auto command_buffer = d.get_command_buffer(stream.index); + auto& encoder = metal::get_command_encoder(stream); + encoder.end_encoding(); + auto* command_buffer = encoder.get_command_buffer(); command_buffer->encodeWait(static_cast(event_.get()), value()); command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {}); } @@ -45,9 +45,9 @@ void Event::signal(Stream stream) { static_cast(event_.get())->setSignaledValue(value()); }); } else { - auto& d = metal::device(stream.device); - d.end_encoding(stream.index); - auto command_buffer = d.get_command_buffer(stream.index); + auto& encoder = metal::get_command_encoder(stream); + encoder.end_encoding(); + auto* command_buffer = encoder.get_command_buffer(); command_buffer->encodeSignalEvent( static_cast(event_.get()), value()); command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {}); diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index da73d5d913..0ff7e7f3b4 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -69,19 +69,17 @@ void Fence::wait(Stream stream, const array& x) { } auto& d = metal::device(stream.device); - auto idx = stream.index; + auto& compute_encoder = metal::get_command_encoder(stream); if (!f.use_fast) { - d.end_encoding(idx); - auto command_buffer = d.get_command_buffer(idx); + compute_encoder.end_encoding(); + auto* command_buffer = compute_encoder.get_command_buffer(); command_buffer->encodeWait(static_cast(f.fence), f.count); command_buffer->addCompletedHandler( [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); return; } - auto& compute_encoder = d.get_command_encoder(idx); - // Register outputs to ensure that no kernels which depends on the // output starts before this one is done compute_encoder.register_output_array(x); @@ -95,7 +93,7 @@ void Fence::wait(Stream stream, const array& x) { compute_encoder.set_bytes(f.count, 1); compute_encoder.dispatch_threads(kernel_dims, kernel_dims); - d.get_command_buffer(idx)->addCompletedHandler( + compute_encoder.get_command_buffer()->addCompletedHandler( [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); } @@ -117,10 +115,11 @@ void Fence::update(Stream stream, const array& x, bool cross_device) { } auto& d = metal::device(stream.device); - auto idx = stream.index; + auto& compute_encoder = metal::get_command_encoder(stream); + if (!f.use_fast) { - d.end_encoding(idx); - auto command_buffer = d.get_command_buffer(idx); + compute_encoder.end_encoding(); + auto* command_buffer = compute_encoder.get_command_buffer(); command_buffer->encodeSignalEvent( static_cast(f.fence), f.count); command_buffer->addCompletedHandler( @@ -129,7 +128,6 @@ void Fence::update(Stream stream, const array& x, bool cross_device) { } // Launch input visibility kernels - auto& compute_encoder = d.get_command_encoder(idx); if (cross_device) { auto kernel = d.get_kernel("input_coherent"); uint32_t nthreads = (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / @@ -155,7 +153,7 @@ void Fence::update(Stream stream, const array& x, bool cross_device) { compute_encoder.set_bytes(f.count, 1); compute_encoder.dispatch_threads(kernel_dims, kernel_dims); - d.get_command_buffer(idx)->addCompletedHandler( + compute_encoder.get_command_buffer()->addCompletedHandler( [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); } diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index d99e1badb3..61eb02dac9 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -515,6 +515,7 @@ void fft_op( bool inplace, const Stream& s) { auto& d = metal::device(s.device); + auto& compute_encoder = metal::get_command_encoder(s); size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis); if (n == 1) { @@ -581,7 +582,7 @@ void fft_op( auto plan = plan_fft(n); if (plan.four_step) { four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace); - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); return; } @@ -654,7 +655,6 @@ void fft_op( // We can perform 2 RFFTs at once so the batch size is halved. batch_size = (batch_size + 2 - 1) / 2; } - auto& compute_encoder = d.get_command_encoder(s.index); auto in_type_str = in.dtype() == float32 ? "float" : "float2"; auto out_type_str = out.dtype() == float32 ? "float" : "float2"; // Only required by four step @@ -745,7 +745,7 @@ void fft_op( compute_encoder.dispatch_threads(grid_dims, group_dims); } - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); } void fft_op( @@ -789,8 +789,7 @@ void nd_fft_op( fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); } - auto& d = metal::device(s.device); - d.add_temporaries(std::move(temp_arrs), s.index); + metal::get_command_encoder(s).add_temporaries(std::move(temp_arrs)); } void FFT::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index 65a8771513..223c6f3f5f 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -132,7 +132,7 @@ void hadamard_mn_contiguous( // Launch the strided transform for n1 if (n1 > 1) { - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel("n1" + kname, lib); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x, 0); @@ -142,7 +142,7 @@ void hadamard_mn_contiguous( } // Launch the transform for n2 - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel("n2" + kname, lib); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(n1 > 1 ? y : x, 0); diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 9c8a36392d..cdb4a03bb1 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -111,7 +111,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { return kernel_source; }); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); @@ -164,7 +164,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { return kernel_source; }); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); @@ -343,7 +343,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { return kernel_source; }); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kernel_name, lib); size_t nthreads = upd.size(); @@ -482,7 +482,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { return kernel_source; }); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); @@ -598,7 +598,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { return kernel_source; }); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); @@ -649,6 +649,7 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); + auto& compute_encoder = metal::get_command_encoder(s); const size_t total = mask.size(); const CopyType ct = (total == 1) @@ -661,18 +662,18 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { array mask_flat = flatten_in_eval(mask, 1, -1, s); if (mask_flat.data() != mask.data()) { - d.add_temporary(mask_flat, s.index); + compute_encoder.add_temporary(mask_flat); } if (!mask_flat.flags().row_contiguous) { mask_flat = contiguous_copy_gpu(mask_flat, s); - d.add_temporary(mask_flat, s.index); + compute_encoder.add_temporary(mask_flat); } // Prefix (exclusive) of mask → scatter_offsets array scatter_offsets(mask_flat.shape(), uint32, nullptr, {}); scatter_offsets.set_data(allocator::malloc(scatter_offsets.nbytes())); - d.add_temporary(scatter_offsets, s.index); + compute_encoder.add_temporary(scatter_offsets); scan_gpu_inplace( mask_flat, @@ -704,7 +705,6 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { // Binding int bind_idx = 0; const int ndim = static_cast(src.ndim()); - auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(mask_flat, bind_idx++); compute_encoder.set_input_array(scatter_offsets, bind_idx++); @@ -842,7 +842,7 @@ void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { return kernel_source; }); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/metal/logsumexp.cpp b/mlx/backend/metal/logsumexp.cpp index 2cfdcdc8a0..8f7cbe3aff 100644 --- a/mlx/backend/metal/logsumexp.cpp +++ b/mlx/backend/metal/logsumexp.cpp @@ -19,14 +19,15 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { } auto& s = stream(); auto& d = metal::device(s.device); + auto& compute_encoder = metal::get_command_encoder(s); // Make sure that the last dimension is contiguous - auto ensure_contiguous = [&s, &d](const array& x) { + auto ensure_contiguous = [&](const array& x) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { return x; } else { array x_copy = contiguous_copy_gpu(x, s); - d.add_temporary(x_copy, s.index); + compute_encoder.add_temporary(x_copy); return x_copy; } }; @@ -66,7 +67,6 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { kernel_name += type_to_name(out); auto kernel = get_logsumexp_kernel(d, kernel_name, out); - auto& compute_encoder = d.get_command_encoder(s.index); { MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 84b6ee06da..df0065be55 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -44,7 +44,7 @@ inline array ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { if (!x.flags().row_contiguous) { array x_copy = contiguous_copy_gpu(x, s); - d.add_temporary(x_copy, s.index); + metal::get_command_encoder(s).add_temporary(x_copy); return x_copy; } else { return x; @@ -75,7 +75,7 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { } array x_copy = contiguous_copy_gpu(x, s); - d.add_temporary(x_copy, s.index); + metal::get_command_encoder(s).add_temporary(x_copy); return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy); } @@ -254,7 +254,7 @@ void steel_matmul_regular_axpby_nax( std::string hash_name = kname.str(); // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_steel_gemm_fused_nax_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ base_name, @@ -334,7 +334,7 @@ void steel_matmul_regular_axpby_nax( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Record copies - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); } template @@ -442,7 +442,7 @@ void steel_matmul_regular_axpby( std::string hash_name = kname.str(); // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_steel_gemm_fused_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ base_name, @@ -519,7 +519,7 @@ void steel_matmul_regular_axpby( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Record copies - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); } /////////////////////////////////////////////////////////////////////////////// @@ -587,7 +587,7 @@ void steel_gemm_splitk_axpby( << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_steel_gemm_splitk_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ kname.str(), @@ -676,7 +676,7 @@ void steel_gemm_splitk_axpby( compute_encoder.dispatch_threads(grid_dims, group_dims); } - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); } /////////////////////////////////////////////////////////////////////////////// @@ -748,7 +748,7 @@ void steel_gemm_splitk_axpby_nax( std::string hash_name = kname.str(); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_steel_gemm_splitk_nax_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ base_name, @@ -848,7 +848,7 @@ void steel_gemm_splitk_axpby_nax( compute_encoder.dispatch_threads(grid_dims, group_dims); } - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); } /////////////////////////////////////////////////////////////////////////////// @@ -1131,7 +1131,7 @@ void gemv_axbpy( << "_axpby" << do_axpby; // clang-format on // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kname.str()); compute_encoder.set_compute_pipeline_state(kernel); @@ -1166,7 +1166,7 @@ void gemv_axbpy( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); } inline void gemv( @@ -1219,6 +1219,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { } auto& s = stream(); auto& d = metal::device(s.device); + auto& compute_encoder = metal::get_command_encoder(s); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; @@ -1226,7 +1227,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { if (a_pre.size() == 0 || b_pre.size() == 0) { array zero = array(0, a_pre.dtype()); fill_gpu(zero, out, s); - d.add_temporary(std::move(zero), s.index); + compute_encoder.add_temporary(std::move(zero)); return; } @@ -1331,6 +1332,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); + auto& compute_encoder = metal::get_command_encoder(s); // Handle empty matrix case (K=0) if (inputs[0].shape(-1) == 0) { @@ -1344,7 +1346,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { } else { array beta_scalar = array(beta_, c.dtype()); binary_op_gpu({c, beta_scalar}, out, "Multiply", s); - d.add_temporary(std::move(beta_scalar), s.index); + compute_encoder.add_temporary(std::move(beta_scalar)); } return; } @@ -1464,6 +1466,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { } auto& s = stream(); auto& d = metal::device(s.device); + auto& compute_encoder = metal::get_command_encoder(s); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; @@ -1471,7 +1474,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { if (a_pre.size() == 0 || b_pre.size() == 0) { array zero = array(0, a_pre.dtype()); fill_gpu(zero, out, s); - d.add_temporary(std::move(zero), s.index); + compute_encoder.add_temporary(std::move(zero)); return; } @@ -1642,7 +1645,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { tn, contiguous_kernel); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; @@ -1724,7 +1727,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); return; } @@ -1747,7 +1750,6 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_masked_kernel( d, kname.str(), @@ -1834,7 +1836,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); } /////////////////////////////////////////////////////////////////////////////// @@ -1925,7 +1927,7 @@ void gather_mm_rhs( align_K ? 't' : 'n'); // Get and set the kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_steel_gemm_gather_kernel( d, base_name, @@ -2068,7 +2070,7 @@ void gather_mm_rhs_nax( align_K ? 't' : 'n'); // Get and set the kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_steel_gemm_gather_nax_kernel( d, base_name, @@ -2128,12 +2130,14 @@ void gather_mv( bool is_mv, metal::Device& d, const Stream& s) { + auto& compute_encoder = metal::get_command_encoder(s); + // Copy if needed std::vector copies; auto [transpose_mat, mat_cols, mat] = check_transpose(copies, s, mat_, N == 1); auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true); - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); // If we are doing vector matrix instead of matrix vector we need to flip the // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated @@ -2200,7 +2204,6 @@ void gather_mv( << tm << "_tn" << tn; // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder.set_compute_pipeline_state(kernel); @@ -2245,11 +2248,13 @@ void gather_mm( int K, metal::Device& d, const Stream& s) { + auto& compute_encoder = metal::get_command_encoder(s); + // Copy if needed std::vector copies; auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false); auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false); - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); // Determine dispatch kernel int bm = 64, bn = 64, bk = 16; @@ -2313,7 +2318,6 @@ void gather_mm( align_K ? 't' : 'n'); // Get and set the kernel - auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_gather_kernel( d, base_name, @@ -2375,6 +2379,7 @@ void gather_mm( void GatherMM::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); + auto& compute_encoder = metal::get_command_encoder(s); auto& a = inputs[0]; auto& b = inputs[1]; @@ -2385,7 +2390,7 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { if (a.size() == 0 || b.size() == 0) { array zero = array(0, a.dtype()); fill_gpu(zero, out, s); - d.add_temporary(std::move(zero), s.index); + compute_encoder.add_temporary(std::move(zero)); return; } @@ -2431,7 +2436,9 @@ void segmented_mm( int K, metal::Device& d, const Stream& s) { - auto check_segments_layout = [&d, &s](const array& x) { + auto& compute_encoder = metal::get_command_encoder(s); + + auto check_segments_layout = [&](const array& x) { // Contiguous so return early if (x.flags().row_contiguous) { return std::make_tuple(true, x); @@ -2452,7 +2459,7 @@ void segmented_mm( } array x_copy = contiguous_copy_gpu(x, s); - d.add_temporary(x_copy, s.index); + compute_encoder.add_temporary(x_copy); return std::make_tuple(true, x_copy); }; @@ -2461,7 +2468,7 @@ void segmented_mm( auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false); auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false); auto [segments_contiguous, segments] = check_segments_layout(segments_); - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); // Determine dispatch kernel int bm = 64, bn = 64, bk = 16; @@ -2517,7 +2524,6 @@ void segmented_mm( align_N ? 't' : 'n'); // Get and set the kernel - auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_segmented_kernel( d, base_name, diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 1dc0338f52..9a222cdd6c 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -60,7 +60,7 @@ void RMSNorm::eval_gpu( op_name += "_looped"; } op_name += type_to_name(out); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); { auto kernel = d.get_kernel(op_name); @@ -97,6 +97,7 @@ void RMSNormVJP::eval_gpu( std::vector& outputs) { auto& s = stream(); auto& d = metal::device(s.device); + auto& compute_encoder = metal::get_command_encoder(s); // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the @@ -130,7 +131,7 @@ void RMSNormVJP::eval_gpu( gx.set_data(allocator::malloc(gx.nbytes())); } if (g_copied && !g_in_gx) { - d.add_temporary(g, s.index); + compute_encoder.add_temporary(g); } auto axis_size = static_cast(x.shape().back()); @@ -145,7 +146,7 @@ void RMSNormVJP::eval_gpu( gw_temp.copy_shared_buffer(g); } else { gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); - d.add_temporary(gw_temp, s.index); + compute_encoder.add_temporary(gw_temp); } } gw.set_data(allocator::malloc(gw.nbytes())); @@ -164,7 +165,6 @@ void RMSNormVJP::eval_gpu( {&has_w, MTL::DataType::DataTypeBool, 20}, }; - auto& compute_encoder = d.get_command_encoder(s.index); { auto kernel = d.get_kernel(op_name, hash_name, func_consts); @@ -257,7 +257,7 @@ void LayerNorm::eval_gpu( n_reads = 4; } op_name += type_to_name(out); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); { auto kernel = d.get_kernel(op_name); @@ -303,6 +303,7 @@ void LayerNormVJP::eval_gpu( std::vector& outputs) { auto& s = stream(); auto& d = metal::device(s.device); + auto& compute_encoder = metal::get_command_encoder(s); // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the @@ -339,7 +340,7 @@ void LayerNormVJP::eval_gpu( gx.set_data(allocator::malloc(gx.nbytes())); } if (g_copied && !g_in_gx) { - d.add_temporary(g, s.index); + compute_encoder.add_temporary(g); } auto axis_size = static_cast(x.shape().back()); @@ -354,14 +355,13 @@ void LayerNormVJP::eval_gpu( gw_temp.copy_shared_buffer(g); } else { gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); - d.add_temporary(gw_temp, s.index); + compute_encoder.add_temporary(gw_temp); } } gw.set_data(allocator::malloc(gw.nbytes())); gb.set_data(allocator::malloc(gb.nbytes())); // Finish with the gradient for b in case we had a b - auto& compute_encoder = d.get_command_encoder(s.index); if (gb.ndim() == 1 && gb.size() == axis_size) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 5f6376c5e8..d5bbf797e4 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -37,7 +37,7 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); MTL::Size group_dims = MTL::Size( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); switch (out.dtype()) { @@ -116,7 +116,7 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { // ArgReduce int simd_size = 32; int n_reads = 4; - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); { auto kernel = d.get_kernel(op_name + type_to_name(in)); NS::UInteger thread_group_size = std::min( @@ -183,7 +183,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { // organize into grid nkeys x elem_per_key MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1); auto group_dims = get_block_dims(num_keys, half_size + odd, 1); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(keys, 0); compute_encoder.set_output_array(out, 1); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index e67001e066..c8d5a31cb4 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -54,7 +54,7 @@ inline array ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { if (!x.flags().row_contiguous) { array x_copy = contiguous_copy_gpu(x, s); - d.add_temporary(x_copy, s.index); + metal::get_command_encoder(s).add_temporary(x_copy); return x_copy; } else { return x; @@ -77,7 +77,7 @@ inline array ensure_row_contiguous_matrix( } } array x_copy = contiguous_copy_gpu(x, s); - d.add_temporary(x_copy, s.index); + metal::get_command_encoder(s).add_temporary(x_copy); return x_copy; } @@ -214,7 +214,7 @@ void qmv_quad( B > 1 ? "_batch_1" : "_batch_0"); auto kernel = get_quantized_kernel_wrapped( d, kname, "qmv_quad", mode, type_string, group_size, bits, K, B > 1); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; @@ -277,7 +277,7 @@ void qmv( bits, B > 1); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; @@ -309,6 +309,8 @@ void qvm_split_k( metal::Device& d, const Stream& s, const std::string& mode) { + auto& compute_encoder = metal::get_command_encoder(s); + int split_k = K > 8192 ? 32 : 8; int split_D = (K + split_k - 1) / split_k; int B = out.size() / M / N; @@ -356,7 +358,7 @@ void qvm_split_k( temp_shape.insert(temp_shape.end() - 2, split_k); array intermediate(temp_shape, x.dtype(), nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); - d.add_temporary(intermediate, s.index); + compute_encoder.add_temporary(intermediate); std::string type_string = get_type_string(x.dtype()); std::string kname; @@ -376,7 +378,6 @@ void qvm_split_k( auto kernel = get_quantized_kernel_wrapped( d, kname, "qvm_split_k", mode, type_string, group_size, bits, split_k); - auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; @@ -451,7 +452,7 @@ void qvm( B > 1 ? "_batch_1" : "_batch_0"); auto kernel = get_quantized_kernel_wrapped( d, kname, "qvm", mode, type_string, group_size, bits, B > 1); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; @@ -553,7 +554,7 @@ void qmm_nax( wm, wn); } - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; @@ -654,7 +655,7 @@ void gather_qmm_nax( wn); } - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; @@ -751,7 +752,7 @@ void qmm( kernel = get_quantized_kernel_wrapped( d, kname, "qmm_n", mode, type_string, group_size, bits, batched); } - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; @@ -808,6 +809,7 @@ void qmm_splitk( // Allocate intermediate buffer: insert split_k at the front so that // partition_stride = M * N matches the leading stride of the buffer. + auto& compute_encoder = metal::get_command_encoder(s); auto temp_shape = out.shape(); if (temp_shape.size() == 1) { temp_shape.insert(temp_shape.begin(), 1); @@ -815,7 +817,7 @@ void qmm_splitk( temp_shape.insert(temp_shape.begin(), split_k); array intermediate(temp_shape, x.dtype(), nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); - d.add_temporary(intermediate, s.index); + compute_encoder.add_temporary(intermediate); // Grid: (N_tiles, M_tiles, split_k) MTL::Size group_dims(32, 2, 2); @@ -837,7 +839,6 @@ void qmm_splitk( auto kernel = get_quantized_kernel_wrapped( d, kname, "qmm_t_splitk", mode, type_string, group_size, bits, aligned); - auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; @@ -934,7 +935,7 @@ void gather_qmm( d, kname, "gather_qmm_n", mode, type_string, group_size, bits); } - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; @@ -1001,7 +1002,7 @@ void gather_qmv( group_size, bits); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; @@ -1059,7 +1060,7 @@ void gather_qvm( bits); auto kernel = get_quantized_kernel_wrapped( d, kname, "gather_qvm", mode, type_string, group_size, bits); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; @@ -1173,7 +1174,7 @@ void gather_qmm_rhs_nax( align_K ? 't' : 'n'); // Get and set the kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_gather_qmm_nax_kernel( d, kname, @@ -1323,7 +1324,7 @@ void gather_qmm_rhs( align_K ? 't' : 'n'); // Get and set the kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_gather_qmm_kernel( d, kname, @@ -1568,7 +1569,7 @@ void quantize_dequantize( int bits, metal::Device& d, const Stream& s) { - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto w = ensure_row_contiguous(in, d, s); compute_encoder.set_input_array(w, 0); @@ -1662,7 +1663,7 @@ void fast::Quantize::eval_gpu( auto& s = stream(); auto& d = metal::device(s.device); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto w = ensure_row_contiguous(w_pre, d, s); if (dequantize_) { diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 504943d823..644af5d218 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -358,7 +358,7 @@ void all_reduce_dispatch( // Allocate an intermediate tensor to hold results if needed array intermediate({n_rows}, out_type, nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); - d.add_temporary(intermediate, s.index); + compute_encoder.add_temporary(intermediate); // 1st pass size_t row_size = (in_size + n_rows - 1) / n_rows; @@ -652,7 +652,7 @@ void strided_reduce_longcolumn( intermediate_shape.end(), out.shape().begin(), out.shape().end()); array intermediate(std::move(intermediate_shape), out_type, nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); - d.add_temporary(intermediate, s.index); + compute_encoder.add_temporary(intermediate); // Prepare the arguments for the kernel args.reduce_shape.push_back(args.reduction_size); @@ -823,7 +823,7 @@ void strided_reduce_2pass( intermediate_shape.end(), out.shape().begin(), out.shape().end()); array intermediate(std::move(intermediate_shape), out_type, nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); - d.add_temporary(intermediate, s.index); + compute_encoder.add_temporary(intermediate); // Prepare the arguments for the kernel args.reduce_shape.push_back(args.reduction_size); @@ -986,7 +986,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Initialize output auto& s = stream(); auto& d = metal::device(s.device); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); // Reduce if (in.size() > 0) { @@ -1000,7 +1000,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // stride. if (plan.type == GeneralReduce) { array in_copy = contiguous_copy_gpu(in, s); - d.add_temporary(in_copy, s.index); + compute_encoder.add_temporary(in_copy); in = in_copy; plan = get_reduction_plan(in, axes_); } diff --git a/mlx/backend/metal/resident.cpp b/mlx/backend/metal/resident.cpp index 97187b05b9..80de4f5792 100644 --- a/mlx/backend/metal/resident.cpp +++ b/mlx/backend/metal/resident.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/resident.h" +#include "mlx/backend/metal/device.h" namespace mlx::core::metal { diff --git a/mlx/backend/metal/resident.h b/mlx/backend/metal/resident.h index 9961d722d7..50b1bd03d4 100644 --- a/mlx/backend/metal/resident.h +++ b/mlx/backend/metal/resident.h @@ -2,7 +2,9 @@ #pragma once -#include "mlx/backend/metal/device.h" +#include + +#include namespace mlx::core::metal { diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index ca0a662212..2190b7dc10 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -118,7 +118,7 @@ void RoPE::eval_gpu( {&head_seq_transpose, MTL::DataType::DataTypeBool, 3}}; auto kernel = d.get_kernel(kname, hash_name, func_consts); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); float base = std::log2(base_); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 37e554f183..c79cd51ff0 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -89,7 +89,7 @@ void sdpa_full_self_attention_nax( "_has_sinks_", (has_sinks ? 't' : 'n')); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_steel_attention_nax_kernel( d, @@ -252,7 +252,7 @@ void sdpa_full_self_attention_metal( "_has_sinks_", (has_sinks ? 't' : 'n')); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = get_steel_attention_kernel( d, @@ -378,7 +378,7 @@ void sdpa_vector( hash_name += has_sinks ? "_sinks" : "_nosinks"; // Get the kernel - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto kernel = d.get_kernel(kname, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); @@ -495,9 +495,10 @@ void sdpa_vector_2pass( intermediate.set_data(allocator::malloc(intermediate.nbytes())); sums.set_data(allocator::malloc(sums.nbytes())); maxs.set_data(allocator::malloc(maxs.nbytes())); - d.add_temporary(intermediate, s.index); - d.add_temporary(sums, s.index); - d.add_temporary(maxs, s.index); + auto& compute_encoder = metal::get_command_encoder(s); + compute_encoder.add_temporary(intermediate); + compute_encoder.add_temporary(sums); + compute_encoder.add_temporary(maxs); bool has_mask = mask.has_value(); bool bool_mask = has_mask && (*mask).dtype() == bool_; @@ -521,7 +522,6 @@ void sdpa_vector_2pass( hash_name += std::to_string(blocks); // Get the kernel - auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname, hash_name, func_consts); check_kernel_threadgroup_size(kernel, group_dims, hash_name); @@ -782,7 +782,7 @@ void ScaledDotProductAttention::eval_gpu( s, d, q, k, v, scale_, o, do_causal_, mask, sinks); } - d.add_temporaries(std::move(copies), s.index); + metal::get_command_encoder(s).add_temporaries(std::move(copies)); } bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index 5d26981334..ede0306c06 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -60,7 +60,7 @@ void scan_gpu_inplace( get_scan_kernel(d, kname, reverse, inclusive, reduce_type_str, in, out); if (contiguous) { - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); @@ -89,7 +89,7 @@ void scan_gpu_inplace( MTL::Size group_dims(thread_group_size, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 1e14c35c8a..e92aef43db 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -31,7 +31,7 @@ void concatenate_gpu( flags.col_contiguous = false; flags.contiguous = false; auto& d = metal::device(s.device); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); auto concurrent_ctx = compute_encoder.start_concurrent(); for (int i = 0; i < inputs.size(); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); @@ -48,6 +48,7 @@ array compute_dynamic_offset( const std::vector& axes, const Stream& s) { auto& d = metal::device(s.device); + auto& compute_encoder = metal::get_command_encoder(s); // Kernel to compute offset here. array offset({1}, int64, nullptr, {}); @@ -58,7 +59,7 @@ array compute_dynamic_offset( } else { offset.set_data(allocator::malloc(offset.itemsize())); } - d.add_temporary(offset, s.index); + compute_encoder.add_temporary(offset); auto dtype = indices.dtype(); std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype); @@ -83,7 +84,6 @@ array compute_dynamic_offset( }); auto kernel = d.get_kernel(lib_name, lib); - auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(indices, 0); compute_encoder.set_output_array(offset, 1); diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 0b1a1848df..f455cf2d1f 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -58,7 +58,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { kernel_name += type_to_name(out); auto kernel = get_softmax_kernel(d, kernel_name, precise_, out); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); { MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 3c84022f2c..ec965e2fb9 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -63,7 +63,7 @@ void single_block_sort( auto kernel = get_sort_kernel(d, kname.str(), in, out, bn, tn); // Prepare command encoder - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); // Set inputs @@ -160,7 +160,7 @@ void multi_block_sort( dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions}; // Prepare command encoder - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); // Do blockwise sort { @@ -268,7 +268,7 @@ void multi_block_sort( (axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General, s); - d.add_temporaries(std::move(copies), s.index); + compute_encoder.add_temporaries(std::move(copies)); } void gpu_merge_sort( diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 252815aae5..9bf07f31b3 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -78,7 +78,7 @@ void ternary_op_gpu_inplace( auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 833b23f632..354f199c68 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -57,7 +57,7 @@ void unary_op_gpu_inplace( auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1);