diff --git a/challenges/medium/84_swiglu_mlp_block/challenge.html b/challenges/medium/84_swiglu_mlp_block/challenge.html
new file mode 100644
index 00000000..b93b5801
--- /dev/null
+++ b/challenges/medium/84_swiglu_mlp_block/challenge.html
@@ -0,0 +1,151 @@
+
+ Implement the SwiGLU MLP block — the feedforward network used in LLaMA, Mistral, Gemma, and most
+ modern large language models. Given an input matrix x of shape
+ [M, d_model] and three weight matrices W_gate, W_up
+ (each [d_model, d_ffn]), and W_down ([d_ffn, d_model]),
+ compute:
+ output = (SiLU(x × W_gate) ⊙ (x × W_up)) × W_down,
+ where SiLU(z) = z × sigmoid(z) and ⊙ denotes element-wise
+ multiplication. All tensors are float32.
+
+
+
+
+
+
+
+
+
+
+
+
+ x
+ [M, d_model]
+
+
+
+
+ x · W_gate
+ gate projection
+
+
+
+
+ x · W_up
+ up projection
+
+
+ [M, d_ffn]
+ [M, d_ffn]
+
+
+
+
+
+
+ SiLU
+
+
+
+
+
+
+
+
+
+ ⊙
+
+
+
+
+
+
+ · W_down
+ [M, d_ffn]
+
+
+
+
+
+
+ output
+ [M, d_model]
+
+
+ z · sigmoid(z)
+
+
+Implementation Requirements
+
+ Implement the solve function with the signature unchanged.
+ Do not use external libraries beyond the framework provided.
+ Write the result into output in-place.
+
+
+Example
+
+ Input: M = 2, d_model = 2, d_ffn = 4
+
+
+ \(x\) (float32, \(2 \times 2\)):
+ \[
+ x = \begin{bmatrix} 1.0 & 0.0 \\ 0.0 & 1.0 \end{bmatrix}
+ \]
+ \(W_\text{gate}\) and \(W_\text{up}\) (both \(2 \times 4\)):
+ \[
+ W_\text{gate} = W_\text{up} =
+ \begin{bmatrix}
+ 1.0 & 0.0 & 0.0 & 0.0 \\
+ 0.0 & 1.0 & 0.0 & 0.0
+ \end{bmatrix}
+ \]
+ \(W_\text{down}\) (\(4 \times 2\)):
+ \[
+ W_\text{down} =
+ \begin{bmatrix}
+ 1.0 & 0.0 \\
+ 0.0 & 1.0 \\
+ 0.0 & 0.0 \\
+ 0.0 & 0.0
+ \end{bmatrix}
+ \]
+
+
+ Intermediate steps:
+ \[
+ \text{gate} = x \cdot W_\text{gate} =
+ \begin{bmatrix} 1.0 & 0.0 & 0.0 & 0.0 \\ 0.0 & 1.0 & 0.0 & 0.0 \end{bmatrix}
+ \]
+ \[
+ \text{up} = x \cdot W_\text{up} =
+ \begin{bmatrix} 1.0 & 0.0 & 0.0 & 0.0 \\ 0.0 & 1.0 & 0.0 & 0.0 \end{bmatrix}
+ \]
+ \[
+ \text{SiLU}(1.0) = 1.0 \times \sigma(1.0) \approx 0.7311
+ \]
+ \[
+ \text{hidden} = \text{SiLU}(\text{gate}) \odot \text{up} =
+ \begin{bmatrix} 0.7311 & 0.0 & 0.0 & 0.0 \\ 0.0 & 0.7311 & 0.0 & 0.0 \end{bmatrix}
+ \]
+
+
+ Output:
+ \[
+ \text{output} = \text{hidden} \cdot W_\text{down} \approx
+ \begin{bmatrix} 0.7311 & 0.0 \\ 0.0 & 0.7311 \end{bmatrix}
+ \]
+
+
+Constraints
+
+ 1 ≤ M ≤ 65,536
+ 1 ≤ d_model ≤ 8,192
+ 1 ≤ d_ffn ≤ 32,768
+ All tensors are float32 on the GPU.
+ Input values are in the range [-10, 10].
+
+ Performance is measured with M = 512, d_model = 4,096,
+ d_ffn = 14,336
+
+
diff --git a/challenges/medium/84_swiglu_mlp_block/challenge.py b/challenges/medium/84_swiglu_mlp_block/challenge.py
new file mode 100644
index 00000000..87408f23
--- /dev/null
+++ b/challenges/medium/84_swiglu_mlp_block/challenge.py
@@ -0,0 +1,162 @@
+import ctypes
+from typing import Any, Dict, List
+
+import torch
+import torch.nn.functional as F
+from core.challenge_base import ChallengeBase
+
+
+class Challenge(ChallengeBase):
+ def __init__(self):
+ super().__init__(
+ name="SwiGLU MLP Block",
+ atol=1e-04,
+ rtol=1e-04,
+ num_gpus=1,
+ access_tier="free",
+ )
+
+ def reference_impl(
+ self,
+ x: torch.Tensor,
+ W_gate: torch.Tensor,
+ W_up: torch.Tensor,
+ W_down: torch.Tensor,
+ output: torch.Tensor,
+ M: int,
+ d_model: int,
+ d_ffn: int,
+ ):
+ assert x.shape == (M, d_model)
+ assert W_gate.shape == (d_model, d_ffn)
+ assert W_up.shape == (d_model, d_ffn)
+ assert W_down.shape == (d_ffn, d_model)
+ assert output.shape == (M, d_model)
+ assert (
+ x.dtype == W_gate.dtype == W_up.dtype == W_down.dtype == output.dtype == torch.float32
+ )
+ assert x.device.type == "cuda"
+ assert W_gate.device.type == "cuda"
+ assert W_up.device.type == "cuda"
+ assert W_down.device.type == "cuda"
+ assert output.device.type == "cuda"
+
+ gate = x @ W_gate # [M, d_ffn]
+ up = x @ W_up # [M, d_ffn]
+ hidden = F.silu(gate) * up # [M, d_ffn]
+ output.copy_(hidden @ W_down) # [M, d_model]
+
+ def get_solve_signature(self) -> Dict[str, tuple]:
+ return {
+ "x": (ctypes.POINTER(ctypes.c_float), "in"),
+ "W_gate": (ctypes.POINTER(ctypes.c_float), "in"),
+ "W_up": (ctypes.POINTER(ctypes.c_float), "in"),
+ "W_down": (ctypes.POINTER(ctypes.c_float), "in"),
+ "output": (ctypes.POINTER(ctypes.c_float), "out"),
+ "M": (ctypes.c_int, "in"),
+ "d_model": (ctypes.c_int, "in"),
+ "d_ffn": (ctypes.c_int, "in"),
+ }
+
+ def _make_test_case(self, M, d_model, d_ffn, zero_x=False):
+ device = "cuda"
+ dtype = torch.float32
+ if zero_x:
+ x = torch.zeros(M, d_model, device=device, dtype=dtype)
+ else:
+ x = torch.randn(M, d_model, device=device, dtype=dtype) * 0.1
+ W_gate = torch.randn(d_model, d_ffn, device=device, dtype=dtype) * 0.02
+ W_up = torch.randn(d_model, d_ffn, device=device, dtype=dtype) * 0.02
+ W_down = torch.randn(d_ffn, d_model, device=device, dtype=dtype) * 0.02
+ output = torch.empty(M, d_model, device=device, dtype=dtype)
+ return {
+ "x": x,
+ "W_gate": W_gate,
+ "W_up": W_up,
+ "W_down": W_down,
+ "output": output,
+ "M": M,
+ "d_model": d_model,
+ "d_ffn": d_ffn,
+ }
+
+ def generate_example_test(self) -> Dict[str, Any]:
+ device = "cuda"
+ dtype = torch.float32
+ M, d_model, d_ffn = 2, 2, 4
+ # x: each row is a basis vector
+ x = torch.tensor(
+ [[1.0, 0.0], [0.0, 1.0]],
+ device=device,
+ dtype=dtype,
+ )
+ # W_gate: [d_model=2, d_ffn=4] — first two columns are identity, rest zeros
+ W_gate = torch.tensor(
+ [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]],
+ device=device,
+ dtype=dtype,
+ )
+ # W_up: same layout as W_gate
+ W_up = torch.tensor(
+ [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]],
+ device=device,
+ dtype=dtype,
+ )
+ # W_down: [d_ffn=4, d_model=2] — top 2x2 is identity, rest zeros
+ W_down = torch.tensor(
+ [[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]],
+ device=device,
+ dtype=dtype,
+ )
+ output = torch.empty(M, d_model, device=device, dtype=dtype)
+ return {
+ "x": x,
+ "W_gate": W_gate,
+ "W_up": W_up,
+ "W_down": W_down,
+ "output": output,
+ "M": M,
+ "d_model": d_model,
+ "d_ffn": d_ffn,
+ }
+
+ def generate_functional_test(self) -> List[Dict[str, Any]]:
+ torch.manual_seed(42)
+ tests = []
+
+ # Edge cases: single row
+ tests.append(self._make_test_case(1, 4, 8))
+
+ # Edge case: two rows
+ tests.append(self._make_test_case(2, 4, 8))
+
+ # Zero input
+ tests.append(self._make_test_case(4, 8, 16, zero_x=True))
+
+ # Power-of-2 sizes
+ tests.append(self._make_test_case(16, 32, 64))
+
+ # Power-of-2 larger
+ tests.append(self._make_test_case(64, 64, 128))
+
+ # Non-power-of-2 M
+ tests.append(self._make_test_case(30, 32, 64))
+
+ # Non-power-of-2 all dims
+ tests.append(self._make_test_case(100, 60, 120))
+
+ # Non-power-of-2 M, medium size
+ tests.append(self._make_test_case(255, 64, 128))
+
+ # Realistic small inference batch (LLaMA-style ratios)
+ tests.append(self._make_test_case(128, 256, 512))
+
+ # Realistic medium inference batch
+ tests.append(self._make_test_case(256, 512, 1024))
+
+ return tests
+
+ def generate_performance_test(self) -> Dict[str, Any]:
+ torch.manual_seed(0)
+ # LLaMA-3 8B style: d_model=4096, d_ffn=14336, M=512 (batch=4 x seq=128)
+ return self._make_test_case(512, 4096, 14336)
diff --git a/challenges/medium/84_swiglu_mlp_block/starter/starter.cu b/challenges/medium/84_swiglu_mlp_block/starter/starter.cu
new file mode 100644
index 00000000..118724bc
--- /dev/null
+++ b/challenges/medium/84_swiglu_mlp_block/starter/starter.cu
@@ -0,0 +1,5 @@
+#include
+
+// x, W_gate, W_up, W_down, output are device pointers
+extern "C" void solve(const float* x, const float* W_gate, const float* W_up, const float* W_down,
+ float* output, int M, int d_model, int d_ffn) {}
diff --git a/challenges/medium/84_swiglu_mlp_block/starter/starter.cute.py b/challenges/medium/84_swiglu_mlp_block/starter/starter.cute.py
new file mode 100644
index 00000000..c0833020
--- /dev/null
+++ b/challenges/medium/84_swiglu_mlp_block/starter/starter.cute.py
@@ -0,0 +1,17 @@
+import cutlass
+import cutlass.cute as cute
+
+
+# x, W_gate, W_up, W_down, output are tensors on the GPU
+@cute.jit
+def solve(
+ x: cute.Tensor,
+ W_gate: cute.Tensor,
+ W_up: cute.Tensor,
+ W_down: cute.Tensor,
+ output: cute.Tensor,
+ M: cute.Int32,
+ d_model: cute.Int32,
+ d_ffn: cute.Int32,
+):
+ pass
diff --git a/challenges/medium/84_swiglu_mlp_block/starter/starter.jax.py b/challenges/medium/84_swiglu_mlp_block/starter/starter.jax.py
new file mode 100644
index 00000000..219948bb
--- /dev/null
+++ b/challenges/medium/84_swiglu_mlp_block/starter/starter.jax.py
@@ -0,0 +1,17 @@
+import jax
+import jax.numpy as jnp
+
+
+# x, W_gate, W_up, W_down are tensors on GPU
+@jax.jit
+def solve(
+ x: jax.Array,
+ W_gate: jax.Array,
+ W_up: jax.Array,
+ W_down: jax.Array,
+ M: int,
+ d_model: int,
+ d_ffn: int,
+) -> jax.Array:
+ # return output tensor directly
+ pass
diff --git a/challenges/medium/84_swiglu_mlp_block/starter/starter.mojo b/challenges/medium/84_swiglu_mlp_block/starter/starter.mojo
new file mode 100644
index 00000000..184038dc
--- /dev/null
+++ b/challenges/medium/84_swiglu_mlp_block/starter/starter.mojo
@@ -0,0 +1,9 @@
+from gpu.host import DeviceContext
+from gpu.id import block_dim, block_idx, thread_idx
+from memory import UnsafePointer
+from math import ceildiv
+
+# x, W_gate, W_up, W_down, output are device pointers
+@export
+def solve(x: UnsafePointer[Float32], W_gate: UnsafePointer[Float32], W_up: UnsafePointer[Float32], W_down: UnsafePointer[Float32], output: UnsafePointer[Float32], M: Int32, d_model: Int32, d_ffn: Int32):
+ pass
diff --git a/challenges/medium/84_swiglu_mlp_block/starter/starter.pytorch.py b/challenges/medium/84_swiglu_mlp_block/starter/starter.pytorch.py
new file mode 100644
index 00000000..2af6572f
--- /dev/null
+++ b/challenges/medium/84_swiglu_mlp_block/starter/starter.pytorch.py
@@ -0,0 +1,15 @@
+import torch
+
+
+# x, W_gate, W_up, W_down, output are tensors on the GPU
+def solve(
+ x: torch.Tensor,
+ W_gate: torch.Tensor,
+ W_up: torch.Tensor,
+ W_down: torch.Tensor,
+ output: torch.Tensor,
+ M: int,
+ d_model: int,
+ d_ffn: int,
+):
+ pass
diff --git a/challenges/medium/84_swiglu_mlp_block/starter/starter.triton.py b/challenges/medium/84_swiglu_mlp_block/starter/starter.triton.py
new file mode 100644
index 00000000..47552d0c
--- /dev/null
+++ b/challenges/medium/84_swiglu_mlp_block/starter/starter.triton.py
@@ -0,0 +1,17 @@
+import torch
+import triton
+import triton.language as tl
+
+
+# x, W_gate, W_up, W_down, output are tensors on the GPU
+def solve(
+ x: torch.Tensor,
+ W_gate: torch.Tensor,
+ W_up: torch.Tensor,
+ W_down: torch.Tensor,
+ output: torch.Tensor,
+ M: int,
+ d_model: int,
+ d_ffn: int,
+):
+ pass