Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
*venv/
*.so
*.DS_Store
CLAUDE.md
8 changes: 8 additions & 0 deletions cpp/include/common.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include <map>

enum class OpCode : int {
INPUT = 0,
Expand All @@ -14,3 +15,10 @@ enum class OpCode : int {
CONSTANT = 10,
COPY = 11
};

const std::map<OpCode, std::string> op_symbol = {
{OpCode::ADD, "+"},
{OpCode::SUB, "-"},
{OpCode::MUL, "*"},
{OpCode::DIV, "/"},
};
70 changes: 69 additions & 1 deletion cpp/src/compiler.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#include <sstream>
#include <numeric>
#include "../include/compiler.h"
#include "../include/common.h"

std::vector<Node> optimize_graph(std::vector<Node> raw_nodes) { return raw_nodes; }
// Could generate Fused Kernels, with special OpCodes
Expand All @@ -13,7 +16,8 @@ std::vector<Node> optimize_graph(std::vector<Node> raw_nodes) { return raw_nodes
// 5. Fusion, combine nodes into "blocks" that run in the same "way" (elementwise easiest)
// 6. loop fusion. like two for i in range(100) can be put together

void generateKernels(Graph& graph) {}


// Generates one huge string of all the kernel functions back to back
// if the op requires no gpu kernel (INPUT, VIEW, etc -> call it "no op"),
// we dont generate any string
Expand All @@ -23,3 +27,67 @@ void generateKernels(Graph& graph) {}
// Op metadata: Shape, strides (internal offset?) WONT be handled from execute()
// This means the kernel strings we generate needs to bake in/hardcode the loops for that
// kernels generated so that out buffer is idx 0, then the N inputs to the node (in order)


void generateKernels(Graph& graph) {
graph.shader_source = "#include <metal_stdlib>\nusing namespace metal;\n";
for (int64_t i = 0; i < graph.nodes.size(); i++) {
Node &node = graph.nodes[i];

if (node.op == OpCode::INPUT || node.op == OpCode::VIEW || node.op == OpCode::RESHAPE || node.op == OpCode::TRANSPOSE) {
// if Input, do a no-op
graph.configs.push_back({});
continue;
}

// function definition and inputs
std::ostringstream oss;
std::string kernel_name = "op_" + std::to_string(i);
oss << "kernel void " << kernel_name << " (\n\
device float* out [[ buffer(0) ]],\n";
for (int64_t j = 0; j < node.inputs.size(); j++) {
oss << "const device float* in" << j
<< " [[ buffer(" + std::to_string(j + 1) + ") ]],\n";
}
oss << "uint gid [[ thread_position_in_grid ]])\n{\n";

// hardcode constants: shape array
oss << "constant long shape[] = {";
for (int64_t j = 0; j < node.shape.size(); j++) {
oss << node.shape[j] << ",}"[j == node.shape.size() - 1];
}
oss << ";\n";

// hardcode constants: strides arrays
for (int64_t j = 0; j < node.inputs.size(); j++) {
oss << "constant long strides_in" << j << "[] = {";
Node &input = graph.nodes[node.inputs[j]];
for (int64_t k = 0; k < input.strides.size(); k++) {
oss << input.strides[k] << ",}"[k == input.strides.size() - 1];
}
oss << ";\n";
}

// declare variables for computing strided read indices
oss << "uint remaining = gid;\n";
for (int64_t j = 0; j < node.inputs.size(); j++) {
oss << "uint idx_in" << j << " = 0;\n"; // check that this is actually 0, is it actually zero??? Should be, I think
}

// compute strided read indices
oss << "for (int i = " << node.shape.size() - 1 << "; i >= 0; i--) {\n";
oss << "uint coord = remaining \% shape[i];\n";
for (int64_t j = 0; j < node.inputs.size(); j++) {
oss << "idx_in" << j << " += coord * strides_in" << j << "[i];\n";
}
oss << "remaining /= shape[i];\n}\n";

// note that I'm just putting the 4 binary ops for now. Need to do the others too but that's later
oss << "out[gid] = ";
oss << "in0[idx_in0] " << op_symbol.at(node.op) << " in1[idx_in1];\n}\n\n";

graph.shader_source += oss.str();
uint64_t total_elements = numel_from_shape(node.shape);
graph.configs.push_back({.name = kernel_name, .grid = {total_elements, 1, 1}, .group = {256, 1, 1}});
}
}
Loading