Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions challenges/medium/84_swiglu_mlp_block/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
<p>
Implement the SwiGLU MLP block — the feedforward network used in LLaMA, Mistral, Gemma, and most
modern large language models. Given an input matrix <code>x</code> of shape
<code>[M, d_model]</code> and three weight matrices <code>W_gate</code>, <code>W_up</code>
(each <code>[d_model, d_ffn]</code>), and <code>W_down</code> (<code>[d_ffn, d_model]</code>),
compute:
<code>output = (SiLU(x &times; W_gate) &odot; (x &times; W_up)) &times; W_down</code>,
where <code>SiLU(z) = z &times; sigmoid(z)</code> and <code>&odot;</code> denotes element-wise
multiplication. All tensors are <code>float32</code>.
</p>

<svg width="680" height="220" viewBox="0 0 680 220" xmlns="http://www.w3.org/2000/svg"
style="display:block; margin:20px auto; font-family:monospace;">
<rect width="680" height="220" fill="#222" rx="8"/>
<defs>
<marker id="arr" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto">
<path d="M0,0 L0,6 L8,3 z" fill="#888"/>
</marker>
</defs>

<!-- x box -->
<rect x="16" y="82" width="56" height="40" rx="4" fill="#2a4a7f" stroke="#5588cc" stroke-width="1.5"/>
<text x="44" y="106" fill="#ccc" font-size="12" text-anchor="middle">x</text>
<text x="44" y="136" fill="#666" font-size="8" text-anchor="middle">[M, d_model]</text>

<!-- Gate branch (top) -->
<line x1="72" y1="92" x2="108" y2="52" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
<rect x="110" y="32" width="90" height="40" rx="4" fill="#2a4a7f" stroke="#5588cc" stroke-width="1.5"/>
<text x="155" y="56" fill="#ccc" font-size="10" text-anchor="middle">x &#xb7; W_gate</text>
<text x="155" y="22" fill="#5588cc" font-size="9" text-anchor="middle">gate projection</text>

<!-- Up branch (bottom) -->
<line x1="72" y1="112" x2="108" y2="152" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
<rect x="110" y="132" width="90" height="40" rx="4" fill="#2a4a7f" stroke="#5588cc" stroke-width="1.5"/>
<text x="155" y="156" fill="#ccc" font-size="10" text-anchor="middle">x &#xb7; W_up</text>
<text x="155" y="184" fill="#5588cc" font-size="9" text-anchor="middle">up projection</text>

<!-- Shape labels after projections -->
<text x="155" y="82" fill="#666" font-size="8" text-anchor="middle">[M, d_ffn]</text>
<text x="155" y="130" fill="#666" font-size="8" text-anchor="middle">[M, d_ffn]</text>

<!-- Arrow gate → SiLU -->
<line x1="200" y1="52" x2="238" y2="52" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>

<!-- SiLU box -->
<rect x="240" y="32" width="60" height="40" rx="4" fill="#1a5a3a" stroke="#44aa66" stroke-width="1.5"/>
<text x="270" y="56" fill="#ccc" font-size="11" text-anchor="middle">SiLU</text>

<!-- Arrow SiLU → element-wise multiply (goes down) -->
<line x1="300" y1="52" x2="370" y2="90" stroke="#44aa66" stroke-width="1.5" marker-end="url(#arr)"/>

<!-- Arrow up branch → element-wise multiply (goes up) -->
<line x1="200" y1="152" x2="370" y2="114" stroke="#5588cc" stroke-width="1.5" marker-end="url(#arr)"/>

<!-- Element-wise multiply box -->
<rect x="372" y="82" width="50" height="40" rx="4" fill="#5a3a1a" stroke="#cc8844" stroke-width="1.5"/>
<text x="397" y="107" fill="#ccc" font-size="16" text-anchor="middle">&#x2299;</text>

<!-- Arrow ⊙ → W_down -->
<line x1="422" y1="102" x2="458" y2="102" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>

<!-- W_down box -->
<rect x="460" y="82" width="86" height="40" rx="4" fill="#2a4a7f" stroke="#5588cc" stroke-width="1.5"/>
<text x="503" y="106" fill="#ccc" font-size="10" text-anchor="middle">&#xb7; W_down</text>
<text x="503" y="76" fill="#666" font-size="8" text-anchor="middle">[M, d_ffn]</text>

<!-- Arrow W_down → output -->
<line x1="546" y1="102" x2="578" y2="102" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>

<!-- Output box -->
<rect x="580" y="82" width="80" height="40" rx="4" fill="#3a1a3a" stroke="#cc44cc" stroke-width="1.5"/>
<text x="620" y="106" fill="#ccc" font-size="11" text-anchor="middle">output</text>
<text x="620" y="136" fill="#666" font-size="8" text-anchor="middle">[M, d_model]</text>

<!-- SiLU formula -->
<text x="270" y="18" fill="#44aa66" font-size="9" text-anchor="middle">z &#xb7; sigmoid(z)</text>
</svg>

<h2>Implementation Requirements</h2>
<ul>
<li>Implement the <code>solve</code> function with the signature unchanged.</li>
<li>Do not use external libraries beyond the framework provided.</li>
<li>Write the result into <code>output</code> in-place.</li>
</ul>

<h2>Example</h2>
<p>
Input: <code>M</code> = 2, <code>d_model</code> = 2, <code>d_ffn</code> = 4
</p>
<p>
\(x\) (float32, \(2 \times 2\)):
\[
x = \begin{bmatrix} 1.0 & 0.0 \\ 0.0 & 1.0 \end{bmatrix}
\]
\(W_\text{gate}\) and \(W_\text{up}\) (both \(2 \times 4\)):
\[
W_\text{gate} = W_\text{up} =
\begin{bmatrix}
1.0 & 0.0 & 0.0 & 0.0 \\
0.0 & 1.0 & 0.0 & 0.0
\end{bmatrix}
\]
\(W_\text{down}\) (\(4 \times 2\)):
\[
W_\text{down} =
\begin{bmatrix}
1.0 & 0.0 \\
0.0 & 1.0 \\
0.0 & 0.0 \\
0.0 & 0.0
\end{bmatrix}
\]
</p>
<p>
Intermediate steps:
\[
\text{gate} = x \cdot W_\text{gate} =
\begin{bmatrix} 1.0 & 0.0 & 0.0 & 0.0 \\ 0.0 & 1.0 & 0.0 & 0.0 \end{bmatrix}
\]
\[
\text{up} = x \cdot W_\text{up} =
\begin{bmatrix} 1.0 & 0.0 & 0.0 & 0.0 \\ 0.0 & 1.0 & 0.0 & 0.0 \end{bmatrix}
\]
\[
\text{SiLU}(1.0) = 1.0 \times \sigma(1.0) \approx 0.7311
\]
\[
\text{hidden} = \text{SiLU}(\text{gate}) \odot \text{up} =
\begin{bmatrix} 0.7311 & 0.0 & 0.0 & 0.0 \\ 0.0 & 0.7311 & 0.0 & 0.0 \end{bmatrix}
\]
</p>
<p>
Output:
\[
\text{output} = \text{hidden} \cdot W_\text{down} \approx
\begin{bmatrix} 0.7311 & 0.0 \\ 0.0 & 0.7311 \end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>M</code> &le; 65,536</li>
<li>1 &le; <code>d_model</code> &le; 8,192</li>
<li>1 &le; <code>d_ffn</code> &le; 32,768</li>
<li>All tensors are <code>float32</code> on the GPU.</li>
<li>Input values are in the range [-10, 10].</li>
<li>
Performance is measured with <code>M</code> = 512, <code>d_model</code> = 4,096,
<code>d_ffn</code> = 14,336
</li>
</ul>
162 changes: 162 additions & 0 deletions challenges/medium/84_swiglu_mlp_block/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import ctypes
from typing import Any, Dict, List

import torch
import torch.nn.functional as F
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="SwiGLU MLP Block",
atol=1e-04,
rtol=1e-04,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
x: torch.Tensor,
W_gate: torch.Tensor,
W_up: torch.Tensor,
W_down: torch.Tensor,
output: torch.Tensor,
M: int,
d_model: int,
d_ffn: int,
):
assert x.shape == (M, d_model)
assert W_gate.shape == (d_model, d_ffn)
assert W_up.shape == (d_model, d_ffn)
assert W_down.shape == (d_ffn, d_model)
assert output.shape == (M, d_model)
assert (
x.dtype == W_gate.dtype == W_up.dtype == W_down.dtype == output.dtype == torch.float32
)
assert x.device.type == "cuda"
assert W_gate.device.type == "cuda"
assert W_up.device.type == "cuda"
assert W_down.device.type == "cuda"
assert output.device.type == "cuda"

gate = x @ W_gate # [M, d_ffn]
up = x @ W_up # [M, d_ffn]
hidden = F.silu(gate) * up # [M, d_ffn]
output.copy_(hidden @ W_down) # [M, d_model]

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"x": (ctypes.POINTER(ctypes.c_float), "in"),
"W_gate": (ctypes.POINTER(ctypes.c_float), "in"),
"W_up": (ctypes.POINTER(ctypes.c_float), "in"),
"W_down": (ctypes.POINTER(ctypes.c_float), "in"),
"output": (ctypes.POINTER(ctypes.c_float), "out"),
"M": (ctypes.c_int, "in"),
"d_model": (ctypes.c_int, "in"),
"d_ffn": (ctypes.c_int, "in"),
}

def _make_test_case(self, M, d_model, d_ffn, zero_x=False):
device = "cuda"
dtype = torch.float32
if zero_x:
x = torch.zeros(M, d_model, device=device, dtype=dtype)
else:
x = torch.randn(M, d_model, device=device, dtype=dtype) * 0.1
W_gate = torch.randn(d_model, d_ffn, device=device, dtype=dtype) * 0.02
W_up = torch.randn(d_model, d_ffn, device=device, dtype=dtype) * 0.02
W_down = torch.randn(d_ffn, d_model, device=device, dtype=dtype) * 0.02
output = torch.empty(M, d_model, device=device, dtype=dtype)
return {
"x": x,
"W_gate": W_gate,
"W_up": W_up,
"W_down": W_down,
"output": output,
"M": M,
"d_model": d_model,
"d_ffn": d_ffn,
}

def generate_example_test(self) -> Dict[str, Any]:
device = "cuda"
dtype = torch.float32
M, d_model, d_ffn = 2, 2, 4
# x: each row is a basis vector
x = torch.tensor(
[[1.0, 0.0], [0.0, 1.0]],
device=device,
dtype=dtype,
)
# W_gate: [d_model=2, d_ffn=4] — first two columns are identity, rest zeros
W_gate = torch.tensor(
[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]],
device=device,
dtype=dtype,
)
# W_up: same layout as W_gate
W_up = torch.tensor(
[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]],
device=device,
dtype=dtype,
)
# W_down: [d_ffn=4, d_model=2] — top 2x2 is identity, rest zeros
W_down = torch.tensor(
[[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]],
device=device,
dtype=dtype,
)
output = torch.empty(M, d_model, device=device, dtype=dtype)
return {
"x": x,
"W_gate": W_gate,
"W_up": W_up,
"W_down": W_down,
"output": output,
"M": M,
"d_model": d_model,
"d_ffn": d_ffn,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
torch.manual_seed(42)
tests = []

# Edge cases: single row
tests.append(self._make_test_case(1, 4, 8))

# Edge case: two rows
tests.append(self._make_test_case(2, 4, 8))

# Zero input
tests.append(self._make_test_case(4, 8, 16, zero_x=True))

# Power-of-2 sizes
tests.append(self._make_test_case(16, 32, 64))

# Power-of-2 larger
tests.append(self._make_test_case(64, 64, 128))

# Non-power-of-2 M
tests.append(self._make_test_case(30, 32, 64))

# Non-power-of-2 all dims
tests.append(self._make_test_case(100, 60, 120))

# Non-power-of-2 M, medium size
tests.append(self._make_test_case(255, 64, 128))

# Realistic small inference batch (LLaMA-style ratios)
tests.append(self._make_test_case(128, 256, 512))

# Realistic medium inference batch
tests.append(self._make_test_case(256, 512, 1024))

return tests

def generate_performance_test(self) -> Dict[str, Any]:
torch.manual_seed(0)
# LLaMA-3 8B style: d_model=4096, d_ffn=14336, M=512 (batch=4 x seq=128)
return self._make_test_case(512, 4096, 14336)
5 changes: 5 additions & 0 deletions challenges/medium/84_swiglu_mlp_block/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <cuda_runtime.h>

// x, W_gate, W_up, W_down, output are device pointers
extern "C" void solve(const float* x, const float* W_gate, const float* W_up, const float* W_down,
float* output, int M, int d_model, int d_ffn) {}
17 changes: 17 additions & 0 deletions challenges/medium/84_swiglu_mlp_block/starter/starter.cute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import cutlass
import cutlass.cute as cute


# x, W_gate, W_up, W_down, output are tensors on the GPU
@cute.jit
def solve(
x: cute.Tensor,
W_gate: cute.Tensor,
W_up: cute.Tensor,
W_down: cute.Tensor,
output: cute.Tensor,
M: cute.Int32,
d_model: cute.Int32,
d_ffn: cute.Int32,
):
pass
17 changes: 17 additions & 0 deletions challenges/medium/84_swiglu_mlp_block/starter/starter.jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import jax
import jax.numpy as jnp


# x, W_gate, W_up, W_down are tensors on GPU
@jax.jit
def solve(
x: jax.Array,
W_gate: jax.Array,
W_up: jax.Array,
W_down: jax.Array,
M: int,
d_model: int,
d_ffn: int,
) -> jax.Array:
# return output tensor directly
pass
9 changes: 9 additions & 0 deletions challenges/medium/84_swiglu_mlp_block/starter/starter.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from gpu.host import DeviceContext
from gpu.id import block_dim, block_idx, thread_idx
from memory import UnsafePointer
from math import ceildiv

# x, W_gate, W_up, W_down, output are device pointers
@export
def solve(x: UnsafePointer[Float32], W_gate: UnsafePointer[Float32], W_up: UnsafePointer[Float32], W_down: UnsafePointer[Float32], output: UnsafePointer[Float32], M: Int32, d_model: Int32, d_ffn: Int32):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch


# x, W_gate, W_up, W_down, output are tensors on the GPU
def solve(
x: torch.Tensor,
W_gate: torch.Tensor,
W_up: torch.Tensor,
W_down: torch.Tensor,
output: torch.Tensor,
M: int,
d_model: int,
d_ffn: int,
):
pass
Loading
Loading