Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 151 additions & 135 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"

namespace std {

// Required for putting the pointer in unordered_set.
template <class T>
struct hash<NS::SharedPtr<T>> {
size_t operator()(const NS::SharedPtr<T>& p) const {
return std::hash<T*>{}(p.get());
}
};

} // namespace std

namespace mlx::core::metal {

namespace {
Expand Down Expand Up @@ -233,14 +245,22 @@ MTL::Library* load_library(

} // namespace

CommandEncoder::CommandEncoder(DeviceStream& stream) : stream_(stream) {
enc_ = stream_.buffer->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc_->retain();
}

CommandEncoder::~CommandEncoder() {
enc_->endEncoding();
enc_->release();
CommandEncoder::CommandEncoder(
Device& d,
int index,
const MTL::ResidencySet* residency_set)
: device_(d) {
auto pool = new_scoped_memory_pool();
queue_ = NS::TransferPtr(device_.mtl_device()->newCommandQueue());
if (!queue_) {
throw std::runtime_error(
"[metal::CommandEncoder] Failed to make new command queue.");
}
if (residency_set) {
queue_->addResidencySet(residency_set);
}
debug_set_stream_queue_label(queue_.get(), index);
buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences());
}

void CommandEncoder::set_buffer(
Expand All @@ -251,21 +271,21 @@ void CommandEncoder::set_buffer(
// buffers
all_inputs_.insert((void*)buf);
all_outputs_.insert((void*)buf);
enc_->setBuffer(buf, offset, idx);
get_command_encoder()->setBuffer(buf, offset, idx);
}

void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
if (all_inputs_.insert(a.buffer().ptr()).second) {
stream_.buffer_sizes += a.data_size();
buffer_sizes_ += a.data_size();
}
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
needs_barrier_ =
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
enc_->setBuffer(a_buf, a.offset() + offset, idx);
get_command_encoder()->setBuffer(a_buf, a.offset() + offset, idx);
}

void CommandEncoder::set_output_array(
Expand All @@ -288,9 +308,20 @@ void CommandEncoder::register_output_array(const array& a) {
}
}

void CommandEncoder::add_temporary(array arr) {
temporaries_.push_back(std::move(arr));
}

void CommandEncoder::add_temporaries(std::vector<array> arrays) {
temporaries_.insert(
temporaries_.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
}

void CommandEncoder::maybeInsertBarrier() {
if (needs_barrier_) {
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
get_command_encoder()->memoryBarrier(MTL::BarrierScopeBuffers);
needs_barrier_ = false;
prev_outputs_ = std::move(next_outputs_);
} else {
Expand All @@ -303,20 +334,110 @@ void CommandEncoder::dispatch_threadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
stream_.buffer_ops++;
enc_->dispatchThreadgroups(grid_dims, group_dims);
buffer_ops_++;
get_command_encoder()->dispatchThreadgroups(grid_dims, group_dims);
}

void CommandEncoder::dispatch_threads(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
stream_.buffer_ops++;
enc_->dispatchThreads(grid_dims, group_dims);
buffer_ops_++;
get_command_encoder()->dispatchThreads(grid_dims, group_dims);
}

void CommandEncoder::barrier() {
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
get_command_encoder()->memoryBarrier(MTL::BarrierScopeBuffers);
}

void CommandEncoder::end_encoding() {
// Each command encoder has a unique fence. We also store a map of
// all previous outputs of command encoders to their corresponding fence.
// - The command encoder records its inputs and outputs.
// - Wait on a fence if any inputs in the encoder are outputs of a previous
// encoder.
// - Update the map of outputs to include this command encoder's outputs.
// - Always signal this command encoders fence.
// - Add a completion handler for this command encoder that removes outputs
// from the map to limit the growth of the map and avoid unnecessary waits
// - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
if (!encoder_) {
return;
}

// Remove temporaries from inputs and outputs.
for (auto& t : temporaries_) {
all_outputs_.erase(t.buffer().ptr());
all_inputs_.erase(t.buffer().ptr());
}

// Keep references to the fences we waited on and put them in the completion
// handler so they are not prematurely released.
std::unordered_set<NS::SharedPtr<MTL::Fence>> waiting_on;
{
std::lock_guard lk(outputs_mtx_);
for (auto& in : all_inputs_) {
if (auto it = prev_ce_outputs_.find(in); it != prev_ce_outputs_.end()) {
// If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) {
encoder_->waitForFence(it->second.get());
waiting_on.insert(it->second);
}
}
}
for (auto& out : all_outputs_) {
prev_ce_outputs_[out] = fence_;
}
}

encoder_->updateFence(fence_.get());
buffer_->addCompletedHandler([this,
fence = std::move(fence_),
temporaries = std::move(temporaries_),
all_outputs = std::move(all_outputs_),
waiting_on = std::move(waiting_on)](
MTL::CommandBuffer*) mutable {
std::lock_guard lk(outputs_mtx_);
for (auto& o : all_outputs) {
if (auto it = prev_ce_outputs_.find(o); it != prev_ce_outputs_.end()) {
if (it->second == fence) {
prev_ce_outputs_.erase(it);
}
}
}
});

encoder_->endEncoding();
encoder_.reset();
needs_barrier_ = false;
concurrent_ = false;
prev_outputs_.clear();
next_outputs_.clear();
concurrent_outputs_.clear();
all_inputs_.clear();
}

bool CommandEncoder::needs_commit() const {
auto [max_ops, max_mb] = device_.get_max_ops_mb_per_buffer();
return (buffer_ops_ > max_ops) || ((buffer_sizes_ >> 20) > max_mb);
}

void CommandEncoder::commit() {
buffer_->commit();
buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences());
buffer_ops_ = 0;
buffer_sizes_ = 0;
}

MTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() {
if (!encoder_) {
encoder_ = NS::RetainPtr(
buffer_->computeCommandEncoder(MTL::DispatchTypeConcurrent));
fence_ = NS::TransferPtr(device_.mtl_device()->newFence());
}
return encoder_.get();
}

Device::Device() {
Expand Down Expand Up @@ -371,145 +492,40 @@ Device::~Device() {
k->release();
}
}
stream_map_.clear();
encoders_.clear();
device_->release();
}

void Device::new_queue(int index) {
auto thread_pool = metal::new_scoped_memory_pool();
auto q = device_->newCommandQueue();
debug_set_stream_queue_label(q, index);
if (!q) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");
}
stream_map_.emplace(index, q);
if (residency_set_ != nullptr) {
q->addResidencySet(residency_set_);
}
}

MTL::CommandQueue* Device::get_queue(Stream stream) {
return get_stream_(stream.index).queue;
}

bool Device::command_buffer_needs_commit(int index) {
auto& stream = get_stream_(index);
return (stream.buffer_ops > max_ops_per_buffer_) ||
((stream.buffer_sizes >> 20) > max_mb_per_buffer_);
return get_command_encoder(index).needs_commit();
}

MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto& stream = get_stream_(index);
if (stream.buffer == nullptr) {
stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
if (!stream.buffer) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
stream.buffer->retain();
}
return stream.buffer;
return get_command_encoder(index).get_command_buffer();
}

void Device::commit_command_buffer(int index) {
auto& stream = get_stream_(index);
stream.buffer->commit();
stream.buffer->release();
stream.buffer = nullptr;
stream.buffer_ops = 0;
stream.buffer_sizes = 0;
get_command_encoder(index).commit();
}

void Device::add_temporary(array arr, int index) {
get_stream_(index).temporaries.push_back(std::move(arr));
get_command_encoder(index).add_temporary(std::move(arr));
}

void Device::add_temporaries(std::vector<array> arrays, int index) {
if (arrays.empty()) {
return;
}
auto& stream = get_stream_(index);
stream.temporaries.insert(
stream.temporaries.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
get_command_encoder(index).add_temporaries(std::move(arrays));
}

void Device::end_encoding(int index) {
auto& stream = get_stream_(index);
if (stream.encoder != nullptr) {
// Each command encoder has a unique fence. We also store a map of
// all previous outputs of command encoders to their corresponding fence.
// - The command encoder records its inputs and outputs.
// - Wait on a fence if any inputs in the encoder are outputs of a previous
// encoder.
// - Update the map of outputs to include this command encoder's outputs.
// - Always signal this command encoders fence.
// - Add a completion handler for this command encoder that removes outputs
// from the map to limit the growth of the map and avoid unnecessary waits
// - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
auto& enc = *stream.encoder;
// Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) {
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}

// Keep references to the fences we waited on and put them
// in the completion handler so they are not prematurely released
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
{
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto in : enc.inputs()) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) {
enc.wait_for_fence(it->second->fence);
waiting_on.insert(it->second);
}
}
}
for (auto out : enc.outputs()) {
stream.outputs[out] = stream.fence;
}
}
enc.update_fence(stream.fence->fence);
stream.buffer->addCompletedHandler(
[&stream,
waiting_on = std::move(waiting_on),
fence = std::move(stream.fence),
outputs = std::move(enc.outputs()),
temporaries =
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
temporaries.clear();
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto o : outputs) {
if (auto it = stream.outputs.find(o); it != stream.outputs.end()) {
if (it->second == fence) {
stream.outputs.erase(it);
}
}
}
});
}
stream.encoder = nullptr;
get_command_encoder(index).end_encoding();
}

CommandEncoder& Device::get_command_encoder(int index) {
auto& stream = get_stream_(index);
if (stream.encoder == nullptr) {
// Ensure there is an active command buffer
if (stream.buffer == nullptr) {
get_command_buffer(index);
}
stream.encoder = std::make_unique<CommandEncoder>(stream);
stream.fence = std::make_shared<Fence>(device_->newFence());
auto it = encoders_.find(index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(index, *this, index, residency_set_).first;
}
return *stream.encoder;
return it->second;
}

MTL::Library* Device::get_library(
Expand Down Expand Up @@ -808,8 +824,8 @@ void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
}
residency_set_ = residency_set;
// Attach residency set to existing command queues
for (auto& [_, stream] : stream_map_) {
stream.queue->addResidencySet(residency_set_);
for (auto& [_, encoder] : encoders_) {
encoder.get_command_queue()->addResidencySet(residency_set_);
}
}

Expand Down
Loading
Loading