Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,24 +266,24 @@ def setup_method(self):

def test_output_shape(self):
x = torch.randn(B, T, self.cfg.dim)
out = self.attn(x, self.freqs)
out = self.attn(x, self.freqs[:T])
assert out.shape == (B, T, self.cfg.dim)

def test_kv_cache_accumulates(self):
cache = {}
x = torch.randn(B, T, self.cfg.dim)
self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0")
self.attn(x, self.freqs[:T], kv_cache=cache, cache_key="layer0")
assert "layer0" in cache
k_len = cache["layer0"]["k"].shape[1]
# second call adds T more tokens
self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0")
self.attn(x, self.freqs[:T], kv_cache=cache, cache_key="layer0")
assert cache["layer0"]["k"].shape[1] == k_len + T

def test_with_causal_mask(self):
x = torch.randn(B, T, self.cfg.dim)
mask = torch.full((1, 1, T, T), float("-inf"))
mask = torch.triu(mask, diagonal=1)
out = self.attn(x, self.freqs, mask=mask)
out = self.attn(x, self.freqs[:T], mask=mask)
assert out.shape == (B, T, self.cfg.dim)


Expand All @@ -302,13 +302,13 @@ def setup_method(self):

def test_output_shape(self):
x = torch.randn(B, T, self.cfg.dim)
out = self.attn(x, self.freqs)
out = self.attn(x, self.freqs[:T])
assert out.shape == (B, T, self.cfg.dim)

def test_cache_stores_compressed_kv(self):
cache = {}
x = torch.randn(B, T, self.cfg.dim)
self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0")
self.attn(x, self.freqs[:T], kv_cache=cache, cache_key="mla0")
assert "c_kv" in cache["mla0"]
assert "k_rope" in cache["mla0"]
# c_kv should have kv_lora_rank as last dim, not full K/V
Expand All @@ -317,15 +317,15 @@ def test_cache_stores_compressed_kv(self):
def test_cache_accumulates_across_steps(self):
cache = {}
x = torch.randn(B, T, self.cfg.dim)
self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0")
self.attn(x, self.freqs[:T], kv_cache=cache, cache_key="mla0")
first_len = cache["mla0"]["c_kv"].shape[1]
self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0")
self.attn(x, self.freqs[:T], kv_cache=cache, cache_key="mla0")
assert cache["mla0"]["c_kv"].shape[1] == first_len + T

def test_with_causal_mask(self):
x = torch.randn(B, T, self.cfg.dim)
mask = torch.triu(torch.full((1, 1, T, T), float("-inf")), diagonal=1)
out = self.attn(x, self.freqs, mask=mask)
out = self.attn(x, self.freqs[:T], mask=mask)
assert out.shape == (B, T, self.cfg.dim)


Expand Down Expand Up @@ -432,21 +432,21 @@ def test_gqa_output_shape(self):
block = TransformerBlock(cfg, use_moe=False)
freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len)
x = torch.randn(B, T, cfg.dim)
assert block(x, freqs).shape == (B, T, cfg.dim)
assert block(x, freqs[:T]).shape == (B, T, cfg.dim)

def test_mla_output_shape(self):
cfg = mla_cfg()
block = TransformerBlock(cfg, use_moe=False)
freqs = precompute_rope_freqs(cfg.qk_rope_head_dim, cfg.max_seq_len)
x = torch.randn(B, T, cfg.dim)
assert block(x, freqs).shape == (B, T, cfg.dim)
assert block(x, freqs[:T]).shape == (B, T, cfg.dim)

def test_moe_block_output_shape(self):
cfg = gqa_cfg()
block = TransformerBlock(cfg, use_moe=True)
freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len)
x = torch.randn(B, T, cfg.dim)
assert block(x, freqs).shape == (B, T, cfg.dim)
assert block(x, freqs[:T]).shape == (B, T, cfg.dim)

def test_attn_type_selection(self):
assert isinstance(TransformerBlock(gqa_cfg()).attn, GQAttention)
Expand Down Expand Up @@ -486,7 +486,10 @@ def test_spectral_radius_stable_after_large_grad_step(self):
loss.backward()
opt.step()
A = self.inj.get_A()
assert A.max().item() < 1.0
# ZOH discretization: exp(-exp(clamp(x, -20, 20))). At the clamp
# boundary, exp(-exp(-20)) ≈ 1.0 in float32. The stability guarantee
# is A ∈ (0, 1] — 1.0 means no decay (neutral), not divergence.
assert A.max().item() <= 1.0


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -526,20 +529,20 @@ def setup_method(self):
def test_output_shape(self):
h = torch.randn(B, T, self.cfg.dim)
e = torch.randn(B, T, self.cfg.dim)
out = self.block(h, e, self.freqs)
out = self.block(h, e, self.freqs[:T])
assert out.shape == (B, T, self.cfg.dim)

def test_more_loops_changes_output(self):
h = torch.randn(B, T, self.cfg.dim)
e = torch.randn(B, T, self.cfg.dim)
out1 = self.block(h.clone(), e.clone(), self.freqs, n_loops=1)
out3 = self.block(h.clone(), e.clone(), self.freqs, n_loops=3)
out1 = self.block(h.clone(), e.clone(), self.freqs[:T], n_loops=1)
out3 = self.block(h.clone(), e.clone(), self.freqs[:T], n_loops=3)
assert not torch.allclose(out1, out3)

def test_single_loop_runs(self):
h = torch.randn(B, T, self.cfg.dim)
e = torch.randn(B, T, self.cfg.dim)
out = self.block(h, e, self.freqs, n_loops=1)
out = self.block(h, e, self.freqs[:T], n_loops=1)
assert out.shape == (B, T, self.cfg.dim)


Expand Down