From e0220b409ad53b8247c88b28a3f38336a7c18fc7 Mon Sep 17 00:00:00 2001 From: BillJJ Date: Sun, 15 Mar 2026 23:05:19 -0700 Subject: [PATCH 1/2] codegen for binary ops --- .gitignore | 1 + cpp/include/common.h | 8 ++++ cpp/src/compiler.cpp | 93 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index dc4a36d..4eebf4a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ *venv/ *.so *.DS_Store +CLAUDE.md diff --git a/cpp/include/common.h b/cpp/include/common.h index 69d077a..b43c8fc 100644 --- a/cpp/include/common.h +++ b/cpp/include/common.h @@ -1,4 +1,5 @@ #pragma once +#include enum class OpCode : int { INPUT = 0, @@ -14,3 +15,10 @@ enum class OpCode : int { CONSTANT = 10, COPY = 11 }; + +const std::map op_symbol = { + {OpCode::ADD, "+"}, + {OpCode::SUB, "-"}, + {OpCode::MUL, "*"}, + {OpCode::DIV, "/"}, +}; diff --git a/cpp/src/compiler.cpp b/cpp/src/compiler.cpp index 5a60f43..c79a055 100644 --- a/cpp/src/compiler.cpp +++ b/cpp/src/compiler.cpp @@ -1,4 +1,7 @@ +#include +#include #include "../include/compiler.h" +#include "../include/common.h" std::vector optimize_graph(std::vector raw_nodes) { return raw_nodes; } // Could generate Fused Kernels, with special OpCodes @@ -13,7 +16,8 @@ std::vector optimize_graph(std::vector raw_nodes) { return raw_nodes // 5. Fusion, combine nodes into "blocks" that run in the same "way" (elementwise easiest) // 6. loop fusion. like two for i in range(100) can be put together -void generateKernels(Graph& graph) {} + + // Generates one huge string of all the kernel functions back to back // if the op requires no gpu kernel (INPUT, VIEW, etc -> call it "no op"), // we dont generate any string @@ -23,3 +27,90 @@ void generateKernels(Graph& graph) {} // Op metadata: Shape, strides (internal offset?) WONT be handled from execute() // This means the kernel strings we generate needs to bake in/hardcode the loops for that // kernels generated so that out buffer is idx 0, then the N inputs to the node (in order) +/* + // Node 1 (B): shape={2,3}, strides={3,1}, offset=0 + // Bake in as constants: + constant long shape[] = {2, 3}; + constant long strides_in0[] = {3, 1}; + constant long offset_in0 = 0; + + // Node 2 (C): shape={2,3}, strides={3,1}, offset=0 + constant long strides_in1[] = {3, 1}; + constant long offset_in1 = 0; + + // Compute strided read indices + uint remaining = gid; + uint idx_in0 = offset_in0; + uint idx_in1 = offset_in1; + for (int i = 1; i >= 0; --i) { + uint coord = remaining % shape[i]; + idx_in0 += coord * strides_in0[i]; + idx_in1 += coord * strides_in1[i]; + remaining /= shape[i]; + } + + out[gid] = in0[idx_in0] * in1[idx_in1]; + } +*/ +void generateKernels(Graph& graph) { + graph.shader_source = "#include \nusing namespace metal;\n"; + for (int64_t i = 0; i < graph.nodes.size(); i++) { + Node &node = graph.nodes[i]; + + if (node.op == OpCode::INPUT || node.op == OpCode::VIEW || node.op == OpCode::RESHAPE || node.op == OpCode::TRANSPOSE) { + // if Input, do a no-op + graph.configs.push_back({}); + continue; + } + + // function definition and inputs + std::ostringstream oss; + std::string kernel_name = "op_" + std::to_string(i); + oss << "kernel void " << kernel_name << " (\n\ + device float* out [[ buffer(0) ]],\n"; + for (int64_t j = 0; j < node.inputs.size(); j++) { + oss << "const device float* in" << j + << " [[ buffer(" + std::to_string(j + 1) + ") ]],\n"; + } + oss << "uint gid [[ thread_position_in_grid ]])\n{\n"; + + // hardcode constants: shape array + oss << "constant long shape[] = {"; + for (int64_t j = 0; j < node.shape.size(); j++) { + oss << node.shape[j] << ",}"[j == node.shape.size() - 1]; + } + oss << ";\n"; + + // hardcode constants: strides arrays + for (int64_t j = 0; j < node.inputs.size(); j++) { + oss << "constant long strides_in" << j << "[] = {"; + Node &input = graph.nodes[node.inputs[j]]; + for (int64_t k = 0; k < input.strides.size(); k++) { + oss << input.strides[k] << ",}"[k == input.strides.size() - 1]; + } + oss << ";\n"; + } + + // declare variables for computing strided read indices + oss << "uint remaining = gid;\n"; + for (int64_t j = 0; j < node.inputs.size(); j++) { + oss << "uint idx_in" << j << " = 0;\n"; // check that this is actually 0, is it actually zero??? Should be, I think + } + + // compute strided read indices + oss << "for (int i = " << node.shape.size() - 1 << "; i >= 0; i--) {\n"; + oss << "uint coord = remaining \% shape[i];\n"; + for (int64_t j = 0; j < node.inputs.size(); j++) { + oss << "idx_in" << j << " += coord * strides_in" << j << "[i];\n"; + } + oss << "remaining /= shape[i];\n}\n"; + + // note that I'm just putting the 4 binary ops for now. Need to do the others too but that's later + oss << "out[gid] = "; + oss << "in0[idx_in0] " << op_symbol.at(node.op) << " in1[idx_in1];\n}\n\n"; + + graph.shader_source += oss.str(); + uint64_t total_elements = numel_from_shape(node.shape); + graph.configs.push_back({.name = kernel_name, .grid = {total_elements, 1, 1}, .group = {256, 1, 1}}); + } +} From 0ed2319cdb52115afe6a5f4714352b0e4710262f Mon Sep 17 00:00:00 2001 From: BillJJ Date: Sun, 15 Mar 2026 23:11:03 -0700 Subject: [PATCH 2/2] delete comment --- cpp/src/compiler.cpp | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/cpp/src/compiler.cpp b/cpp/src/compiler.cpp index c79a055..c278b40 100644 --- a/cpp/src/compiler.cpp +++ b/cpp/src/compiler.cpp @@ -27,31 +27,8 @@ std::vector optimize_graph(std::vector raw_nodes) { return raw_nodes // Op metadata: Shape, strides (internal offset?) WONT be handled from execute() // This means the kernel strings we generate needs to bake in/hardcode the loops for that // kernels generated so that out buffer is idx 0, then the N inputs to the node (in order) -/* - // Node 1 (B): shape={2,3}, strides={3,1}, offset=0 - // Bake in as constants: - constant long shape[] = {2, 3}; - constant long strides_in0[] = {3, 1}; - constant long offset_in0 = 0; - // Node 2 (C): shape={2,3}, strides={3,1}, offset=0 - constant long strides_in1[] = {3, 1}; - constant long offset_in1 = 0; - // Compute strided read indices - uint remaining = gid; - uint idx_in0 = offset_in0; - uint idx_in1 = offset_in1; - for (int i = 1; i >= 0; --i) { - uint coord = remaining % shape[i]; - idx_in0 += coord * strides_in0[i]; - idx_in1 += coord * strides_in1[i]; - remaining /= shape[i]; - } - - out[gid] = in0[idx_in0] * in1[idx_in1]; - } -*/ void generateKernels(Graph& graph) { graph.shader_source = "#include \nusing namespace metal;\n"; for (int64_t i = 0; i < graph.nodes.size(); i++) {