diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index 5790a21685..fbb9dfa360 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include "mlx/backend/gpu/eval.h" #include "mlx/backend/metal/device.h" @@ -9,6 +10,31 @@ namespace mlx::core::gpu { +namespace { +// Thread-safe deferred error from Metal completion handlers. +// Completion handlers run on GCD threads where C++ exceptions +// hit std::terminate. Instead, we store the error and re-throw +// at the next eval() or synchronize() call. +std::mutex deferred_error_mutex; +std::string deferred_error_message; + +void set_deferred_error(const std::string& msg) { + std::lock_guard lock(deferred_error_mutex); + if (deferred_error_message.empty()) { + deferred_error_message = msg; + } +} + +void check_deferred_error() { + std::lock_guard lock(deferred_error_mutex); + if (!deferred_error_message.empty()) { + std::string msg = std::move(deferred_error_message); + deferred_error_message.clear(); + throw std::runtime_error(msg); + } +} +} // namespace + void init() {} void new_stream(Stream stream) { @@ -26,7 +52,20 @@ inline void check_error(MTL::CommandBuffer* cbuf) { } } +// Safe version for Metal completion handlers (GCD callbacks). +// Cannot throw — stores error for deferred propagation. +inline void check_error_deferred(MTL::CommandBuffer* cbuf) { + if (cbuf->status() == MTL::CommandBufferStatusError) { + std::ostringstream msg; + msg << "[METAL] Command buffer execution failed: " + << cbuf->error()->localizedDescription()->utf8String(); + set_deferred_error(msg.str()); + } +} + void eval(array& arr) { + // Re-throw any deferred error from a prior completion handler + check_deferred_error(); auto pool = metal::new_scoped_memory_pool(); auto s = arr.primitive().stream(); auto& d = metal::device(s.device); @@ -62,13 +101,13 @@ void eval(array& arr) { command_buffer->addCompletedHandler( [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { scheduler::notify_task_completion(s); - check_error(cbuf); + check_error_deferred(cbuf); }); d.commit_command_buffer(s.index); } else { command_buffer->addCompletedHandler( [buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - check_error(cbuf); + check_error_deferred(cbuf); }); } } @@ -78,11 +117,13 @@ void finalize(Stream s) { auto& d = metal::device(s.device); auto cb = d.get_command_buffer(s.index); d.end_encoding(s.index); - cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); + cb->addCompletedHandler( + [](MTL::CommandBuffer* cbuf) { check_error_deferred(cbuf); }); d.commit_command_buffer(s.index); } void synchronize(Stream s) { + check_deferred_error(); auto pool = metal::new_scoped_memory_pool(); auto& d = metal::device(s.device); auto cb = d.get_command_buffer(s.index);