Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/src/dev/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ below.
auto kernel = d.get_kernel(kname, lib);

// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = mx::metal::get_command_encoder(s);
compute_encoder.set_compute_pipeline_state(kernel);

// Kernel parameters are registered with buffer indices corresponding to
Expand Down Expand Up @@ -448,7 +448,7 @@ We can now call the :meth:`axpby` operation on both the CPU and the GPU!

A few things to note about MLX and Metal before moving on. MLX keeps track of
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
associated. We rely on :meth:`d.get_command_encoder` to give us the active
associated. We rely on :meth:`metal::get_command_encoder` to give us the active
metal compute command encoder instead of building a new one and calling
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
pipelines) to the active command buffer until some specified limit is hit or
Expand Down
2 changes: 1 addition & 1 deletion examples/extensions/axpby/axpby.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ void Axpby::eval_gpu(
auto kernel = d.get_kernel(kname, lib);

// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = mx::metal::get_command_encoder(s);
compute_encoder.set_compute_pipeline_state(kernel);

// Kernel parameters are registered with buffer indices corresponding to
Expand Down
3 changes: 1 addition & 2 deletions mlx/backend/cuda/quantized/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ void fast::Quantize::eval_gpu(
std::vector<array>& outputs) {
nvtx3::scoped_range r("Quantize::eval_gpu");
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
auto& enc = cu::get_command_encoder(s);
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
Expand Down
13 changes: 6 additions & 7 deletions mlx/backend/metal/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ void* Buffer::raw_ptr() {

namespace metal {

MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
MetalAllocator::MetalAllocator(Device& d)
: device_(d.mtl_device()),
residency_set_(d.residency_set()),
buffer_cache_(
vm_page_size,
[](MTL::Buffer* buf) { return buf->length(); },
Expand All @@ -42,8 +43,7 @@ MetalAllocator::MetalAllocator()
}
auto pool = metal::new_scoped_memory_pool();
buf->release();
}),
residency_set_(device_) {
}) {
const auto& info = gpu::device_info(0);
auto memsize = std::get<size_t>(info.at("memory_size"));
auto max_rec_size =
Expand All @@ -52,8 +52,6 @@ MetalAllocator::MetalAllocator()
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
bool is_vm = std::get<std::string>(info.at("device_name")) ==
"Apple Paravirtual device";
if (is_vm) {
Expand Down Expand Up @@ -226,7 +224,8 @@ MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static MetalAllocator* allocator_ = new MetalAllocator;
static MetalAllocator* allocator_ =
new MetalAllocator(device(mlx::core::Device::gpu));
return *allocator_;
}

Expand Down
5 changes: 2 additions & 3 deletions mlx/backend/metal/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/resident.h"

namespace mlx::core::metal {

Expand Down Expand Up @@ -52,13 +51,13 @@ class MetalAllocator : public allocator::Allocator {
static constexpr int small_size_ = 256;
static constexpr int heap_size_ = 1 << 20;

MetalAllocator();
MetalAllocator(Device& d);
~MetalAllocator();

friend MetalAllocator& allocator();

NS::SharedPtr<MTL::Heap> heap_;
ResidencySet residency_set_;
ResidencySet& residency_set_;

// Caching allocator
BufferCache<MTL::Buffer> buffer_cache_;
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/metal/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void binary_op_gpu_inplace(
auto kernel = outputs.size() == 2
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
compute_encoder.set_compute_pipeline_state(kernel);

int arg_idx = 0;
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/metal/compiled.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ void Compiled::eval_gpu(
kernel_name += "_large";
}
auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
compute_encoder.set_compute_pipeline_state(kernel);

// Put the inputs in
Expand Down
26 changes: 13 additions & 13 deletions mlx/backend/metal/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
return x;
}
auto result = contiguous_copy_gpu(x, s);
d.add_temporary(result, s.index);
metal::get_command_encoder(s).add_temporary(result);
return result;
}

Expand All @@ -52,7 +52,7 @@ void explicit_gemm_conv_ND_gpu(
std::string kname;
kname.reserve(32);
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);

Expand Down Expand Up @@ -132,7 +132,7 @@ void explicit_gemm_conv_group_ND_gpu(
kname.reserve(32);
concatenate(
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);

Expand Down Expand Up @@ -286,7 +286,7 @@ void implicit_gemm_conv_2D_gpu(
small_filter ? 's' : 'l');

// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
auto kernel = get_steel_conv_kernel(
d,
kname,
Expand Down Expand Up @@ -469,7 +469,7 @@ void implicit_gemm_conv_2D_general_gpu(
};

// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
auto kernel = get_steel_conv_general_kernel(
d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
compute_encoder.set_compute_pipeline_state(kernel);
Expand Down Expand Up @@ -595,7 +595,7 @@ void implicit_gemm_conv_3D_gpu(
small_filter ? 's' : 'l');

// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
auto kernel =
get_steel_conv_3d_kernel(d, kname, out, bm, bn, bk, wm, wn, small_filter);
compute_encoder.set_compute_pipeline_state(kernel);
Expand Down Expand Up @@ -644,7 +644,7 @@ void pad_and_slice_conv_3D_gpu(
array x_copy(xshape, x.dtype(), nullptr, {});
array zero(0, x.dtype());
pad_gpu(x, zero, x_copy, {0, -1}, {0, 0}, s);
d.add_temporary(x_copy, s.index);
metal::get_command_encoder(s).add_temporary(x_copy);

return x_copy;
};
Expand Down Expand Up @@ -804,7 +804,7 @@ void winograd_conv_2D_gpu(
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);

Expand Down Expand Up @@ -837,7 +837,7 @@ void winograd_conv_2D_gpu(
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);

Expand Down Expand Up @@ -889,7 +889,7 @@ void winograd_conv_2D_gpu(
type_to_name(out),
"_bo",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);

Expand Down Expand Up @@ -950,7 +950,7 @@ void depthwise_conv_2D_gpu(
"_tgp_w_", tw,
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on

auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);

Expand Down Expand Up @@ -1044,7 +1044,7 @@ void depthwise_conv_1D_gpu(
type_to_name(out),
large ? "_large" : "");

auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
auto kernel = d.get_kernel(base_name);
compute_encoder.set_compute_pipeline_state(kernel);

Expand Down Expand Up @@ -1348,7 +1348,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {

// Record copies
if (!copies.empty()) {
d.add_temporaries(std::move(copies), s.index);
metal::get_command_encoder(s).add_temporaries(std::move(copies));
}
}

Expand Down
4 changes: 2 additions & 2 deletions mlx/backend/metal/copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void copy_gpu_inplace(
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
: get_copy_kernel(d, kernel_name, in, out);

auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
compute_encoder.set_compute_pipeline_state(kernel);

inp_offset *= size_of(in.dtype());
Expand Down Expand Up @@ -190,7 +190,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s");
concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
compute_encoder.set_compute_pipeline_state(kernel);

compute_encoder.set_input_array(val, 0);
Expand Down
4 changes: 2 additions & 2 deletions mlx/backend/metal/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ void CustomKernel::eval_gpu(

auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = metal::get_command_encoder(s);
compute_encoder.set_compute_pipeline_state(kernel);
int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) {
Expand Down Expand Up @@ -424,7 +424,7 @@ void CustomKernel::eval_gpu(
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder.dispatch_threads(grid_dims, group_dims);

d.add_temporaries(std::move(copies), s.index);
compute_encoder.add_temporaries(std::move(copies));
}

} // namespace mlx::core::fast
68 changes: 16 additions & 52 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ auto get_metal_version() {
}

NS::SharedPtr<MTL::Device> load_device() {
auto pool = new_scoped_memory_pool();
auto devices = NS::TransferPtr(MTL::CopyAllDevices());
auto device = NS::RetainPtr(static_cast<MTL::Device*>(devices->object(0)))
?: NS::TransferPtr(MTL::CreateSystemDefaultDevice());
Expand Down Expand Up @@ -249,16 +250,16 @@ MTL::Library* load_library(
CommandEncoder::CommandEncoder(
Device& d,
int index,
const MTL::ResidencySet* residency_set)
ResidencySet& residency_set)
: device_(d) {
auto pool = new_scoped_memory_pool();
queue_ = NS::TransferPtr(device_.mtl_device()->newCommandQueue());
if (!queue_) {
throw std::runtime_error(
"[metal::CommandEncoder] Failed to make new command queue.");
}
if (residency_set) {
queue_->addResidencySet(residency_set);
if (residency_set.mtl_residency_set()) {
queue_->addResidencySet(residency_set.mtl_residency_set());
}
debug_set_stream_queue_label(queue_.get(), index);
buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences());
Expand Down Expand Up @@ -441,9 +442,8 @@ MTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() {
return encoder_.get();
}

Device::Device() {
Device::Device() : device_(load_device()), residency_set_(device_.get()) {
auto pool = new_scoped_memory_pool();
device_ = load_device();
default_library_ = NS::TransferPtr(load_default_library(device_.get()));
arch_ = env::metal_gpu_arch();
if (arch_.empty()) {
Expand Down Expand Up @@ -487,38 +487,6 @@ Device::Device() {

Device::~Device() = default;

bool Device::command_buffer_needs_commit(int index) {
return get_command_encoder(index).needs_commit();
}

MTL::CommandBuffer* Device::get_command_buffer(int index) {
return get_command_encoder(index).get_command_buffer();
}

void Device::commit_command_buffer(int index) {
get_command_encoder(index).commit();
}

void Device::add_temporary(array arr, int index) {
get_command_encoder(index).add_temporary(std::move(arr));
}

void Device::add_temporaries(std::vector<array> arrays, int index) {
get_command_encoder(index).add_temporaries(std::move(arrays));
}

void Device::end_encoding(int index) {
get_command_encoder(index).end_encoding();
}

CommandEncoder& Device::get_command_encoder(int index) {
auto it = encoders_.find(index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(index, *this, index, residency_set_).first;
}
return it->second;
}

MTL::Library* Device::get_library(
const std::string& name,
const std::string& path /* = "" */) {
Expand Down Expand Up @@ -793,21 +761,6 @@ MTL::ComputePipelineState* Device::get_kernel(
linked_functions);
}

void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
if (residency_set_ != nullptr) {
throw std::runtime_error(
"[Device::set_residency_set] Can only be set once.");
}
if (residency_set == nullptr) {
return;
}
residency_set_ = residency_set;
// Attach residency set to existing command queues
for (auto& [_, encoder] : encoders_) {
encoder.get_command_queue()->addResidencySet(residency_set_);
}
}

Device& device(mlx::core::Device) {
// Leak singleton device intentionally, to avoid cases where a compute kernel
// returns and tries to access the object after it has been freed by the main
Expand All @@ -816,6 +769,17 @@ Device& device(mlx::core::Device) {
return *metal_device;
}

CommandEncoder& get_command_encoder(Stream s) {
// Leak the command encoders for the same reason with device.
static auto* encoders = new std::unordered_map<int, CommandEncoder>;
auto it = encoders->find(s.index);
if (it == encoders->end()) {
auto& d = device(s.device);
it = encoders->try_emplace(s.index, d, s.index, d.residency_set()).first;
}
return it->second;
}

NS::SharedPtr<NS::AutoreleasePool> new_scoped_memory_pool() {
return NS::TransferPtr(NS::AutoreleasePool::alloc()->init());
}
Expand Down
Loading
Loading