Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions mlx/backend/metal/eval.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include <memory>
#include <mutex>

#include "mlx/backend/gpu/eval.h"
#include "mlx/backend/metal/device.h"
Expand All @@ -9,6 +10,31 @@

namespace mlx::core::gpu {

namespace {
// Thread-safe deferred error from Metal completion handlers.
// Completion handlers run on GCD threads where C++ exceptions
// hit std::terminate. Instead, we store the error and re-throw
// at the next eval() or synchronize() call.
std::mutex deferred_error_mutex;
std::string deferred_error_message;

void set_deferred_error(const std::string& msg) {
std::lock_guard<std::mutex> lock(deferred_error_mutex);
if (deferred_error_message.empty()) {
deferred_error_message = msg;
}
}

void check_deferred_error() {
std::lock_guard<std::mutex> lock(deferred_error_mutex);
if (!deferred_error_message.empty()) {
std::string msg = std::move(deferred_error_message);
deferred_error_message.clear();
throw std::runtime_error(msg);
}
}
} // namespace

void init() {}

void new_stream(Stream stream) {
Expand All @@ -26,7 +52,20 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
}
}

// Safe version for Metal completion handlers (GCD callbacks).
// Cannot throw — stores error for deferred propagation.
inline void check_error_deferred(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
msg << "[METAL] Command buffer execution failed: "
<< cbuf->error()->localizedDescription()->utf8String();
set_deferred_error(msg.str());
}
}

void eval(array& arr) {
// Re-throw any deferred error from a prior completion handler
check_deferred_error();
auto pool = metal::new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto& d = metal::device(s.device);
Expand Down Expand Up @@ -62,13 +101,13 @@ void eval(array& arr) {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
check_error_deferred(cbuf);
});
d.commit_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
check_error_deferred(cbuf);
});
}
}
Expand All @@ -78,11 +117,13 @@ void finalize(Stream s) {
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index);
cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
cb->addCompletedHandler(
[](MTL::CommandBuffer* cbuf) { check_error_deferred(cbuf); });
d.commit_command_buffer(s.index);
}

void synchronize(Stream s) {
check_deferred_error();
auto pool = metal::new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
Expand Down