Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions mlx/backend/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Device::Device(int device) : device_(device) {
"Device {} does not support synchronization in managed memory.",
device_));
}

// The cublasLt handle is used by matmul.
make_current();
CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_));
Expand Down Expand Up @@ -189,12 +190,41 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
}
}

// Can be tuned with MLX_MAX_OPS_PER_BUFFER, MLX_MAX_MB_PER_BUFFER
std::pair<int, int> get_graph_limits(Device& d) {
auto cc =
d.compute_capability_major() * 100 + d.compute_capability_minor() * 10;
int ops = 20;
int mb = 100;
switch (cc) {
case 800: // A100
ops = 20;
mb = 400;
break;
case 900: // H100
ops = 30;
mb = 400;
break;
case 1000: // B200
ops = 50;
mb = 500;
break;
case 1210: // DGX Spark
ops = 20;
mb = 25;
break;
}
return {env::max_ops_per_buffer(ops), env::max_mb_per_buffer(mb)};
}

CommandEncoder::CommandEncoder(Device& d)
: device_(d),
stream_(d),
graph_(d),
worker_(d),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {
std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d);
}

void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
Expand All @@ -204,6 +234,7 @@ void CommandEncoder::set_input_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
bytes_in_graph_ += arr.data_size();
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id);
}
Expand Down Expand Up @@ -301,8 +332,9 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
insert_graph_dependencies(GraphNode{node, 'G'});
}

int CommandEncoder::get_num_ops() {
return node_count_;
bool CommandEncoder::needs_commit() {
return (node_count_ > max_ops_per_graph_) ||
((bytes_in_graph_ >> 20) > max_mb_per_graph_);
}

void CommandEncoder::commit() {
Expand Down Expand Up @@ -365,6 +397,7 @@ void CommandEncoder::commit() {
// Put completion handlers in a batch.
worker_.commit(stream_);
node_count_ = 0;
bytes_in_graph_ = 0;
}

void CommandEncoder::synchronize() {
Expand Down
6 changes: 5 additions & 1 deletion mlx/backend/cuda/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class CommandEncoder {
}

void add_completed_handler(std::function<void()> task);
int get_num_ops();
bool needs_commit();
void commit();

Device& device() {
Expand Down Expand Up @@ -131,6 +131,9 @@ class CommandEncoder {
std::vector<std::uintptr_t> active_deps_;
std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
size_t bytes_in_graph_{0};
int max_ops_per_graph_;
int max_mb_per_graph_;
};

class Device {
Expand Down Expand Up @@ -166,6 +169,7 @@ class Device {
int device_;
int compute_capability_major_;
int compute_capability_minor_;
std::string device_name_;
cublasLtHandle_t lt_;
cudnnHandle_t cudnn_;
std::unordered_map<int, CommandEncoder> encoders_;
Expand Down
6 changes: 1 addition & 5 deletions mlx/backend/cuda/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

namespace mlx::core::gpu {

// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;

bool is_available() {
return true;
}
Expand Down Expand Up @@ -53,8 +50,7 @@ void eval(array& arr) {
encoder.add_temporary(s);
}

if (encoder.get_num_ops() >=
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
if (encoder.needs_commit()) {
scheduler::notify_new_task(stream);
encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); });
Expand Down
7 changes: 2 additions & 5 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,8 @@ MTL::CommandQueue* Device::get_queue(Stream stream) {

bool Device::command_buffer_needs_commit(int index) {
auto& stream = get_stream_(index);
if (stream.buffer_ops > max_ops_per_buffer_ ||
(stream.buffer_sizes >> 20) > max_mb_per_buffer_) {
return true;
}
return false;
return (stream.buffer_ops > max_ops_per_buffer_) ||
((stream.buffer_sizes >> 20) > max_mb_per_buffer_);
}

MTL::CommandBuffer* Device::get_command_buffer(int index) {
Expand Down