diff --git a/challenges/medium/84_swiglu_mlp_block/challenge.html b/challenges/medium/84_swiglu_mlp_block/challenge.html new file mode 100644 index 00000000..b93b5801 --- /dev/null +++ b/challenges/medium/84_swiglu_mlp_block/challenge.html @@ -0,0 +1,151 @@ +

+ Implement the SwiGLU MLP block — the feedforward network used in LLaMA, Mistral, Gemma, and most + modern large language models. Given an input matrix x of shape + [M, d_model] and three weight matrices W_gate, W_up + (each [d_model, d_ffn]), and W_down ([d_ffn, d_model]), + compute: + output = (SiLU(x × W_gate) ⊙ (x × W_up)) × W_down, + where SiLU(z) = z × sigmoid(z) and denotes element-wise + multiplication. All tensors are float32. +

+ + + + + + + + + + + + x + [M, d_model] + + + + + x · W_gate + gate projection + + + + + x · W_up + up projection + + + [M, d_ffn] + [M, d_ffn] + + + + + + + SiLU + + + + + + + + + + + + + + + + + · W_down + [M, d_ffn] + + + + + + + output + [M, d_model] + + + z · sigmoid(z) + + +

Implementation Requirements

+ + +

Example

+

+ Input: M = 2, d_model = 2, d_ffn = 4 +

+

+ \(x\) (float32, \(2 \times 2\)): + \[ + x = \begin{bmatrix} 1.0 & 0.0 \\ 0.0 & 1.0 \end{bmatrix} + \] + \(W_\text{gate}\) and \(W_\text{up}\) (both \(2 \times 4\)): + \[ + W_\text{gate} = W_\text{up} = + \begin{bmatrix} + 1.0 & 0.0 & 0.0 & 0.0 \\ + 0.0 & 1.0 & 0.0 & 0.0 + \end{bmatrix} + \] + \(W_\text{down}\) (\(4 \times 2\)): + \[ + W_\text{down} = + \begin{bmatrix} + 1.0 & 0.0 \\ + 0.0 & 1.0 \\ + 0.0 & 0.0 \\ + 0.0 & 0.0 + \end{bmatrix} + \] +

+

+ Intermediate steps: + \[ + \text{gate} = x \cdot W_\text{gate} = + \begin{bmatrix} 1.0 & 0.0 & 0.0 & 0.0 \\ 0.0 & 1.0 & 0.0 & 0.0 \end{bmatrix} + \] + \[ + \text{up} = x \cdot W_\text{up} = + \begin{bmatrix} 1.0 & 0.0 & 0.0 & 0.0 \\ 0.0 & 1.0 & 0.0 & 0.0 \end{bmatrix} + \] + \[ + \text{SiLU}(1.0) = 1.0 \times \sigma(1.0) \approx 0.7311 + \] + \[ + \text{hidden} = \text{SiLU}(\text{gate}) \odot \text{up} = + \begin{bmatrix} 0.7311 & 0.0 & 0.0 & 0.0 \\ 0.0 & 0.7311 & 0.0 & 0.0 \end{bmatrix} + \] +

+

+ Output: + \[ + \text{output} = \text{hidden} \cdot W_\text{down} \approx + \begin{bmatrix} 0.7311 & 0.0 \\ 0.0 & 0.7311 \end{bmatrix} + \] +

+ +

Constraints

+ diff --git a/challenges/medium/84_swiglu_mlp_block/challenge.py b/challenges/medium/84_swiglu_mlp_block/challenge.py new file mode 100644 index 00000000..87408f23 --- /dev/null +++ b/challenges/medium/84_swiglu_mlp_block/challenge.py @@ -0,0 +1,162 @@ +import ctypes +from typing import Any, Dict, List + +import torch +import torch.nn.functional as F +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="SwiGLU MLP Block", + atol=1e-04, + rtol=1e-04, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + x: torch.Tensor, + W_gate: torch.Tensor, + W_up: torch.Tensor, + W_down: torch.Tensor, + output: torch.Tensor, + M: int, + d_model: int, + d_ffn: int, + ): + assert x.shape == (M, d_model) + assert W_gate.shape == (d_model, d_ffn) + assert W_up.shape == (d_model, d_ffn) + assert W_down.shape == (d_ffn, d_model) + assert output.shape == (M, d_model) + assert ( + x.dtype == W_gate.dtype == W_up.dtype == W_down.dtype == output.dtype == torch.float32 + ) + assert x.device.type == "cuda" + assert W_gate.device.type == "cuda" + assert W_up.device.type == "cuda" + assert W_down.device.type == "cuda" + assert output.device.type == "cuda" + + gate = x @ W_gate # [M, d_ffn] + up = x @ W_up # [M, d_ffn] + hidden = F.silu(gate) * up # [M, d_ffn] + output.copy_(hidden @ W_down) # [M, d_model] + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "x": (ctypes.POINTER(ctypes.c_float), "in"), + "W_gate": (ctypes.POINTER(ctypes.c_float), "in"), + "W_up": (ctypes.POINTER(ctypes.c_float), "in"), + "W_down": (ctypes.POINTER(ctypes.c_float), "in"), + "output": (ctypes.POINTER(ctypes.c_float), "out"), + "M": (ctypes.c_int, "in"), + "d_model": (ctypes.c_int, "in"), + "d_ffn": (ctypes.c_int, "in"), + } + + def _make_test_case(self, M, d_model, d_ffn, zero_x=False): + device = "cuda" + dtype = torch.float32 + if zero_x: + x = torch.zeros(M, d_model, device=device, dtype=dtype) + else: + x = torch.randn(M, d_model, device=device, dtype=dtype) * 0.1 + W_gate = torch.randn(d_model, d_ffn, device=device, dtype=dtype) * 0.02 + W_up = torch.randn(d_model, d_ffn, device=device, dtype=dtype) * 0.02 + W_down = torch.randn(d_ffn, d_model, device=device, dtype=dtype) * 0.02 + output = torch.empty(M, d_model, device=device, dtype=dtype) + return { + "x": x, + "W_gate": W_gate, + "W_up": W_up, + "W_down": W_down, + "output": output, + "M": M, + "d_model": d_model, + "d_ffn": d_ffn, + } + + def generate_example_test(self) -> Dict[str, Any]: + device = "cuda" + dtype = torch.float32 + M, d_model, d_ffn = 2, 2, 4 + # x: each row is a basis vector + x = torch.tensor( + [[1.0, 0.0], [0.0, 1.0]], + device=device, + dtype=dtype, + ) + # W_gate: [d_model=2, d_ffn=4] — first two columns are identity, rest zeros + W_gate = torch.tensor( + [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], + device=device, + dtype=dtype, + ) + # W_up: same layout as W_gate + W_up = torch.tensor( + [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], + device=device, + dtype=dtype, + ) + # W_down: [d_ffn=4, d_model=2] — top 2x2 is identity, rest zeros + W_down = torch.tensor( + [[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]], + device=device, + dtype=dtype, + ) + output = torch.empty(M, d_model, device=device, dtype=dtype) + return { + "x": x, + "W_gate": W_gate, + "W_up": W_up, + "W_down": W_down, + "output": output, + "M": M, + "d_model": d_model, + "d_ffn": d_ffn, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + torch.manual_seed(42) + tests = [] + + # Edge cases: single row + tests.append(self._make_test_case(1, 4, 8)) + + # Edge case: two rows + tests.append(self._make_test_case(2, 4, 8)) + + # Zero input + tests.append(self._make_test_case(4, 8, 16, zero_x=True)) + + # Power-of-2 sizes + tests.append(self._make_test_case(16, 32, 64)) + + # Power-of-2 larger + tests.append(self._make_test_case(64, 64, 128)) + + # Non-power-of-2 M + tests.append(self._make_test_case(30, 32, 64)) + + # Non-power-of-2 all dims + tests.append(self._make_test_case(100, 60, 120)) + + # Non-power-of-2 M, medium size + tests.append(self._make_test_case(255, 64, 128)) + + # Realistic small inference batch (LLaMA-style ratios) + tests.append(self._make_test_case(128, 256, 512)) + + # Realistic medium inference batch + tests.append(self._make_test_case(256, 512, 1024)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + # LLaMA-3 8B style: d_model=4096, d_ffn=14336, M=512 (batch=4 x seq=128) + return self._make_test_case(512, 4096, 14336) diff --git a/challenges/medium/84_swiglu_mlp_block/starter/starter.cu b/challenges/medium/84_swiglu_mlp_block/starter/starter.cu new file mode 100644 index 00000000..118724bc --- /dev/null +++ b/challenges/medium/84_swiglu_mlp_block/starter/starter.cu @@ -0,0 +1,5 @@ +#include + +// x, W_gate, W_up, W_down, output are device pointers +extern "C" void solve(const float* x, const float* W_gate, const float* W_up, const float* W_down, + float* output, int M, int d_model, int d_ffn) {} diff --git a/challenges/medium/84_swiglu_mlp_block/starter/starter.cute.py b/challenges/medium/84_swiglu_mlp_block/starter/starter.cute.py new file mode 100644 index 00000000..c0833020 --- /dev/null +++ b/challenges/medium/84_swiglu_mlp_block/starter/starter.cute.py @@ -0,0 +1,17 @@ +import cutlass +import cutlass.cute as cute + + +# x, W_gate, W_up, W_down, output are tensors on the GPU +@cute.jit +def solve( + x: cute.Tensor, + W_gate: cute.Tensor, + W_up: cute.Tensor, + W_down: cute.Tensor, + output: cute.Tensor, + M: cute.Int32, + d_model: cute.Int32, + d_ffn: cute.Int32, +): + pass diff --git a/challenges/medium/84_swiglu_mlp_block/starter/starter.jax.py b/challenges/medium/84_swiglu_mlp_block/starter/starter.jax.py new file mode 100644 index 00000000..219948bb --- /dev/null +++ b/challenges/medium/84_swiglu_mlp_block/starter/starter.jax.py @@ -0,0 +1,17 @@ +import jax +import jax.numpy as jnp + + +# x, W_gate, W_up, W_down are tensors on GPU +@jax.jit +def solve( + x: jax.Array, + W_gate: jax.Array, + W_up: jax.Array, + W_down: jax.Array, + M: int, + d_model: int, + d_ffn: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/84_swiglu_mlp_block/starter/starter.mojo b/challenges/medium/84_swiglu_mlp_block/starter/starter.mojo new file mode 100644 index 00000000..184038dc --- /dev/null +++ b/challenges/medium/84_swiglu_mlp_block/starter/starter.mojo @@ -0,0 +1,9 @@ +from gpu.host import DeviceContext +from gpu.id import block_dim, block_idx, thread_idx +from memory import UnsafePointer +from math import ceildiv + +# x, W_gate, W_up, W_down, output are device pointers +@export +def solve(x: UnsafePointer[Float32], W_gate: UnsafePointer[Float32], W_up: UnsafePointer[Float32], W_down: UnsafePointer[Float32], output: UnsafePointer[Float32], M: Int32, d_model: Int32, d_ffn: Int32): + pass diff --git a/challenges/medium/84_swiglu_mlp_block/starter/starter.pytorch.py b/challenges/medium/84_swiglu_mlp_block/starter/starter.pytorch.py new file mode 100644 index 00000000..2af6572f --- /dev/null +++ b/challenges/medium/84_swiglu_mlp_block/starter/starter.pytorch.py @@ -0,0 +1,15 @@ +import torch + + +# x, W_gate, W_up, W_down, output are tensors on the GPU +def solve( + x: torch.Tensor, + W_gate: torch.Tensor, + W_up: torch.Tensor, + W_down: torch.Tensor, + output: torch.Tensor, + M: int, + d_model: int, + d_ffn: int, +): + pass diff --git a/challenges/medium/84_swiglu_mlp_block/starter/starter.triton.py b/challenges/medium/84_swiglu_mlp_block/starter/starter.triton.py new file mode 100644 index 00000000..47552d0c --- /dev/null +++ b/challenges/medium/84_swiglu_mlp_block/starter/starter.triton.py @@ -0,0 +1,17 @@ +import torch +import triton +import triton.language as tl + + +# x, W_gate, W_up, W_down, output are tensors on the GPU +def solve( + x: torch.Tensor, + W_gate: torch.Tensor, + W_up: torch.Tensor, + W_down: torch.Tensor, + output: torch.Tensor, + M: int, + d_model: int, + d_ffn: int, +): + pass