From adb6cdbe7fa4f7c845e697609b8f0992adae6fbb Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 20 Mar 2026 09:50:25 +0900 Subject: [PATCH] Make each thread have its own default stream --- mlx/CMakeLists.txt | 1 + mlx/backend/cuda/eval.cpp | 8 ++-- mlx/backend/gpu/eval.h | 1 + mlx/backend/metal/eval.cpp | 2 + mlx/backend/no_gpu/eval.cpp | 2 + mlx/device.cpp | 8 ---- mlx/device.h | 13 +++++-- mlx/scheduler.cpp | 59 +++++++++++----------------- mlx/scheduler.h | 58 ++++------------------------ mlx/stream.cpp | 76 +++++++++++++++++++++++++++++++++++++ mlx/stream.h | 21 ++++------ mlx/transforms.cpp | 12 +++--- python/src/device.cpp | 8 +++- tests/scheduler_tests.cpp | 29 ++++++++++++++ 14 files changed, 175 insertions(+), 123 deletions(-) create mode 100644 mlx/stream.cpp diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 9c0fd38899..06bf2a244c 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -14,6 +14,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index c8da425877..8b8bf598df 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -10,12 +10,14 @@ namespace mlx::core::gpu { -void new_stream(Stream s) { - // Force initalization of CUDA, so CUDA runtime get destroyed at last. +void init() { + // Force initalization of CUDA, so CUDA runtime get destroyed last. cudaFree(nullptr); // Make sure CUDA event pool get destroyed after device and stream. cu::CudaEvent::init_pool(); - // Ensure the static stream objects get created. +} + +void new_stream(Stream s) { cu::get_command_encoder(s); } diff --git a/mlx/backend/gpu/eval.h b/mlx/backend/gpu/eval.h index f646c2ec9d..bdd3f81f1b 100644 --- a/mlx/backend/gpu/eval.h +++ b/mlx/backend/gpu/eval.h @@ -10,6 +10,7 @@ namespace mlx::core::gpu { +void init(); void new_stream(Stream stream); void eval(array& arr); void finalize(Stream s); diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index 123fd74057..5790a21685 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -9,6 +9,8 @@ namespace mlx::core::gpu { +void init() {} + void new_stream(Stream stream) { if (stream.device == mlx::core::Device::gpu) { metal::device(stream.device).get_command_encoder(stream.index); diff --git a/mlx/backend/no_gpu/eval.cpp b/mlx/backend/no_gpu/eval.cpp index 5d88970a07..831cfb395a 100644 --- a/mlx/backend/no_gpu/eval.cpp +++ b/mlx/backend/no_gpu/eval.cpp @@ -7,6 +7,8 @@ namespace mlx::core::gpu { +void init() {} + void new_stream(Stream) {} void eval(array&) { diff --git a/mlx/device.cpp b/mlx/device.cpp index 18fd33459a..f0c868f21b 100644 --- a/mlx/device.cpp +++ b/mlx/device.cpp @@ -25,14 +25,6 @@ void set_default_device(const Device& d) { mutable_default_device() = d; } -bool operator==(const Device& lhs, const Device& rhs) { - return lhs.type == rhs.type && lhs.index == rhs.index; -} - -bool operator!=(const Device& lhs, const Device& rhs) { - return !(lhs == rhs); -} - bool is_available(const Device& d) { switch (d.type) { case Device::cpu: diff --git a/mlx/device.h b/mlx/device.h index f89ad189fb..9ddd1b3f5b 100644 --- a/mlx/device.h +++ b/mlx/device.h @@ -23,15 +23,22 @@ struct MLX_API Device { DeviceType type; int index; + + // TODO: Use default three-way comparison when it gets supported in XCode. + bool operator==(const Device&) const = default; + bool operator<(const Device& rhs) const { + return type < rhs.type || index < rhs.index; + } }; +inline bool operator==(const Device& device, Device::DeviceType type) { + return device.type == type; +} + MLX_API const Device& default_device(); MLX_API void set_default_device(const Device& d); -MLX_API bool operator==(const Device& lhs, const Device& rhs); -MLX_API bool operator!=(const Device& lhs, const Device& rhs); - MLX_API bool is_available(const Device& d); /** Get the number of available devices for the given device type. */ diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index 584e3bd382..4a8bda3adc 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -1,47 +1,10 @@ // Copyright © 2023 Apple Inc. #include "mlx/scheduler.h" -#include "mlx/backend/gpu/device_info.h" #include "mlx/backend/gpu/eval.h" namespace mlx::core { -Stream default_stream(Device d) { - if (!gpu::is_available() && d == Device::gpu) { - throw std::invalid_argument( - "[default_stream] Cannot get gpu stream without gpu backend."); - } - return scheduler::scheduler().get_default_stream(d); -} - -void set_default_stream(Stream s) { - if (!gpu::is_available() && s.device == Device::gpu) { - throw std::invalid_argument( - "[set_default_stream] Cannot set gpu stream without gpu backend."); - } - return scheduler::scheduler().set_default_stream(s); -} - -Stream get_stream(int index) { - return scheduler::scheduler().get_stream(index); -} - -std::vector get_streams() { - return scheduler::scheduler().get_streams(); -} - -Stream new_stream(Device d) { - if (!gpu::is_available() && d == Device::gpu) { - throw std::invalid_argument( - "[new_stream] Cannot make gpu stream without gpu backend."); - } - return scheduler::scheduler().new_stream(d); -} - -Stream new_stream() { - return scheduler::scheduler().new_stream(default_device()); -} - void synchronize(Stream s) { if (s.device == mlx::core::Device::cpu) { auto p = std::make_shared>(); @@ -59,6 +22,28 @@ void synchronize() { namespace scheduler { +Scheduler::Scheduler() { + gpu::init(); +} + +Scheduler::~Scheduler() { + for (auto& s : get_streams()) { + try { + synchronize(s); + } catch (const std::runtime_error&) { + // ignore errors if synch fails + } + } +} + +void Scheduler::new_thread(Device::DeviceType type) { + if (type == Device::gpu) { + threads_.push_back(nullptr); + } else { + threads_.push_back(std::make_unique()); + } +} + /** A singleton scheduler to manage devices, streams, and task execution. */ Scheduler& scheduler() { // Intentionally leaked to avoid the "static destruction order fiasco": diff --git a/mlx/scheduler.h b/mlx/scheduler.h index c94044a797..619c65f1c5 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -66,12 +66,8 @@ struct StreamThread { class Scheduler { public: - Scheduler() : n_active_tasks_(0) { - if (is_available(Device::gpu)) { - default_streams_.insert({Device::gpu, new_stream(Device::gpu)}); - } - default_streams_.insert({Device::cpu, new_stream(Device::cpu)}); - } + Scheduler(); + ~Scheduler(); // Not copyable or moveable Scheduler(const Scheduler&) = delete; @@ -79,34 +75,9 @@ class Scheduler { Scheduler& operator=(const Scheduler&) = delete; Scheduler& operator=(Scheduler&&) = delete; - Stream new_stream(const Device& d) { - streams_.emplace_back(streams_.size(), d); - if (d == Device::gpu) { - threads_.push_back(nullptr); - gpu::new_stream(streams_.back()); - } else { - threads_.push_back(new StreamThread{}); - } - return streams_.back(); - } - template void enqueue(const Stream& stream, F&& f); - Stream get_default_stream(const Device& d) const { - return default_streams_.at(d.type); - } - Stream get_stream(int index) const { - return streams_.at(index); - } - std::vector get_streams() const { - return streams_; - } - - void set_default_stream(const Stream& s) { - default_streams_.at(s.device.type) = s; - } - void notify_new_task(const Stream& stream) { { std::lock_guard lk(mtx); @@ -137,26 +108,13 @@ class Scheduler { } } - ~Scheduler() { - for (auto s : streams_) { - try { - synchronize(s); - } catch (const std::runtime_error&) { - // ignore errors if synch fails - } - } - for (auto t : threads_) { - if (t != nullptr) { - delete t; - } - } - } - private: - int n_active_tasks_; - std::vector threads_; - std::vector streams_; - std::unordered_map default_streams_; + friend Stream mlx::core::new_stream(Device d); + + void new_thread(Device::DeviceType type); + + int n_active_tasks_{0}; + std::vector> threads_; std::condition_variable completion_cv; std::mutex mtx; }; diff --git a/mlx/stream.cpp b/mlx/stream.cpp new file mode 100644 index 0000000000..ee1db01629 --- /dev/null +++ b/mlx/stream.cpp @@ -0,0 +1,76 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/stream.h" +#include "mlx/backend/cpu/device_info.h" +#include "mlx/backend/gpu/device_info.h" +#include "mlx/scheduler.h" + +#include +#include +#include + +namespace mlx::core { + +namespace { + +auto& default_stream_storage(Device d) { + // Each device has its own default stream in each thread. + static thread_local auto default_streams = []() { + std::array>, 2> streams; + streams[static_cast(Device::cpu)].resize(cpu::device_count()); + streams[static_cast(Device::gpu)].resize(gpu::device_count()); + return streams; + }(); + return default_streams[static_cast(d.type)].at(d.index); +} + +auto& all_streams() { + static std::tuple, std::shared_mutex> streams_and_mtx; + return streams_and_mtx; +} + +} // namespace + +Stream default_stream(Device d) { + if (!gpu::is_available() && d.type == Device::gpu) { + throw std::invalid_argument( + "[default_stream] Cannot get gpu stream without gpu backend."); + } + auto& s = default_stream_storage(d); + if (!s.has_value()) { + s = new_stream(d.type); + } + return s.value(); +} + +void set_default_stream(Stream s) { + if (!gpu::is_available() && s.device == Device::gpu) { + throw std::invalid_argument( + "[set_default_stream] Cannot set gpu stream without gpu backend."); + } + default_stream_storage(s.device) = s; +} + +std::vector get_streams() { + auto& [streams, mtx] = all_streams(); + std::shared_lock lock(mtx); + return streams; +} + +Stream new_stream(Device d) { + if (!gpu::is_available() && d == Device::gpu) { + throw std::invalid_argument( + "[new_stream] Cannot make gpu stream without gpu backend."); + } + auto& [streams, mtx] = all_streams(); + std::unique_lock lock(mtx); + int index = streams.size(); + auto& s = streams.emplace_back(index, d); + scheduler::scheduler().new_thread(d.type); + if (d == Device::gpu) { + gpu::new_stream(s); + } + return s; +} + +} // namespace mlx::core diff --git a/mlx/stream.h b/mlx/stream.h index efe0ef1a74..54f7d82015 100644 --- a/mlx/stream.h +++ b/mlx/stream.h @@ -13,31 +13,26 @@ struct MLX_API Stream { int index; Device device; explicit Stream(int index, Device device) : index(index), device(device) {} + + // TODO: Use default three-way comparison when it gets supported in XCode. + bool operator==(const Stream&) const = default; + bool operator<(const Stream& rhs) const { + return device < rhs.device || index < rhs.index; + } }; -/** Get the default stream for the given device. */ +/** Get the default stream of current thread for the given device. */ MLX_API Stream default_stream(Device d); -/** Make the stream the default for its device. */ +/** Make the stream the default for its device on current thread. */ MLX_API void set_default_stream(Stream s); /** Make a new stream on the given device. */ MLX_API Stream new_stream(Device d); -/** Get the stream with the given index. */ -MLX_API Stream get_stream(int index); - /** Get all available streams. */ MLX_API std::vector get_streams(); -inline bool operator==(const Stream& lhs, const Stream& rhs) { - return lhs.index == rhs.index; -} - -inline bool operator!=(const Stream& lhs, const Stream& rhs) { - return !(lhs == rhs); -} - /* Synchronize with the default stream. */ MLX_API void synchronize(); diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index c7be012083..e8c124cfe8 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -218,13 +218,13 @@ array eval_impl(std::vector outputs, bool async) { } } - std::unordered_set open_streams; + std::set open_streams; while (!tape.empty()) { auto arr = std::move(tape.back()); tape.pop_back(); auto stream = arr.primitive().stream(); - open_streams.insert(stream.index); + open_streams.insert(stream); if (async) { // Lookup corresponding event @@ -265,8 +265,7 @@ array eval_impl(std::vector outputs, bool async) { (get_active_memory() > get_memory_limit() && scheduler::n_active_tasks() > 0)) { // Commit any open streams - for (auto i : open_streams) { - auto s = get_stream(i); + for (auto& s : open_streams) { if (s.device == Device::gpu) { gpu::finalize(s); } @@ -302,9 +301,8 @@ array eval_impl(std::vector outputs, bool async) { } // Signal the event in its stream - for (auto i : open_streams) { - auto s = get_stream(i); - if (auto e = events.find(i); e != events.end()) { + for (auto& s : open_streams) { + if (auto e = events.find(s.index); e != events.end()) { e->second.signal(s); } if (s.device == Device::gpu) { diff --git a/python/src/device.cpp b/python/src/device.cpp index f15f7f92db..e70b69bd34 100644 --- a/python/src/device.cpp +++ b/python/src/device.cpp @@ -1,8 +1,10 @@ // Copyright © 2023-2025 Apple Inc. +#include #include #include +#include #include #include #include @@ -80,8 +82,10 @@ void init_device(nb::module_& m) { )pbdoc"); m.def( "device_info", - &mx::device_info, - nb::arg("d") = mx::default_device(), + [](std::optional d) { + return mx::device_info(d.value_or(mx::default_device())); + }, + "d"_a = nb::none(), R"pbdoc( Get information about a device. diff --git a/tests/scheduler_tests.cpp b/tests/scheduler_tests.cpp index eb10b2b98d..532f168616 100644 --- a/tests/scheduler_tests.cpp +++ b/tests/scheduler_tests.cpp @@ -36,6 +36,35 @@ TEST_CASE("test stream management") { } } +TEST_CASE("test default stream in threads") { + std::set thread_streams; + std::mutex mtx; + std::vector threads; + + auto all_streams = get_streams(); + size_t num_streams = all_streams.size(); + + size_t num_threads = 4; + for (size_t i = 0; i < num_threads; ++i) { + threads.emplace_back([&thread_streams, &mtx]() { + auto s = default_stream(gpu::is_available() ? Device::gpu : Device::cpu); + std::lock_guard lock(mtx); + thread_streams.insert(s); + }); + } + + for (auto& t : threads) { + t.join(); + } + CHECK_EQ(thread_streams.size(), num_threads); + + all_streams = get_streams(); + CHECK_EQ(all_streams.size() - num_streams, num_threads); + + std::set new_streams(all_streams.begin() + num_streams, all_streams.end()); + CHECK_EQ(new_streams, thread_streams); +} + TEST_CASE("test get streams") { auto streams = get_streams();