diff --git a/README.md b/README.md index af2d774..40aaacb 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ print(f"[{attn_type.upper()}] Generated shape: {out.shape}") A = model.recurrent.injection.get_A() rho = torch.linalg.eigvals(A).abs().max().item() print( - f"[{attn_type.upper()}] Spectral radius ρ(A) = {rho:.4f} (must be < 1)" + f"[{attn_type.upper()}] Spectral radius ρ(A): {rho:.4f} (must be < 1)" ) ``` diff --git a/examples/moda_example.py b/examples/moda_example.py index bffc92e..c9a8ff9 100644 --- a/examples/moda_example.py +++ b/examples/moda_example.py @@ -39,6 +39,7 @@ labels = torch.randint(0, cfg.vocab_size, (B, T), device=device) logits, loss = model(input_ids, labels) + assert not torch.isnan(loss), "Loss is NaN!" assert logits.shape == (B, T, cfg.vocab_size) print(f"Logits shape : {logits.shape}") print(f"Loss (LM + balance): {loss.item():.4f}")