diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index 9585167463..226e3306c3 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -57,7 +57,7 @@ std::string build_kernel( const std::vector& output_names, const std::vector& output_dtypes, const std::vector>& template_args, - const std::vector& shape_infos) { + const std::vector>& shape_infos) { std::string kernel_source; kernel_source.reserve(header.size() + source.size() + 8192); kernel_source += default_header; @@ -81,17 +81,17 @@ std::string build_kernel( kernel_source += ",\n"; // Add input shape, strides and ndim if present in the source if (arr.ndim() > 0) { - if (shape_infos[i].shape) { + if (std::get<0>(shape_infos[i])) { kernel_source += " const __grid_constant__ Shape "; kernel_source += name; kernel_source += "_shape,\n"; } - if (shape_infos[i].strides) { + if (std::get<1>(shape_infos[i])) { kernel_source += " const __grid_constant__ Strides "; kernel_source += name; kernel_source += "_strides,\n"; } - if (shape_infos[i].ndim) { + if (std::get<2>(shape_infos[i])) { kernel_source += " const __grid_constant__ int "; kernel_source += name; kernel_source += "_ndim,\n"; @@ -154,12 +154,12 @@ CustomKernelFunction cuda_kernel( "[custom_kernel] Must specify at least one output."); } - std::vector shape_infos; + std::vector> shape_infos; for (auto& n : input_names) { - CustomKernelShapeInfo shape_info; - shape_info.shape = source.find(n + "_shape") != std::string::npos; - shape_info.strides = source.find(n + "_strides") != std::string::npos; - shape_info.ndim = source.find(n + "_ndim") != std::string::npos; + std::tuple shape_info; + std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos; + std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos; + std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos; shape_infos.push_back(shape_info); } @@ -254,8 +254,8 @@ std::vector precompiled_cuda_kernel( std::optional init_value, bool ensure_row_contiguous, StreamOrDevice s) { - std::vector shape_infos( - inputs.size(), CustomKernelShapeInfo{false, false, false}); + std::vector> shape_infos( + inputs.size(), {false, false, false}); return array::make_arrays( output_shapes, output_dtypes, @@ -327,13 +327,13 @@ void CustomKernel::eval_gpu( const array& in = checked_inputs[i]; auto& shape_info = shape_infos_[i]; args.append(in); - if (shape_info.shape) { + if (std::get<0>(shape_info)) { args.append_ndim(in.shape()); } - if (shape_info.strides) { + if (std::get<1>(shape_info)) { args.append_ndim(in.strides()); } - if (shape_info.ndim) { + if (std::get<2>(shape_info)) { args.append(in.ndim()); } } diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index c48b93c913..eec6645bdd 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -32,7 +32,7 @@ std::string write_signature( const std::vector& output_dtypes, const std::vector>& template_args, const std::vector& attributes, - const std::vector& shape_infos, + const std::vector>& shape_infos, bool atomic_outputs) { std::string kernel_source; kernel_source.reserve(header.size() + source.size() + 16384); @@ -88,19 +88,19 @@ std::string write_signature( index++; // Add input shape, strides and ndim if present in the source if (arr.ndim() > 0) { - if (shape_infos[i].shape) { + if (std::get<0>(shape_infos[i])) { kernel_source += (" const constant int* " + name + "_shape [[buffer(" + std::to_string(index) + ")]],\n"); index++; } - if (shape_infos[i].strides) { + if (std::get<1>(shape_infos[i])) { kernel_source += (" const constant int64_t* " + name + "_strides [[buffer(" + std::to_string(index) + ")]],\n"); index++; } - if (shape_infos[i].ndim) { + if (std::get<2>(shape_infos[i])) { kernel_source += (" const constant int& " + name + "_ndim [[buffer(" + std::to_string(index) + ")]],\n"); @@ -184,12 +184,12 @@ CustomKernelFunction metal_kernel( throw std::invalid_argument( "[metal_kernel] Must specify at least one output."); } - std::vector shape_infos; + std::vector> shape_infos; for (auto& n : input_names) { - CustomKernelShapeInfo shape_info; - shape_info.shape = source.find(n + "_shape") != std::string::npos; - shape_info.strides = source.find(n + "_strides") != std::string::npos; - shape_info.ndim = source.find(n + "_ndim") != std::string::npos; + std::tuple shape_info; + std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos; + std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos; + std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos; shape_infos.push_back(shape_info); } const std::vector> metal_attributes = { @@ -388,15 +388,15 @@ void CustomKernel::eval_gpu( index++; if (in.ndim() > 0) { int ndim = in.ndim(); - if (shape_info.shape) { + if (std::get<0>(shape_info)) { compute_encoder.set_vector_bytes(in.shape(), ndim, index); index++; } - if (shape_info.strides) { + if (std::get<1>(shape_info)) { compute_encoder.set_vector_bytes(in.strides(), ndim, index); index++; } - if (shape_info.ndim) { + if (std::get<2>(shape_info)) { compute_encoder.set_bytes(ndim, index); index++; } diff --git a/mlx/export.cpp b/mlx/export.cpp index 3448178e24..972febc2b2 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -75,6 +75,14 @@ constexpr bool is_pair = is_specialization_of>; template constexpr bool is_tuple = is_specialization_of>; +template +inline constexpr bool is_optional = + is_specialization_of>; + +template +inline constexpr bool is_variant = + is_specialization_of>; + template constexpr bool dependent_false = false; @@ -96,6 +104,12 @@ void reverse_bytes(T& data) { } } +template +void serialize_variant(Writer& os, T v); + +template +T deserialize_variant(Reader& is); + template void serialize(Writer& os, T v) { if constexpr (std::is_arithmetic_v) { @@ -113,6 +127,13 @@ void serialize(Writer& os, T v) { } } else if constexpr (is_pair || is_tuple) { std::apply([&os](auto&... x) { (..., serialize(os, x)); }, v); + } else if constexpr (is_variant) { + serialize_variant(os, v); + } else if constexpr (is_optional) { + serialize(os, v.has_value()); + if (v.has_value()) { + serialize(os, *v); + } } else { NotSerializable(); } @@ -145,11 +166,58 @@ T deserialize(Reader& is) { } else if constexpr (is_pair || is_tuple) { return deserialize_tuple( is, std::make_index_sequence>>{}); + } else if constexpr (is_optional) { + auto has_value = deserialize(is); + if (has_value) { + return deserialize(is); + } else { + return std::nullopt; + } + } else if constexpr (is_variant) { + return deserialize_variant(is); } else { NotDeserializable(); } } +enum class VariantType { Int = 0, Float = 1, Bool = 2 }; + +template +void serialize_variant(Writer& os, T v) { + std::visit( + [&](auto&& x) { + using ElemT = std::decay_t; + if constexpr (std::is_same_v) { + serialize(os, VariantType::Int); + } else if constexpr (std::is_same_v) { + serialize(os, VariantType::Float); + } else if constexpr (std::is_same_v) { + serialize(os, VariantType::Bool); + } else { + static_assert( + std::is_same_v, "Can't serialize variant type."); + } + serialize(os, x); + }, + v); +} + +template +T deserialize_variant(Reader& is) { + auto vt = deserialize(is); + switch (vt) { + case VariantType::Int: + return deserialize(is); + case VariantType::Float: + return deserialize(is); + case VariantType::Bool: + return deserialize(is); + default: + throw std::runtime_error( + "[deserialize_variant] Unknonw variant type tag."); + } +} + template decltype(auto) deserialize_tuple(Reader& is, std::index_sequence) { return T{deserialize>(is)...}; @@ -374,7 +442,8 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(LayerNorm), SERIALIZE_PRIMITIVE(LayerNormVJP), SERIALIZE_PRIMITIVE(RoPE), - SERIALIZE_PRIMITIVE(ScaledDotProductAttention)}; + SERIALIZE_PRIMITIVE(ScaledDotProductAttention), + SERIALIZE_PRIMITIVE(CustomKernel)}; std::unordered_map name_remap; PrimitiveFactory() { diff --git a/mlx/export.h b/mlx/export.h index 715dac2c87..0a8e9fb0cc 100644 --- a/mlx/export.h +++ b/mlx/export.h @@ -2,6 +2,7 @@ #pragma once +#include #include #include #include @@ -24,6 +25,9 @@ using StateT = std::variant< Strides, std::vector, std::vector, + std::vector>, + std::vector>, + std::optional, std::string>; using ExportCallbackInput = std::unordered_map< diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 649e554e67..584860e526 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -315,12 +315,6 @@ class Quantize : public Custom { bool dequantize_; }; -struct CustomKernelShapeInfo { - bool shape = false; - bool strides = false; - bool ndim = false; -}; - using ScalarArg = std::variant; class CustomKernel : public Primitive { @@ -331,15 +325,15 @@ class CustomKernel : public Primitive { std::string source, std::tuple grid, std::tuple threadgroup, - std::vector shape_infos, + std::vector> shape_infos, bool ensure_row_contiguous, std::optional init_value, std::vector scalar_arguments, bool is_precompiled, int shared_memory) : Primitive(stream), - source_(std::move(source)), name_(std::move(name)), + source_(std::move(source)), grid_(grid), threadgroup_(threadgroup), shape_infos_(std::move(shape_infos)), @@ -358,13 +352,26 @@ class CustomKernel : public Primitive { override; DEFINE_NAME(CustomKernel); + auto state() const { + return std::make_tuple( + name_, + source_, + grid_, + threadgroup_, + shape_infos_, + ensure_row_contiguous_, + init_value_, + scalar_arguments_, + is_precompiled_, + shared_memory_); + } private: - std::string source_; std::string name_; + std::string source_; std::tuple grid_; std::tuple threadgroup_; - std::vector shape_infos_; + std::vector> shape_infos_; bool ensure_row_contiguous_; std::optional init_value_; std::vector scalar_arguments_; diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 4a4ca82a56..1aa251b6d9 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -531,6 +531,50 @@ def callback(args): self.assertEqual(keywords[0][0], "y") self.assertEqual(primitives, ["Subtract", "Abs", "Log"]) + @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available") + def test_export_import_custom_kernel(self): + if mx.metal.is_available(): + source = """ + uint elem = thread_position_in_grid.x; + out1[elem] = a[elem]; + """ + custom_kernel = mx.fast.metal_kernel + elif mx.cuda.is_available(): + source = """ + auto elem = cooperative_groups::this_grid().thread_rank(); + out1[elem] = a[elem]; + """ + custom_kernel = mx.fast.cuda_kernel + + kernel = custom_kernel( + name="basic", + input_names=["a"], + output_names=["out1"], + source=source, + ) + + def call(a): + return kernel( + inputs=[a], + grid=(4, 1, 1), + threadgroup=(2, 1, 1), + output_shapes=[(2, 2)], + output_dtypes=[mx.float32], + stream=mx.gpu, + )[0] + + mx.random.seed(7) + a = mx.random.normal(shape=(2, 2)) + + path = os.path.join(self.test_dir, "fn.mlxfn") + expected = call(a) + mx.export_function(path, call, a) + + imported = mx.import_function(path) + + out = imported(a)[0] + self.assertTrue(mx.allclose(expected, out)) + if __name__ == "__main__": mlx_tests.MLXTestRunner()