Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
Expand Down
8 changes: 5 additions & 3 deletions mlx/backend/cuda/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@

namespace mlx::core::gpu {

void new_stream(Stream s) {
// Force initalization of CUDA, so CUDA runtime get destroyed at last.
void init() {
// Force initalization of CUDA, so CUDA runtime get destroyed last.
cudaFree(nullptr);
// Make sure CUDA event pool get destroyed after device and stream.
cu::CudaEvent::init_pool();
// Ensure the static stream objects get created.
}

void new_stream(Stream s) {
cu::get_command_encoder(s);
}

Expand Down
1 change: 1 addition & 0 deletions mlx/backend/gpu/eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

namespace mlx::core::gpu {

void init();
void new_stream(Stream stream);
void eval(array& arr);
void finalize(Stream s);
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/metal/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

namespace mlx::core::gpu {

void init() {}

void new_stream(Stream stream) {
if (stream.device == mlx::core::Device::gpu) {
metal::device(stream.device).get_command_encoder(stream.index);
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/no_gpu/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

namespace mlx::core::gpu {

void init() {}

void new_stream(Stream) {}

void eval(array&) {
Expand Down
8 changes: 0 additions & 8 deletions mlx/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,6 @@ void set_default_device(const Device& d) {
mutable_default_device() = d;
}

bool operator==(const Device& lhs, const Device& rhs) {
return lhs.type == rhs.type && lhs.index == rhs.index;
}

bool operator!=(const Device& lhs, const Device& rhs) {
return !(lhs == rhs);
}

bool is_available(const Device& d) {
switch (d.type) {
case Device::cpu:
Expand Down
13 changes: 10 additions & 3 deletions mlx/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,22 @@ struct MLX_API Device {

DeviceType type;
int index;

// TODO: Use default three-way comparison when it gets supported in XCode.
bool operator==(const Device&) const = default;
bool operator<(const Device& rhs) const {
return type < rhs.type || index < rhs.index;
}
};

inline bool operator==(const Device& device, Device::DeviceType type) {
return device.type == type;
}

MLX_API const Device& default_device();

MLX_API void set_default_device(const Device& d);

MLX_API bool operator==(const Device& lhs, const Device& rhs);
MLX_API bool operator!=(const Device& lhs, const Device& rhs);

MLX_API bool is_available(const Device& d);

/** Get the number of available devices for the given device type. */
Expand Down
59 changes: 22 additions & 37 deletions mlx/scheduler.cpp
Original file line number Diff line number Diff line change
@@ -1,47 +1,10 @@
// Copyright © 2023 Apple Inc.

#include "mlx/scheduler.h"
#include "mlx/backend/gpu/device_info.h"
#include "mlx/backend/gpu/eval.h"

namespace mlx::core {

Stream default_stream(Device d) {
if (!gpu::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[default_stream] Cannot get gpu stream without gpu backend.");
}
return scheduler::scheduler().get_default_stream(d);
}

void set_default_stream(Stream s) {
if (!gpu::is_available() && s.device == Device::gpu) {
throw std::invalid_argument(
"[set_default_stream] Cannot set gpu stream without gpu backend.");
}
return scheduler::scheduler().set_default_stream(s);
}

Stream get_stream(int index) {
return scheduler::scheduler().get_stream(index);
}

std::vector<Stream> get_streams() {
return scheduler::scheduler().get_streams();
}

Stream new_stream(Device d) {
if (!gpu::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[new_stream] Cannot make gpu stream without gpu backend.");
}
return scheduler::scheduler().new_stream(d);
}

Stream new_stream() {
return scheduler::scheduler().new_stream(default_device());
}

void synchronize(Stream s) {
if (s.device == mlx::core::Device::cpu) {
auto p = std::make_shared<std::promise<void>>();
Expand All @@ -59,6 +22,28 @@ void synchronize() {

namespace scheduler {

Scheduler::Scheduler() {
gpu::init();
}

Scheduler::~Scheduler() {
for (auto& s : get_streams()) {
try {
synchronize(s);
} catch (const std::runtime_error&) {
// ignore errors if synch fails
}
}
}

void Scheduler::new_thread(Device::DeviceType type) {
if (type == Device::gpu) {
threads_.push_back(nullptr);
} else {
threads_.push_back(std::make_unique<StreamThread>());
}
}

/** A singleton scheduler to manage devices, streams, and task execution. */
Scheduler& scheduler() {
// Intentionally leaked to avoid the "static destruction order fiasco":
Expand Down
58 changes: 8 additions & 50 deletions mlx/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,47 +66,18 @@ struct StreamThread {

class Scheduler {
public:
Scheduler() : n_active_tasks_(0) {
if (is_available(Device::gpu)) {
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
}
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
}
Scheduler();
~Scheduler();

// Not copyable or moveable
Scheduler(const Scheduler&) = delete;
Scheduler(Scheduler&&) = delete;
Scheduler& operator=(const Scheduler&) = delete;
Scheduler& operator=(Scheduler&&) = delete;

Stream new_stream(const Device& d) {
streams_.emplace_back(streams_.size(), d);
if (d == Device::gpu) {
threads_.push_back(nullptr);
gpu::new_stream(streams_.back());
} else {
threads_.push_back(new StreamThread{});
}
return streams_.back();
}

template <typename F>
void enqueue(const Stream& stream, F&& f);

Stream get_default_stream(const Device& d) const {
return default_streams_.at(d.type);
}
Stream get_stream(int index) const {
return streams_.at(index);
}
std::vector<Stream> get_streams() const {
return streams_;
}

void set_default_stream(const Stream& s) {
default_streams_.at(s.device.type) = s;
}

void notify_new_task(const Stream& stream) {
{
std::lock_guard<std::mutex> lk(mtx);
Expand Down Expand Up @@ -137,26 +108,13 @@ class Scheduler {
}
}

~Scheduler() {
for (auto s : streams_) {
try {
synchronize(s);
} catch (const std::runtime_error&) {
// ignore errors if synch fails
}
}
for (auto t : threads_) {
if (t != nullptr) {
delete t;
}
}
}

private:
int n_active_tasks_;
std::vector<StreamThread*> threads_;
std::vector<Stream> streams_;
std::unordered_map<Device::DeviceType, Stream> default_streams_;
friend Stream mlx::core::new_stream(Device d);

void new_thread(Device::DeviceType type);

int n_active_tasks_{0};
std::vector<std::unique_ptr<StreamThread>> threads_;
std::condition_variable completion_cv;
std::mutex mtx;
};
Expand Down
76 changes: 76 additions & 0 deletions mlx/stream.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright © 2026 Apple Inc.

#include "mlx/stream.h"
#include "mlx/backend/cpu/device_info.h"
#include "mlx/backend/gpu/device_info.h"
#include "mlx/scheduler.h"

#include <array>
#include <optional>
#include <shared_mutex>

namespace mlx::core {

namespace {

auto& default_stream_storage(Device d) {
// Each device has its own default stream in each thread.
static thread_local auto default_streams = []() {
std::array<std::vector<std::optional<Stream>>, 2> streams;
streams[static_cast<size_t>(Device::cpu)].resize(cpu::device_count());
streams[static_cast<size_t>(Device::gpu)].resize(gpu::device_count());
return streams;
}();
return default_streams[static_cast<size_t>(d.type)].at(d.index);
}

auto& all_streams() {
static std::tuple<std::vector<Stream>, std::shared_mutex> streams_and_mtx;
return streams_and_mtx;
}

} // namespace

Stream default_stream(Device d) {
if (!gpu::is_available() && d.type == Device::gpu) {
throw std::invalid_argument(
"[default_stream] Cannot get gpu stream without gpu backend.");
}
auto& s = default_stream_storage(d);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well this is not necessarily a bug in this code but Device can have any index weirdly. ie I can make Device(Device::gpu, 7) and pass it to default_stream which will access out of bounds memory.

So for this code to be correct I think the constructor of Device needs to check that 0 <= index < device_count(dev_type).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking index in Device::Device would break code like is_available(Device::gpu) which constructs an invalid device first and then checks it.

I changed default_stream_storage to do bound check by using default_streams.at(d.index).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense. I thought of that but then I thought it is generally weird that we can create arbitrary devices but maybe that's fine.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is indeed a weird design, I think the API should be is_available(DeviceType type, int index) which should be compatible with most C++ code but would require API change in mlx-c, not sure if we should change it @andresy.

if (!s.has_value()) {
s = new_stream(d.type);
}
return s.value();
}

void set_default_stream(Stream s) {
if (!gpu::is_available() && s.device == Device::gpu) {
throw std::invalid_argument(
"[set_default_stream] Cannot set gpu stream without gpu backend.");
}
default_stream_storage(s.device) = s;
}

std::vector<Stream> get_streams() {
auto& [streams, mtx] = all_streams();
std::shared_lock lock(mtx);
return streams;
}

Stream new_stream(Device d) {
if (!gpu::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[new_stream] Cannot make gpu stream without gpu backend.");
}
auto& [streams, mtx] = all_streams();
std::unique_lock lock(mtx);
int index = streams.size();
auto& s = streams.emplace_back(index, d);
scheduler::scheduler().new_thread(d.type);
if (d == Device::gpu) {
gpu::new_stream(s);
}
return s;
}

} // namespace mlx::core
21 changes: 8 additions & 13 deletions mlx/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,26 @@ struct MLX_API Stream {
int index;
Device device;
explicit Stream(int index, Device device) : index(index), device(device) {}

// TODO: Use default three-way comparison when it gets supported in XCode.
bool operator==(const Stream&) const = default;
bool operator<(const Stream& rhs) const {
return device < rhs.device || index < rhs.index;
}
};

/** Get the default stream for the given device. */
/** Get the default stream of current thread for the given device. */
MLX_API Stream default_stream(Device d);

/** Make the stream the default for its device. */
/** Make the stream the default for its device on current thread. */
MLX_API void set_default_stream(Stream s);

/** Make a new stream on the given device. */
MLX_API Stream new_stream(Device d);

/** Get the stream with the given index. */
MLX_API Stream get_stream(int index);

/** Get all available streams. */
MLX_API std::vector<Stream> get_streams();

inline bool operator==(const Stream& lhs, const Stream& rhs) {
return lhs.index == rhs.index;
}

inline bool operator!=(const Stream& lhs, const Stream& rhs) {
return !(lhs == rhs);
}

/* Synchronize with the default stream. */
MLX_API void synchronize();

Expand Down
Loading
Loading