diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 539704fe1c..24b85821e1 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -46,6 +46,7 @@ Device::Device(int device) : device_(device) { "Device {} does not support synchronization in managed memory.", device_)); } + // The cublasLt handle is used by matmul. make_current(); CHECK_CUBLAS_ERROR(cublasLtCreate(<_)); @@ -189,12 +190,41 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { } } +// Can be tuned with MLX_MAX_OPS_PER_BUFFER, MLX_MAX_MB_PER_BUFFER +std::pair get_graph_limits(Device& d) { + auto cc = + d.compute_capability_major() * 100 + d.compute_capability_minor() * 10; + int ops = 20; + int mb = 100; + switch (cc) { + case 800: // A100 + ops = 20; + mb = 400; + break; + case 900: // H100 + ops = 30; + mb = 400; + break; + case 1000: // B200 + ops = 50; + mb = 500; + break; + case 1210: // DGX Spark + ops = 20; + mb = 25; + break; + } + return {env::max_ops_per_buffer(ops), env::max_mb_per_buffer(mb)}; +} + CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d), graph_(d), worker_(d), - graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {} + graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) { + std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d); +} void CommandEncoder::add_completed_handler(std::function task) { worker_.add_task(std::move(task)); @@ -204,6 +234,7 @@ void CommandEncoder::set_input_array(const array& arr) { if (!use_cuda_graphs()) { return; } + bytes_in_graph_ += arr.data_size(); auto id = reinterpret_cast(arr.buffer().ptr()); active_deps_.push_back(id); } @@ -301,8 +332,9 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) { insert_graph_dependencies(GraphNode{node, 'G'}); } -int CommandEncoder::get_num_ops() { - return node_count_; +bool CommandEncoder::needs_commit() { + return (node_count_ > max_ops_per_graph_) || + ((bytes_in_graph_ >> 20) > max_mb_per_graph_); } void CommandEncoder::commit() { @@ -365,6 +397,7 @@ void CommandEncoder::commit() { // Put completion handlers in a batch. worker_.commit(stream_); node_count_ = 0; + bytes_in_graph_ = 0; } void CommandEncoder::synchronize() { diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index c049734842..196cd799f6 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -84,7 +84,7 @@ class CommandEncoder { } void add_completed_handler(std::function task); - int get_num_ops(); + bool needs_commit(); void commit(); Device& device() { @@ -131,6 +131,9 @@ class CommandEncoder { std::vector active_deps_; std::vector active_outputs_; std::unordered_map node_map_; + size_t bytes_in_graph_{0}; + int max_ops_per_graph_; + int max_mb_per_graph_; }; class Device { @@ -166,6 +169,7 @@ class Device { int device_; int compute_capability_major_; int compute_capability_minor_; + std::string device_name_; cublasLtHandle_t lt_; cudnnHandle_t cudnn_; std::unordered_map encoders_; diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index 07b3ad63e9..ef58f4a7ed 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -11,9 +11,6 @@ namespace mlx::core::gpu { -// Can be tuned with MLX_MAX_OPS_PER_BUFFER -constexpr int default_max_nodes_per_graph = 20; - bool is_available() { return true; } @@ -53,8 +50,7 @@ void eval(array& arr) { encoder.add_temporary(s); } - if (encoder.get_num_ops() >= - env::max_ops_per_buffer(default_max_nodes_per_graph)) { + if (encoder.needs_commit()) { scheduler::notify_new_task(stream); encoder.add_completed_handler( [stream]() { scheduler::notify_task_completion(stream); }); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index f9c8b8052d..65cd32d30f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -382,11 +382,8 @@ MTL::CommandQueue* Device::get_queue(Stream stream) { bool Device::command_buffer_needs_commit(int index) { auto& stream = get_stream_(index); - if (stream.buffer_ops > max_ops_per_buffer_ || - (stream.buffer_sizes >> 20) > max_mb_per_buffer_) { - return true; - } - return false; + return (stream.buffer_ops > max_ops_per_buffer_) || + ((stream.buffer_sizes >> 20) > max_mb_per_buffer_); } MTL::CommandBuffer* Device::get_command_buffer(int index) {