Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions mlx/backend/cuda/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ std::string build_kernel(
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<CustomKernelShapeInfo>& shape_infos) {
const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 8192);
kernel_source += default_header;
Expand All @@ -81,17 +81,17 @@ std::string build_kernel(
kernel_source += ",\n";
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
if (std::get<0>(shape_infos[i])) {
kernel_source += " const __grid_constant__ Shape ";
kernel_source += name;
kernel_source += "_shape,\n";
}
if (shape_infos[i].strides) {
if (std::get<1>(shape_infos[i])) {
kernel_source += " const __grid_constant__ Strides ";
kernel_source += name;
kernel_source += "_strides,\n";
}
if (shape_infos[i].ndim) {
if (std::get<2>(shape_infos[i])) {
kernel_source += " const __grid_constant__ int ";
kernel_source += name;
kernel_source += "_ndim,\n";
Expand Down Expand Up @@ -154,12 +154,12 @@ CustomKernelFunction cuda_kernel(
"[custom_kernel] Must specify at least one output.");
}

std::vector<CustomKernelShapeInfo> shape_infos;
std::vector<std::tuple<bool, bool, bool>> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
std::tuple<bool, bool, bool> shape_info;
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}

Expand Down Expand Up @@ -254,8 +254,8 @@ std::vector<array> precompiled_cuda_kernel(
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false});
std::vector<std::tuple<bool, bool, bool>> shape_infos(
inputs.size(), {false, false, false});
return array::make_arrays(
output_shapes,
output_dtypes,
Expand Down Expand Up @@ -327,13 +327,13 @@ void CustomKernel::eval_gpu(
const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i];
args.append(in);
if (shape_info.shape) {
if (std::get<0>(shape_info)) {
args.append_ndim(in.shape());
}
if (shape_info.strides) {
if (std::get<1>(shape_info)) {
args.append_ndim(in.strides());
}
if (shape_info.ndim) {
if (std::get<2>(shape_info)) {
args.append<int32_t>(in.ndim());
}
}
Expand Down
24 changes: 12 additions & 12 deletions mlx/backend/metal/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::string write_signature(
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<std::string>& attributes,
const std::vector<CustomKernelShapeInfo>& shape_infos,
const std::vector<std::tuple<bool, bool, bool>>& shape_infos,
bool atomic_outputs) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 16384);
Expand Down Expand Up @@ -88,19 +88,19 @@ std::string write_signature(
index++;
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
if (std::get<0>(shape_infos[i])) {
kernel_source +=
(" const constant int* " + name + "_shape [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].strides) {
if (std::get<1>(shape_infos[i])) {
kernel_source +=
(" const constant int64_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].ndim) {
if (std::get<2>(shape_infos[i])) {
kernel_source +=
(" const constant int& " + name + "_ndim [[buffer(" +
std::to_string(index) + ")]],\n");
Expand Down Expand Up @@ -184,12 +184,12 @@ CustomKernelFunction metal_kernel(
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
std::vector<std::tuple<bool, bool, bool>> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
std::tuple<bool, bool, bool> shape_info;
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
Expand Down Expand Up @@ -388,15 +388,15 @@ void CustomKernel::eval_gpu(
index++;
if (in.ndim() > 0) {
int ndim = in.ndim();
if (shape_info.shape) {
if (std::get<0>(shape_info)) {
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
index++;
}
if (shape_info.strides) {
if (std::get<1>(shape_info)) {
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
index++;
}
if (shape_info.ndim) {
if (std::get<2>(shape_info)) {
compute_encoder.set_bytes(ndim, index);
index++;
}
Expand Down
71 changes: 70 additions & 1 deletion mlx/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ constexpr bool is_pair = is_specialization_of<std::pair, std::decay_t<T>>;
template <typename T>
constexpr bool is_tuple = is_specialization_of<std::tuple, std::decay_t<T>>;

template <typename T>
inline constexpr bool is_optional =
is_specialization_of<std::optional, std::decay_t<T>>;

template <typename T>
inline constexpr bool is_variant =
is_specialization_of<std::variant, std::decay_t<T>>;

template <typename>
constexpr bool dependent_false = false;

Expand All @@ -96,6 +104,12 @@ void reverse_bytes(T& data) {
}
}

template <typename T>
void serialize_variant(Writer& os, T v);

template <typename T>
T deserialize_variant(Reader& is);

template <typename T>
void serialize(Writer& os, T v) {
if constexpr (std::is_arithmetic_v<T>) {
Expand All @@ -113,6 +127,13 @@ void serialize(Writer& os, T v) {
}
} else if constexpr (is_pair<T> || is_tuple<T>) {
std::apply([&os](auto&... x) { (..., serialize(os, x)); }, v);
} else if constexpr (is_variant<T>) {
serialize_variant(os, v);
} else if constexpr (is_optional<T>) {
serialize(os, v.has_value());
if (v.has_value()) {
serialize(os, *v);
}
} else {
NotSerializable<T>();
}
Expand Down Expand Up @@ -145,11 +166,58 @@ T deserialize(Reader& is) {
} else if constexpr (is_pair<T> || is_tuple<T>) {
return deserialize_tuple<T>(
is, std::make_index_sequence<std::tuple_size_v<std::decay_t<T>>>{});
} else if constexpr (is_optional<T>) {
auto has_value = deserialize<bool>(is);
if (has_value) {
return deserialize<T>(is);
} else {
return std::nullopt;
}
} else if constexpr (is_variant<T>) {
return deserialize_variant<T>(is);
} else {
NotDeserializable<T>();
}
}

enum class VariantType { Int = 0, Float = 1, Bool = 2 };

template <typename T>
void serialize_variant(Writer& os, T v) {
std::visit(
[&](auto&& x) {
using ElemT = std::decay_t<decltype(x)>;
if constexpr (std::is_same_v<ElemT, int>) {
serialize(os, VariantType::Int);
} else if constexpr (std::is_same_v<ElemT, float>) {
serialize(os, VariantType::Float);
} else if constexpr (std::is_same_v<ElemT, bool>) {
serialize(os, VariantType::Bool);
} else {
static_assert(
std::is_same_v<ElemT, void>, "Can't serialize variant type.");
}
serialize(os, x);
},
v);
}

template <typename T>
T deserialize_variant(Reader& is) {
auto vt = deserialize<VariantType>(is);
switch (vt) {
case VariantType::Int:
return deserialize<int>(is);
case VariantType::Float:
return deserialize<float>(is);
case VariantType::Bool:
return deserialize<bool>(is);
default:
throw std::runtime_error(
"[deserialize_variant] Unknonw variant type tag.");
}
}

template <typename T, std::size_t... I>
decltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>) {
return T{deserialize<std::tuple_element_t<I, T>>(is)...};
Expand Down Expand Up @@ -374,7 +442,8 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(LayerNorm),
SERIALIZE_PRIMITIVE(LayerNormVJP),
SERIALIZE_PRIMITIVE(RoPE),
SERIALIZE_PRIMITIVE(ScaledDotProductAttention)};
SERIALIZE_PRIMITIVE(ScaledDotProductAttention),
SERIALIZE_PRIMITIVE(CustomKernel)};
std::unordered_map<std::string, std::string> name_remap;

PrimitiveFactory() {
Expand Down
4 changes: 4 additions & 0 deletions mlx/export.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#pragma once

#include <optional>
#include <set>
#include <unordered_map>
#include <variant>
Expand All @@ -24,6 +25,9 @@ using StateT = std::variant<
Strides,
std::vector<int>,
std::vector<size_t>,
std::vector<std::tuple<bool, bool, bool>>,
std::vector<std::variant<bool, int, float>>,
std::optional<float>,
std::string>;

using ExportCallbackInput = std::unordered_map<
Expand Down
27 changes: 17 additions & 10 deletions mlx/fast_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,6 @@ class Quantize : public Custom {
bool dequantize_;
};

struct CustomKernelShapeInfo {
bool shape = false;
bool strides = false;
bool ndim = false;
};

using ScalarArg = std::variant<bool, int, float>;

class CustomKernel : public Primitive {
Expand All @@ -331,15 +325,15 @@ class CustomKernel : public Primitive {
std::string source,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::vector<CustomKernelShapeInfo> shape_infos,
std::vector<std::tuple<bool, bool, bool>> shape_infos,
bool ensure_row_contiguous,
std::optional<float> init_value,
std::vector<ScalarArg> scalar_arguments,
bool is_precompiled,
int shared_memory)
: Primitive(stream),
source_(std::move(source)),
name_(std::move(name)),
source_(std::move(source)),
grid_(grid),
threadgroup_(threadgroup),
shape_infos_(std::move(shape_infos)),
Expand All @@ -358,13 +352,26 @@ class CustomKernel : public Primitive {
override;

DEFINE_NAME(CustomKernel);
auto state() const {
return std::make_tuple(
name_,
source_,
grid_,
threadgroup_,
shape_infos_,
ensure_row_contiguous_,
init_value_,
scalar_arguments_,
is_precompiled_,
shared_memory_);
}

private:
std::string source_;
std::string name_;
std::string source_;
std::tuple<int, int, int> grid_;
std::tuple<int, int, int> threadgroup_;
std::vector<CustomKernelShapeInfo> shape_infos_;
std::vector<std::tuple<bool, bool, bool>> shape_infos_;
bool ensure_row_contiguous_;
std::optional<float> init_value_;
std::vector<ScalarArg> scalar_arguments_;
Expand Down
44 changes: 44 additions & 0 deletions python/tests/test_export_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,50 @@ def callback(args):
self.assertEqual(keywords[0][0], "y")
self.assertEqual(primitives, ["Subtract", "Abs", "Log"])

@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_export_import_custom_kernel(self):
if mx.metal.is_available():
source = """
uint elem = thread_position_in_grid.x;
out1[elem] = a[elem];
"""
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
out1[elem] = a[elem];
"""
custom_kernel = mx.fast.cuda_kernel

kernel = custom_kernel(
name="basic",
input_names=["a"],
output_names=["out1"],
source=source,
)

def call(a):
return kernel(
inputs=[a],
grid=(4, 1, 1),
threadgroup=(2, 1, 1),
output_shapes=[(2, 2)],
output_dtypes=[mx.float32],
stream=mx.gpu,
)[0]

mx.random.seed(7)
a = mx.random.normal(shape=(2, 2))

path = os.path.join(self.test_dir, "fn.mlxfn")
expected = call(a)
mx.export_function(path, call, a)

imported = mx.import_function(path)

out = imported(a)[0]
self.assertTrue(mx.allclose(expected, out))


if __name__ == "__main__":
mlx_tests.MLXTestRunner()