From 48eb7bb7789199d848f730540ace362aea4b1bd9 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 14 Nov 2025 01:07:51 +0000 Subject: [PATCH 1/3] tune ops per buffer based on device --- mlx/backend/cuda/device.cpp | 20 +++++++++++++++++--- mlx/backend/cuda/device.h | 4 +++- mlx/backend/cuda/eval.cpp | 6 +----- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 539704fe1c..682670d262 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -46,6 +46,7 @@ Device::Device(int device) : device_(device) { "Device {} does not support synchronization in managed memory.", device_)); } + // The cublasLt handle is used by matmul. make_current(); CHECK_CUBLAS_ERROR(cublasLtCreate(<_)); @@ -189,12 +190,25 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { } } +// Can be tuned with MLX_MAX_OPS_PER_BUFFER +int get_max_ops_per_graph(Device& d) { + auto cc = d.compute_capability_major() * 100 + d.compute_capability_minor() * 10; + int n = 20; + switch (cc) { + case 1000: // B200 + n = 50; + break; + } + return env::max_ops_per_buffer(n); +} + CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d), graph_(d), worker_(d), - graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {} + graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400), + max_ops_per_graph_(get_max_ops_per_graph(d)) { } void CommandEncoder::add_completed_handler(std::function task) { worker_.add_task(std::move(task)); @@ -301,8 +315,8 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) { insert_graph_dependencies(GraphNode{node, 'G'}); } -int CommandEncoder::get_num_ops() { - return node_count_; +bool CommandEncoder::needs_commit() { + return node_count_ > max_ops_per_graph_; } void CommandEncoder::commit() { diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index c049734842..4b9a5099a7 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -84,7 +84,7 @@ class CommandEncoder { } void add_completed_handler(std::function task); - int get_num_ops(); + bool needs_commit(); void commit(); Device& device() { @@ -131,6 +131,7 @@ class CommandEncoder { std::vector active_deps_; std::vector active_outputs_; std::unordered_map node_map_; + int max_ops_per_graph_; }; class Device { @@ -166,6 +167,7 @@ class Device { int device_; int compute_capability_major_; int compute_capability_minor_; + std::string device_name_; cublasLtHandle_t lt_; cudnnHandle_t cudnn_; std::unordered_map encoders_; diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index 07b3ad63e9..ef58f4a7ed 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -11,9 +11,6 @@ namespace mlx::core::gpu { -// Can be tuned with MLX_MAX_OPS_PER_BUFFER -constexpr int default_max_nodes_per_graph = 20; - bool is_available() { return true; } @@ -53,8 +50,7 @@ void eval(array& arr) { encoder.add_temporary(s); } - if (encoder.get_num_ops() >= - env::max_ops_per_buffer(default_max_nodes_per_graph)) { + if (encoder.needs_commit()) { scheduler::notify_new_task(stream); encoder.add_completed_handler( [stream]() { scheduler::notify_task_completion(stream); }); From 7ed3c82f0ae24285ea53182bf2638f31dea5825d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 14 Nov 2025 21:05:01 +0000 Subject: [PATCH 2/3] tune memory limit as well --- mlx/backend/cuda/device.cpp | 23 +++++++++++++++-------- mlx/backend/cuda/device.h | 2 ++ mlx/backend/metal/device.cpp | 7 ++----- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 682670d262..58bce09ffb 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -190,16 +190,18 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { } } -// Can be tuned with MLX_MAX_OPS_PER_BUFFER -int get_max_ops_per_graph(Device& d) { +// Can be tuned with MLX_MAX_OPS_PER_BUFFER, MLX_MAX_MB_PER_BUFFER +std::pair get_graph_limits(Device& d) { auto cc = d.compute_capability_major() * 100 + d.compute_capability_minor() * 10; - int n = 20; + int ops = 20; + int mb = 100; switch (cc) { case 1000: // B200 - n = 50; + ops = 50; + mb = 500; break; } - return env::max_ops_per_buffer(n); + return {env::max_ops_per_buffer(ops), env::max_mb_per_buffer(mb)}; } CommandEncoder::CommandEncoder(Device& d) @@ -207,8 +209,10 @@ CommandEncoder::CommandEncoder(Device& d) stream_(d), graph_(d), worker_(d), - graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400), - max_ops_per_graph_(get_max_ops_per_graph(d)) { } + graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) + { + std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d); +} void CommandEncoder::add_completed_handler(std::function task) { worker_.add_task(std::move(task)); @@ -218,6 +222,7 @@ void CommandEncoder::set_input_array(const array& arr) { if (!use_cuda_graphs()) { return; } + bytes_in_graph_ += arr.data_size(); auto id = reinterpret_cast(arr.buffer().ptr()); active_deps_.push_back(id); } @@ -316,7 +321,8 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) { } bool CommandEncoder::needs_commit() { - return node_count_ > max_ops_per_graph_; + return (node_count_ > max_ops_per_graph_) || + ((bytes_in_graph_ >> 20) > max_mb_per_graph_); } void CommandEncoder::commit() { @@ -379,6 +385,7 @@ void CommandEncoder::commit() { // Put completion handlers in a batch. worker_.commit(stream_); node_count_ = 0; + bytes_in_graph_ = 0; } void CommandEncoder::synchronize() { diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 4b9a5099a7..196cd799f6 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -131,7 +131,9 @@ class CommandEncoder { std::vector active_deps_; std::vector active_outputs_; std::unordered_map node_map_; + size_t bytes_in_graph_{0}; int max_ops_per_graph_; + int max_mb_per_graph_; }; class Device { diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index f9c8b8052d..65cd32d30f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -382,11 +382,8 @@ MTL::CommandQueue* Device::get_queue(Stream stream) { bool Device::command_buffer_needs_commit(int index) { auto& stream = get_stream_(index); - if (stream.buffer_ops > max_ops_per_buffer_ || - (stream.buffer_sizes >> 20) > max_mb_per_buffer_) { - return true; - } - return false; + return (stream.buffer_ops > max_ops_per_buffer_) || + ((stream.buffer_sizes >> 20) > max_mb_per_buffer_); } MTL::CommandBuffer* Device::get_command_buffer(int index) { From e2694beb803333686b9efc38ed5c982504907f2e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 14 Nov 2025 16:13:25 -0800 Subject: [PATCH 3/3] add tuning for spark --- mlx/backend/cuda/device.cpp | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 58bce09ffb..24b85821e1 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -192,14 +192,27 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { // Can be tuned with MLX_MAX_OPS_PER_BUFFER, MLX_MAX_MB_PER_BUFFER std::pair get_graph_limits(Device& d) { - auto cc = d.compute_capability_major() * 100 + d.compute_capability_minor() * 10; + auto cc = + d.compute_capability_major() * 100 + d.compute_capability_minor() * 10; int ops = 20; int mb = 100; switch (cc) { + case 800: // A100 + ops = 20; + mb = 400; + break; + case 900: // H100 + ops = 30; + mb = 400; + break; case 1000: // B200 ops = 50; mb = 500; break; + case 1210: // DGX Spark + ops = 20; + mb = 25; + break; } return {env::max_ops_per_buffer(ops), env::max_mb_per_buffer(mb)}; } @@ -209,9 +222,8 @@ CommandEncoder::CommandEncoder(Device& d) stream_(d), graph_(d), worker_(d), - graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) - { - std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d); + graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) { + std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d); } void CommandEncoder::add_completed_handler(std::function task) {