From 0d67d8454980f8b5ab11fb830803098a53694c13 Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 16 Mar 2026 11:22:29 +0900 Subject: [PATCH] Merge DeviceStream into CommandEncoder --- mlx/backend/metal/device.cpp | 286 ++++++++++++++++++----------------- mlx/backend/metal/device.h | 125 ++++++--------- mlx/backend/metal/eval.cpp | 4 +- 3 files changed, 201 insertions(+), 214 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index e65bfb41ea..aad42f4471 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -13,6 +13,18 @@ #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" +namespace std { + +// Required for putting the pointer in unordered_set. +template +struct hash> { + size_t operator()(const NS::SharedPtr& p) const { + return std::hash{}(p.get()); + } +}; + +} // namespace std + namespace mlx::core::metal { namespace { @@ -233,14 +245,22 @@ MTL::Library* load_library( } // namespace -CommandEncoder::CommandEncoder(DeviceStream& stream) : stream_(stream) { - enc_ = stream_.buffer->computeCommandEncoder(MTL::DispatchTypeConcurrent); - enc_->retain(); -} - -CommandEncoder::~CommandEncoder() { - enc_->endEncoding(); - enc_->release(); +CommandEncoder::CommandEncoder( + Device& d, + int index, + const MTL::ResidencySet* residency_set) + : device_(d) { + auto pool = new_scoped_memory_pool(); + queue_ = NS::TransferPtr(device_.mtl_device()->newCommandQueue()); + if (!queue_) { + throw std::runtime_error( + "[metal::CommandEncoder] Failed to make new command queue."); + } + if (residency_set) { + queue_->addResidencySet(residency_set); + } + debug_set_stream_queue_label(queue_.get(), index); + buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences()); } void CommandEncoder::set_buffer( @@ -251,7 +271,7 @@ void CommandEncoder::set_buffer( // buffers all_inputs_.insert((void*)buf); all_outputs_.insert((void*)buf); - enc_->setBuffer(buf, offset, idx); + get_command_encoder()->setBuffer(buf, offset, idx); } void CommandEncoder::set_input_array( @@ -259,13 +279,13 @@ void CommandEncoder::set_input_array( int idx, int64_t offset /* = 0 */) { if (all_inputs_.insert(a.buffer().ptr()).second) { - stream_.buffer_sizes += a.data_size(); + buffer_sizes_ += a.data_size(); } auto r_buf = static_cast(const_cast(a.buffer().ptr())); needs_barrier_ = needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end()); auto a_buf = static_cast(a.buffer().ptr()); - enc_->setBuffer(a_buf, a.offset() + offset, idx); + get_command_encoder()->setBuffer(a_buf, a.offset() + offset, idx); } void CommandEncoder::set_output_array( @@ -288,9 +308,20 @@ void CommandEncoder::register_output_array(const array& a) { } } +void CommandEncoder::add_temporary(array arr) { + temporaries_.push_back(std::move(arr)); +} + +void CommandEncoder::add_temporaries(std::vector arrays) { + temporaries_.insert( + temporaries_.end(), + std::make_move_iterator(arrays.begin()), + std::make_move_iterator(arrays.end())); +} + void CommandEncoder::maybeInsertBarrier() { if (needs_barrier_) { - enc_->memoryBarrier(MTL::BarrierScopeBuffers); + get_command_encoder()->memoryBarrier(MTL::BarrierScopeBuffers); needs_barrier_ = false; prev_outputs_ = std::move(next_outputs_); } else { @@ -303,20 +334,110 @@ void CommandEncoder::dispatch_threadgroups( MTL::Size grid_dims, MTL::Size group_dims) { maybeInsertBarrier(); - stream_.buffer_ops++; - enc_->dispatchThreadgroups(grid_dims, group_dims); + buffer_ops_++; + get_command_encoder()->dispatchThreadgroups(grid_dims, group_dims); } void CommandEncoder::dispatch_threads( MTL::Size grid_dims, MTL::Size group_dims) { maybeInsertBarrier(); - stream_.buffer_ops++; - enc_->dispatchThreads(grid_dims, group_dims); + buffer_ops_++; + get_command_encoder()->dispatchThreads(grid_dims, group_dims); } void CommandEncoder::barrier() { - enc_->memoryBarrier(MTL::BarrierScopeBuffers); + get_command_encoder()->memoryBarrier(MTL::BarrierScopeBuffers); +} + +void CommandEncoder::end_encoding() { + // Each command encoder has a unique fence. We also store a map of + // all previous outputs of command encoders to their corresponding fence. + // - The command encoder records its inputs and outputs. + // - Wait on a fence if any inputs in the encoder are outputs of a previous + // encoder. + // - Update the map of outputs to include this command encoder's outputs. + // - Always signal this command encoders fence. + // - Add a completion handler for this command encoder that removes outputs + // from the map to limit the growth of the map and avoid unnecessary waits + // - Temporaries are a special case as they do not cross command encoder + // boundaries. These can be removed early from the encoders inputs and + // outputs since they don't need synchronization. + if (!encoder_) { + return; + } + + // Remove temporaries from inputs and outputs. + for (auto& t : temporaries_) { + all_outputs_.erase(t.buffer().ptr()); + all_inputs_.erase(t.buffer().ptr()); + } + + // Keep references to the fences we waited on and put them in the completion + // handler so they are not prematurely released. + std::unordered_set> waiting_on; + { + std::lock_guard lk(outputs_mtx_); + for (auto& in : all_inputs_) { + if (auto it = prev_ce_outputs_.find(in); it != prev_ce_outputs_.end()) { + // If we've already waited on a fence, don't wait on it again. + if (waiting_on.find(it->second) == waiting_on.end()) { + encoder_->waitForFence(it->second.get()); + waiting_on.insert(it->second); + } + } + } + for (auto& out : all_outputs_) { + prev_ce_outputs_[out] = fence_; + } + } + + encoder_->updateFence(fence_.get()); + buffer_->addCompletedHandler([this, + fence = std::move(fence_), + temporaries = std::move(temporaries_), + all_outputs = std::move(all_outputs_), + waiting_on = std::move(waiting_on)]( + MTL::CommandBuffer*) mutable { + std::lock_guard lk(outputs_mtx_); + for (auto& o : all_outputs) { + if (auto it = prev_ce_outputs_.find(o); it != prev_ce_outputs_.end()) { + if (it->second == fence) { + prev_ce_outputs_.erase(it); + } + } + } + }); + + encoder_->endEncoding(); + encoder_.reset(); + needs_barrier_ = false; + concurrent_ = false; + prev_outputs_.clear(); + next_outputs_.clear(); + concurrent_outputs_.clear(); + all_inputs_.clear(); +} + +bool CommandEncoder::needs_commit() const { + auto [max_ops, max_mb] = device_.get_max_ops_mb_per_buffer(); + return (buffer_ops_ > max_ops) || ((buffer_sizes_ >> 20) > max_mb); +} + +void CommandEncoder::commit() { + buffer_->commit(); + buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences()); + buffer_ops_ = 0; + buffer_sizes_ = 0; +} + +MTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() { + if (!encoder_) { + encoder_ = NS::RetainPtr( + buffer_->computeCommandEncoder(MTL::DispatchTypeConcurrent)); + fence_ = NS::TransferPtr(device_.mtl_device()->newFence()); + } + return encoder_.get(); } Device::Device() { @@ -371,145 +492,40 @@ Device::~Device() { k->release(); } } - stream_map_.clear(); + encoders_.clear(); device_->release(); } -void Device::new_queue(int index) { - auto thread_pool = metal::new_scoped_memory_pool(); - auto q = device_->newCommandQueue(); - debug_set_stream_queue_label(q, index); - if (!q) { - throw std::runtime_error( - "[metal::Device] Failed to make new command queue."); - } - stream_map_.emplace(index, q); - if (residency_set_ != nullptr) { - q->addResidencySet(residency_set_); - } -} - -MTL::CommandQueue* Device::get_queue(Stream stream) { - return get_stream_(stream.index).queue; -} - bool Device::command_buffer_needs_commit(int index) { - auto& stream = get_stream_(index); - return (stream.buffer_ops > max_ops_per_buffer_) || - ((stream.buffer_sizes >> 20) > max_mb_per_buffer_); + return get_command_encoder(index).needs_commit(); } MTL::CommandBuffer* Device::get_command_buffer(int index) { - auto& stream = get_stream_(index); - if (stream.buffer == nullptr) { - stream.buffer = stream.queue->commandBufferWithUnretainedReferences(); - if (!stream.buffer) { - throw std::runtime_error( - "[metal::Device] Unable to create new command buffer"); - } - // Increment ref count so the buffer is not garbage collected - stream.buffer->retain(); - } - return stream.buffer; + return get_command_encoder(index).get_command_buffer(); } void Device::commit_command_buffer(int index) { - auto& stream = get_stream_(index); - stream.buffer->commit(); - stream.buffer->release(); - stream.buffer = nullptr; - stream.buffer_ops = 0; - stream.buffer_sizes = 0; + get_command_encoder(index).commit(); } void Device::add_temporary(array arr, int index) { - get_stream_(index).temporaries.push_back(std::move(arr)); + get_command_encoder(index).add_temporary(std::move(arr)); } void Device::add_temporaries(std::vector arrays, int index) { - if (arrays.empty()) { - return; - } - auto& stream = get_stream_(index); - stream.temporaries.insert( - stream.temporaries.end(), - std::make_move_iterator(arrays.begin()), - std::make_move_iterator(arrays.end())); + get_command_encoder(index).add_temporaries(std::move(arrays)); } void Device::end_encoding(int index) { - auto& stream = get_stream_(index); - if (stream.encoder != nullptr) { - // Each command encoder has a unique fence. We also store a map of - // all previous outputs of command encoders to their corresponding fence. - // - The command encoder records its inputs and outputs. - // - Wait on a fence if any inputs in the encoder are outputs of a previous - // encoder. - // - Update the map of outputs to include this command encoder's outputs. - // - Always signal this command encoders fence. - // - Add a completion handler for this command encoder that removes outputs - // from the map to limit the growth of the map and avoid unnecessary waits - // - Temporaries are a special case as they do not cross command encoder - // boundaries. These can be removed early from the encoders inputs and - // outputs since they don't need synchronization. - auto& enc = *stream.encoder; - // Remove temporaries from inputs and outputs - for (auto& t : stream.temporaries) { - enc.outputs().erase(t.buffer().ptr()); - enc.inputs().erase(t.buffer().ptr()); - } - - // Keep references to the fences we waited on and put them - // in the completion handler so they are not prematurely released - std::unordered_set> waiting_on; - { - std::lock_guard lk(stream.fence_mtx); - for (auto in : enc.inputs()) { - if (auto it = stream.outputs.find(in); it != stream.outputs.end()) { - // If we've already waited on a fence, don't wait on it again. - if (waiting_on.find(it->second) == waiting_on.end()) { - enc.wait_for_fence(it->second->fence); - waiting_on.insert(it->second); - } - } - } - for (auto out : enc.outputs()) { - stream.outputs[out] = stream.fence; - } - } - enc.update_fence(stream.fence->fence); - stream.buffer->addCompletedHandler( - [&stream, - waiting_on = std::move(waiting_on), - fence = std::move(stream.fence), - outputs = std::move(enc.outputs()), - temporaries = - std::move(stream.temporaries)](MTL::CommandBuffer*) mutable { - temporaries.clear(); - std::lock_guard lk(stream.fence_mtx); - for (auto o : outputs) { - if (auto it = stream.outputs.find(o); it != stream.outputs.end()) { - if (it->second == fence) { - stream.outputs.erase(it); - } - } - } - }); - } - stream.encoder = nullptr; + get_command_encoder(index).end_encoding(); } CommandEncoder& Device::get_command_encoder(int index) { - auto& stream = get_stream_(index); - if (stream.encoder == nullptr) { - // Ensure there is an active command buffer - if (stream.buffer == nullptr) { - get_command_buffer(index); - } - stream.encoder = std::make_unique(stream); - stream.fence = std::make_shared(device_->newFence()); + auto it = encoders_.find(index); + if (it == encoders_.end()) { + it = encoders_.try_emplace(index, *this, index, residency_set_).first; } - return *stream.encoder; + return it->second; } MTL::Library* Device::get_library( @@ -808,8 +824,8 @@ void Device::set_residency_set(const MTL::ResidencySet* residency_set) { } residency_set_ = residency_set; // Attach residency set to existing command queues - for (auto& [_, stream] : stream_map_) { - stream.queue->addResidencySet(residency_set_); + for (auto& [_, encoder] : encoders_) { + encoder.get_command_queue()->addResidencySet(residency_set_); } } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index e6162d7da6..0e705a6560 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -18,10 +18,11 @@ namespace mlx::core::metal { using MTLFCList = std::vector>; -struct DeviceStream; +class Device; -struct MLX_API CommandEncoder { - explicit CommandEncoder(DeviceStream& stream); +class MLX_API CommandEncoder { + public: + CommandEncoder(Device& d, int index, const MTL::ResidencySet* residency_set); CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; @@ -40,29 +41,26 @@ struct MLX_API CommandEncoder { CommandEncoder& enc; }; + void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0); void set_input_array(const array& a, int idx, int64_t offset = 0); void set_output_array(array& a, int idx, int64_t offset = 0); void register_output_array(const array& a); + + void add_temporary(array arr); + void add_temporaries(std::vector arrays); + void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims); void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims); void maybeInsertBarrier(); - void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0); void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) { - enc_->setComputePipelineState(kernel); - } - - void wait_for_fence(MTL::Fence* fence) { - enc_->waitForFence(fence); - } - - void update_fence(MTL::Fence* fence) { - enc_->updateFence(fence); + get_command_encoder()->setComputePipelineState(kernel); } template >> void set_vector_bytes(const Vec& vec, size_t nelems, int idx) { - enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx); + get_command_encoder()->setBytes( + vec.data(), nelems * sizeof(typename Vec::value_type), idx); } template >> void set_vector_bytes(const Vec& vec, int idx) { @@ -71,79 +69,62 @@ struct MLX_API CommandEncoder { template void set_bytes(const T* v, int n, int idx) { - return enc_->setBytes(v, n * sizeof(T), idx); + return get_command_encoder()->setBytes(v, n * sizeof(T), idx); } template void set_bytes(const T& v, int idx) { - return enc_->setBytes(&v, sizeof(T), idx); + return get_command_encoder()->setBytes(&v, sizeof(T), idx); } void set_threadgroup_memory_length(size_t length, int idx) { - enc_->setThreadgroupMemoryLength(length, idx); + get_command_encoder()->setThreadgroupMemoryLength(length, idx); } ConcurrentContext start_concurrent() { return ConcurrentContext(*this); } - ~CommandEncoder(); - - // Inputs to all kernels in the encoder including temporaries - std::unordered_set& inputs() { - return all_inputs_; - }; - - // Outputs of all kernels in the encoder including temporaries - std::unordered_set& outputs() { - return all_outputs_; - }; void barrier(); + void end_encoding(); + bool needs_commit() const; + void commit(); + + MTL::CommandQueue* get_command_queue() const { + return queue_.get(); + } + MTL::CommandBuffer* get_command_buffer() const { + return buffer_.get(); + } private: - DeviceStream& stream_; - MTL::ComputeCommandEncoder* enc_; + MTL::ComputeCommandEncoder* get_command_encoder(); + + Device& device_; + + // Buffer that stores encoded commands. + NS::SharedPtr queue_; + NS::SharedPtr buffer_; + int buffer_ops_{0}; + size_t buffer_sizes_{0}; + + // Encoder for issuing GPU commands. + // The members are used within a single ComputeCommandEncoder and will be + // reset after calling end_encoding(). + NS::SharedPtr encoder_; + NS::SharedPtr fence_; bool needs_barrier_{false}; bool concurrent_{false}; + std::vector temporaries_; std::unordered_set prev_outputs_; std::unordered_set next_outputs_; std::unordered_set concurrent_outputs_; std::unordered_set all_inputs_; std::unordered_set all_outputs_; -}; -struct Fence { - Fence(MTL::Fence* fence) : fence(fence) {} - ~Fence() { - fence->release(); - } - MTL::Fence* fence; -}; - -struct DeviceStream { - DeviceStream(MTL::CommandQueue* queue) : queue(queue) {}; - ~DeviceStream() { - queue->release(); - if (buffer != nullptr) { - buffer->release(); - } - }; - MTL::CommandQueue* queue; - // A map of prior command encoder outputs to their corresponding fence - std::unordered_map> outputs; - // Used to allow thread-safe access to the outputs map - std::mutex fence_mtx; - - // Data updated between command buffers - MTL::CommandBuffer* buffer{nullptr}; - int buffer_ops{0}; - size_t buffer_sizes{0}; - - // The command encoder, fence, and temporaries are updated between command - // encoders - std::unique_ptr encoder{nullptr}; - std::shared_ptr fence; - std::vector temporaries; + // A map of prior command encoder outputs to their corresponding fence. + std::unordered_map> prev_ce_outputs_; + std::mutex outputs_mtx_; }; class MLX_API Device { @@ -157,17 +138,15 @@ class MLX_API Device { return device_; }; - const std::string& get_architecture() { + const std::string& get_architecture() const { return arch_; } - int get_architecture_gen() const { return arch_gen_; } - - void new_queue(int index); - - MTL::CommandQueue* get_queue(Stream stream); + std::tuple get_max_ops_mb_per_buffer() const { + return std::make_tuple(max_ops_per_buffer_, max_mb_per_buffer_); + } MTL::CommandBuffer* get_command_buffer(int index); bool command_buffer_needs_commit(int index); @@ -198,9 +177,6 @@ class MLX_API Device { const MTLFCList& func_consts = {}, const std::vector& linked_functions = {}); - MTL::ArgumentEncoder* argument_encoder( - const std::vector& arg_descs) const; - // Record temporary arrays for the given stream index void add_temporary(array arr, int index); void add_temporaries(std::vector arrays, int index); @@ -208,9 +184,6 @@ class MLX_API Device { void set_residency_set(const MTL::ResidencySet* residency_set); private: - DeviceStream& get_stream_(int index) { - return stream_map_.find(index)->second; - } MTL::Library* get_library_cache_(const std::string& name); MTL::Library* get_library_(const std::string& name); @@ -244,7 +217,7 @@ class MLX_API Device { const std::vector& linked_functions = {}); MTL::Device* device_; - std::unordered_map stream_map_; + std::unordered_map encoders_; std::shared_mutex kernel_mtx_; std::shared_mutex library_mtx_; diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index bd58a691a9..123fd74057 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -11,7 +11,7 @@ namespace mlx::core::gpu { void new_stream(Stream stream) { if (stream.device == mlx::core::Device::gpu) { - metal::device(stream.device).new_queue(stream.index); + metal::device(stream.device).get_command_encoder(stream.index); } } @@ -63,7 +63,6 @@ void eval(array& arr) { check_error(cbuf); }); d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); } else { command_buffer->addCompletedHandler( [buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { @@ -79,7 +78,6 @@ void finalize(Stream s) { d.end_encoding(s.index); cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); } void synchronize(Stream s) {