From abc43c5ddc936be34960ce4b158af2ef05167cc6 Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 16 Mar 2026 20:27:43 +0900 Subject: [PATCH] Manage Metal objects with smart pointers --- mlx/backend/metal/allocator.cpp | 22 ++--- mlx/backend/metal/allocator.h | 8 +- mlx/backend/metal/device.cpp | 154 ++++++++++++++++---------------- mlx/backend/metal/device.h | 54 ++++------- mlx/backend/metal/metal.cpp | 3 +- mlx/backend/metal/resident.cpp | 12 +-- mlx/backend/metal/resident.h | 4 +- 7 files changed, 108 insertions(+), 149 deletions(-) diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index c15bf3bdf7..3f9d9b197d 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -40,10 +40,10 @@ MetalAllocator::MetalAllocator() if (!buf->heap()) { residency_set_.erase(buf); } + auto pool = metal::new_scoped_memory_pool(); buf->release(); }), residency_set_(device_) { - auto pool = metal::new_scoped_memory_pool(); const auto& info = gpu::device_info(0); auto memsize = std::get(info.at("memory_size")); auto max_rec_size = @@ -59,21 +59,15 @@ MetalAllocator::MetalAllocator() if (is_vm) { return; } - auto heap_desc = MTL::HeapDescriptor::alloc()->init(); + auto pool = metal::new_scoped_memory_pool(); + auto heap_desc = MTL::HeapDescriptor::alloc()->init()->autorelease(); heap_desc->setResourceOptions(resource_options); heap_desc->setSize(heap_size_); - heap_ = device_->newHeap(heap_desc); - heap_desc->release(); - residency_set_.insert(heap_); + heap_ = NS::TransferPtr(device_->newHeap(heap_desc)); + residency_set_.insert(heap_.get()); } -MetalAllocator::~MetalAllocator() { - auto pool = metal::new_scoped_memory_pool(); - if (heap_) { - heap_->release(); - } - buffer_cache_.clear(); -} +MetalAllocator::~MetalAllocator() = default; size_t MetalAllocator::set_cache_limit(size_t limit) { std::unique_lock lk(mutex_); @@ -128,8 +122,6 @@ Buffer MetalAllocator::malloc(size_t size) { if (!buf) { size_t mem_required = get_active_memory() + get_cache_memory() + size; - auto pool = metal::new_scoped_memory_pool(); - // If we have a lot of memory pressure try to reclaim memory from the cache if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) { num_resources_ -= @@ -167,7 +159,6 @@ Buffer MetalAllocator::malloc(size_t size) { // Maintain the cache below the requested limit if (get_cache_memory() > max_pool_size_) { - auto pool = metal::new_scoped_memory_pool(); num_resources_ -= buffer_cache_.release_cached_buffers( get_cache_memory() - max_pool_size_); } @@ -177,7 +168,6 @@ Buffer MetalAllocator::malloc(size_t size) { void MetalAllocator::clear_cache() { std::unique_lock lk(mutex_); - auto pool = metal::new_scoped_memory_pool(); num_resources_ -= buffer_cache_.clear(); } diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 5e177b3d3e..49e09d7f6b 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -51,16 +51,18 @@ class MetalAllocator : public allocator::Allocator { // the heap, a heap can have at most heap.size() / 256 buffers. static constexpr int small_size_ = 256; static constexpr int heap_size_ = 1 << 20; - MTL::Heap* heap_; + MetalAllocator(); ~MetalAllocator(); + friend MetalAllocator& allocator(); + NS::SharedPtr heap_; + ResidencySet residency_set_; + // Caching allocator BufferCache buffer_cache_; - ResidencySet residency_set_; - // Allocation stats size_t block_limit_; size_t gc_limit_; diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index aad42f4471..bdf2df3936 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -45,15 +45,16 @@ auto get_metal_version() { return metal_version_; } -auto load_device() { - auto devices = MTL::CopyAllDevices(); - auto device = static_cast(devices->object(0)) - ?: MTL::CreateSystemDefaultDevice(); +NS::SharedPtr load_device() { + auto devices = NS::TransferPtr(MTL::CopyAllDevices()); + auto device = NS::RetainPtr(static_cast(devices->object(0))) + ?: NS::TransferPtr(MTL::CreateSystemDefaultDevice()); if (!device) { throw std::runtime_error("Failed to load device"); } return device; } + std::pair load_library_from_path( MTL::Device* device, const char* path) { @@ -443,7 +444,7 @@ MTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() { Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); - default_library_ = load_default_library(device_); + default_library_ = NS::TransferPtr(load_default_library(device_.get())); arch_ = env::metal_gpu_arch(); if (arch_.empty()) { arch_ = std::string(device_->architecture()->name()->utf8String()); @@ -484,17 +485,7 @@ Device::Device() { max_mb_per_buffer_ = env::max_mb_per_buffer(max_mb_per_buffer_); } -Device::~Device() { - auto pool = new_scoped_memory_pool(); - for (auto& [l, kernel_map] : library_kernels_) { - l->release(); - for (auto& [_, k] : kernel_map) { - k->release(); - } - } - encoders_.clear(); - device_->release(); -} +Device::~Device() = default; bool Device::command_buffer_needs_commit(int index) { return get_command_encoder(index).needs_commit(); @@ -534,28 +525,29 @@ MTL::Library* Device::get_library( { std::shared_lock rlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { - return it->second; + return it->second.get(); } } std::unique_lock wlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { - return it->second; + return it->second.get(); } - auto new_lib = load_library(device_, name, path.c_str()); - library_map_.insert({name, new_lib}); + auto new_lib = load_library(device_.get(), name, path.c_str()); + library_map_.insert({name, NS::TransferPtr(new_lib)}); return new_lib; } -MTL::Library* Device::build_library_(const std::string& source_string) { +NS::SharedPtr Device::build_library_( + const std::string& source_string) { auto pool = new_scoped_memory_pool(); auto ns_code = NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding); NS::Error* error = nullptr; - auto options = MTL::CompileOptions::alloc()->init(); + auto options = MTL::CompileOptions::alloc()->init()->autorelease(); options->setFastMathEnabled(false); options->setLanguageVersion(get_metal_version()); #ifndef NDEBUG @@ -563,8 +555,7 @@ MTL::Library* Device::build_library_(const std::string& source_string) { options->setEnableLogging(true); } #endif - auto mtl_lib = device_->newLibrary(ns_code, options, &error); - options->release(); + auto mtl_lib = NS::TransferPtr(device_->newLibrary(ns_code, options, &error)); // Throw error if unable to compile library if (!mtl_lib) { @@ -579,17 +570,16 @@ MTL::Library* Device::build_library_(const std::string& source_string) { return mtl_lib; } -MTL::Function* Device::get_function_( +NS::SharedPtr Device::get_function_( const std::string& name, MTL::Library* mtl_lib) { + auto pool = new_scoped_memory_pool(); // Pull kernel from library auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding); - auto mtl_function = mtl_lib->newFunction(ns_name); - - return mtl_function; + return NS::TransferPtr(mtl_lib->newFunction(ns_name)); } -MTL::Function* Device::get_function_( +NS::SharedPtr Device::get_function_( const std::string& name, const std::string& specialized_name, const MTLFCList& func_consts, @@ -598,8 +588,11 @@ MTL::Function* Device::get_function_( return get_function_(name, mtl_lib); } + auto pool = new_scoped_memory_pool(); + // Prepare function constants - auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init(); + auto mtl_func_consts = + MTL::FunctionConstantValues::alloc()->init()->autorelease(); for (auto [value, type, index] : func_consts) { mtl_func_consts->setConstantValue(value, type, index); @@ -614,7 +607,7 @@ MTL::Function* Device::get_function_( // Pull kernel from library NS::Error* error = nullptr; - auto mtl_function = mtl_lib->newFunction(desc, &error); + auto mtl_function = NS::TransferPtr(mtl_lib->newFunction(desc, &error)); // Throw error if unable to build metal function if (!mtl_function) { @@ -626,20 +619,19 @@ MTL::Function* Device::get_function_( throw std::runtime_error(msg.str()); } - mtl_func_consts->release(); - return mtl_function; } -MTL::ComputePipelineState* Device::get_kernel_( +NS::SharedPtr Device::get_kernel_( const std::string& name, const MTL::Function* mtl_function) { // Compile kernel to compute pipeline NS::Error* error = nullptr; - MTL::ComputePipelineState* kernel; + NS::SharedPtr kernel; if (mtl_function) { - kernel = device_->newComputePipelineState(mtl_function, &error); + kernel = + NS::TransferPtr(device_->newComputePipelineState(mtl_function, &error)); } // Throw error if unable to compile metal function @@ -655,7 +647,7 @@ MTL::ComputePipelineState* Device::get_kernel_( return kernel; } -MTL::ComputePipelineState* Device::get_kernel_( +NS::SharedPtr Device::get_kernel_( const std::string& name, const MTL::Function* mtl_function, const MTL::LinkedFunctions* linked_functions) { @@ -670,15 +662,17 @@ MTL::ComputePipelineState* Device::get_kernel_( throw std::runtime_error(msg.str()); } + auto pool = new_scoped_memory_pool(); + // Prepare compute pipeline state descriptor - auto desc = MTL::ComputePipelineDescriptor::alloc()->init(); + auto desc = MTL::ComputePipelineDescriptor::alloc()->init()->autorelease(); desc->setComputeFunction(mtl_function); desc->setLinkedFunctions(linked_functions); // Compile kernel to compute pipeline NS::Error* error = nullptr; - auto kernel = device_->newComputePipelineState( - desc, MTL::PipelineOptionNone, nullptr, &error); + auto kernel = NS::TransferPtr(device_->newComputePipelineState( + desc, MTL::PipelineOptionNone, nullptr, &error)); // Throw error if unable to compile metal function if (!kernel) { @@ -693,62 +687,45 @@ MTL::ComputePipelineState* Device::get_kernel_( return kernel; } -MTL::Library* Device::get_library_(const std::string& name) { - std::shared_lock lock(library_mtx_); - auto it = library_map_.find(name); - return (it != library_map_.end()) ? it->second : nullptr; -} - MTL::Library* Device::get_library( const std::string& name, const std::function& builder) { { std::shared_lock rlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { - return it->second; + return it->second.get(); } } std::unique_lock wlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { - return it->second; + return it->second.get(); } auto mtl_lib = build_library_(builder()); library_map_.insert({name, mtl_lib}); - return mtl_lib; + return mtl_lib.get(); } void Device::clear_library(const std::string& name) { std::unique_lock wlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { - auto kernel_map_it = library_kernels_.find(it->second); - for (auto& [_, kernel] : kernel_map_it->second) { - kernel->release(); - } - library_kernels_.erase(kernel_map_it); - it->second->release(); + library_kernels_.erase(it->second.get()); library_map_.erase(it); } } -MTL::LinkedFunctions* Device::get_linked_functions_( +NS::SharedPtr Device::get_linked_functions_( const std::vector& funcs) { if (funcs.empty()) { return nullptr; } - auto lfuncs = MTL::LinkedFunctions::linkedFunctions(); - - std::vector objs(funcs.size()); - for (int i = 0; i < funcs.size(); i++) { - objs[i] = funcs[i]; - } - - NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size()); - + auto pool = new_scoped_memory_pool(); + auto lfuncs = NS::TransferPtr(MTL::LinkedFunctions::linkedFunctions()); + NS::Array* funcs_arr = NS::Array::array( + reinterpret_cast(funcs.data()), funcs.size()); lfuncs->setPrivateFunctions(funcs_arr); - return lfuncs; } @@ -764,7 +741,7 @@ MTL::ComputePipelineState* Device::get_kernel_( // Try loading again to avoid loading twice auto& kernel_map_ = library_kernels_[mtl_lib]; if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) { - return it->second; + return it->second.get(); } auto pool = new_scoped_memory_pool(); @@ -774,15 +751,13 @@ MTL::ComputePipelineState* Device::get_kernel_( // Compile kernel to compute pipeline auto mtl_linked_funcs = get_linked_functions_(linked_functions); - auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs); - - mtl_function->release(); - mtl_linked_funcs->release(); + auto kernel = + get_kernel_(hash_name, mtl_function.get(), mtl_linked_funcs.get()); // Add kernel to cache kernel_map_.insert({hash_name, kernel}); - return kernel; + return kernel.get(); } MTL::ComputePipelineState* Device::get_kernel( @@ -799,7 +774,7 @@ MTL::ComputePipelineState* Device::get_kernel( // Look for cached kernel auto& kernel_map_ = library_kernels_[mtl_lib]; if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { - return it->second; + return it->second.get(); } } return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions); @@ -811,7 +786,11 @@ MTL::ComputePipelineState* Device::get_kernel( const MTLFCList& func_consts /* = {} */, const std::vector& linked_functions /* = {} */) { return get_kernel( - base_name, default_library_, hash_name, func_consts, linked_functions); + base_name, + default_library_.get(), + hash_name, + func_consts, + linked_functions); } void Device::set_residency_set(const MTL::ResidencySet* residency_set) { @@ -837,12 +816,29 @@ Device& device(mlx::core::Device) { return *metal_device; } -std::unique_ptr> new_scoped_memory_pool() { - auto dtor = [](void* ptr) { - static_cast(ptr)->release(); +NS::SharedPtr new_scoped_memory_pool() { + return NS::TransferPtr(NS::AutoreleasePool::alloc()->init()); +} + +bool is_nax_available() { +#ifdef MLX_METAL_NO_NAX + return false; +#else + auto _check_nax = []() { + bool can_use_nax = false; + if (__builtin_available( + macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + can_use_nax = true; + } + auto& d = metal::device(mlx::core::Device::gpu); + auto arch = d.get_architecture().back(); + auto gen = d.get_architecture_gen(); + can_use_nax &= gen >= (arch == 'p' ? 18 : 17); + return can_use_nax; }; - return std::unique_ptr>( - NS::AutoreleasePool::alloc()->init(), dtor); + static bool is_nax_available_ = _check_nax(); + return is_nax_available_; +#endif } } // namespace mlx::core::metal diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 0e705a6560..4889887e2e 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -135,7 +135,7 @@ class MLX_API Device { ~Device(); MTL::Device* mtl_device() { - return device_; + return device_.get(); }; const std::string& get_architecture() const { @@ -184,27 +184,24 @@ class MLX_API Device { void set_residency_set(const MTL::ResidencySet* residency_set); private: - MTL::Library* get_library_cache_(const std::string& name); + NS::SharedPtr build_library_(const std::string& source_string); - MTL::Library* get_library_(const std::string& name); - MTL::Library* build_library_(const std::string& source_string); - - MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib); - - MTL::Function* get_function_( + NS::SharedPtr get_function_( + const std::string& name, + MTL::Library* mtl_lib); + NS::SharedPtr get_function_( const std::string& name, const std::string& specialized_name, const MTLFCList& func_consts, MTL::Library* mtl_lib); - MTL::LinkedFunctions* get_linked_functions_( + NS::SharedPtr get_linked_functions_( const std::vector& funcs); - MTL::ComputePipelineState* get_kernel_( + NS::SharedPtr get_kernel_( const std::string& name, const MTL::Function* mtl_function); - - MTL::ComputePipelineState* get_kernel_( + NS::SharedPtr get_kernel_( const std::string& name, const MTL::Function* mtl_function, const MTL::LinkedFunctions* linked_functions); @@ -216,16 +213,16 @@ class MLX_API Device { const MTLFCList& func_consts = {}, const std::vector& linked_functions = {}); - MTL::Device* device_; + NS::SharedPtr device_; std::unordered_map encoders_; std::shared_mutex kernel_mtx_; std::shared_mutex library_mtx_; - std::unordered_map library_map_; - MTL::Library* default_library_; + std::unordered_map> library_map_; + NS::SharedPtr default_library_; std::unordered_map< MTL::Library*, - std::unordered_map> + std::unordered_map>> library_kernels_; const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; @@ -236,27 +233,8 @@ class MLX_API Device { MLX_API Device& device(mlx::core::Device); -std::unique_ptr> new_scoped_memory_pool(); - -inline bool is_nax_available() { -#ifdef MLX_METAL_NO_NAX - return false; -#else - auto _check_nax = []() { - bool can_use_nax = false; - if (__builtin_available( - macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { - can_use_nax = true; - } - auto& d = metal::device(mlx::core::Device::gpu); - auto arch = d.get_architecture().back(); - auto gen = d.get_architecture_gen(); - can_use_nax &= gen >= (arch == 'p' ? 18 : 17); - return can_use_nax; - }; - static bool is_nax_available_ = _check_nax(); - return is_nax_available_; -#endif -} +NS::SharedPtr new_scoped_memory_pool(); + +bool is_nax_available(); } // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 51bd2e62ff..6bf3b895b9 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -14,7 +14,7 @@ bool is_available() { void start_capture(std::string path, NS::Object* object) { auto pool = new_scoped_memory_pool(); - auto descriptor = MTL::CaptureDescriptor::alloc()->init(); + auto descriptor = MTL::CaptureDescriptor::alloc()->init()->autorelease(); descriptor->setCaptureObject(object); if (!path.empty()) { @@ -27,7 +27,6 @@ void start_capture(std::string path, NS::Object* object) { auto manager = MTL::CaptureManager::sharedCaptureManager(); NS::Error* error; bool started = manager->startCapture(descriptor, &error); - descriptor->release(); if (!started) { std::ostringstream msg; msg << "[metal::start_capture] Failed to start: " diff --git a/mlx/backend/metal/resident.cpp b/mlx/backend/metal/resident.cpp index 798824c2fb..97187b05b9 100644 --- a/mlx/backend/metal/resident.cpp +++ b/mlx/backend/metal/resident.cpp @@ -9,10 +9,9 @@ ResidencySet::ResidencySet(MTL::Device* d) { return; } else if (__builtin_available(macOS 15, iOS 18, *)) { auto pool = new_scoped_memory_pool(); - auto desc = MTL::ResidencySetDescriptor::alloc()->init(); + auto desc = MTL::ResidencySetDescriptor::alloc()->init()->autorelease(); NS::Error* error; - wired_set_ = d->newResidencySet(desc, &error); - desc->release(); + wired_set_ = NS::TransferPtr(d->newResidencySet(desc, &error)); if (!wired_set_) { std::ostringstream msg; msg << "[metal::Device] Unable to construct residency set.\n"; @@ -90,11 +89,6 @@ void ResidencySet::resize(size_t size) { } } -ResidencySet::~ResidencySet() { - if (wired_set_) { - auto pool = new_scoped_memory_pool(); - wired_set_->release(); - } -} +ResidencySet::~ResidencySet() = default; } // namespace mlx::core::metal diff --git a/mlx/backend/metal/resident.h b/mlx/backend/metal/resident.h index 5db5582863..9961d722d7 100644 --- a/mlx/backend/metal/resident.h +++ b/mlx/backend/metal/resident.h @@ -15,7 +15,7 @@ class ResidencySet { ResidencySet& operator=(const ResidencySet&) = delete; const MTL::ResidencySet* mtl_residency_set() { - return wired_set_; + return wired_set_.get(); } void insert(MTL::Allocation* buf); @@ -24,7 +24,7 @@ class ResidencySet { void resize(size_t size); private: - MTL::ResidencySet* wired_set_{nullptr}; + NS::SharedPtr wired_set_; std::unordered_set unwired_set_; size_t capacity_{0}; };