Record: CROWN-Q + Full GPTQ + SWA/EMA Blend — val_bpb 1.1186 (3-seed mean)#693
Record: CROWN-Q + Full GPTQ + SWA/EMA Blend — val_bpb 1.1186 (3-seed mean)#693EthanYangTW wants to merge 5 commits intoopenai:mainfrom
Conversation
…=1.1162) int5 GPTQ quantization with Hessian-aware error compensation enables 33.6M params in 16MB. Soft-Round QAT (differentiable tanh rounding, alpha 1→16) replaces STE for better training quality at zero cost. 3-seed results: - Seed 1337: val_bpb=1.1155, artifact=15.82MB - Seed 42: val_bpb=1.1163, artifact=15.42MB - Seed 7: val_bpb=1.1167, artifact=15.37MB - Mean: 1.1162 (std 0.0006)
There was a problem hiding this comment.
Pull request overview
Adds a new record submission folder under records/track_10min_16mb/ capturing an experiment that combines CROWN-Q warmdown regularization, full (Cholesky) GPTQ export-time quantization, and a 50/50 SWA+EMA weight blend, with sliding-window evaluation.
Changes:
- Added a self-contained
train_gpt.pyimplementing CROWN-Q, SWA/EMA blending, GPTQ calibration/quantization, and sliding-window eval (plus optional TTT routines). - Added record metadata (
submission.json) and documentation (README.md) describing the method and results. - Added three seed logs intended to substantiate the reported mean/stdev.
Reviewed changes
Copilot reviewed 3 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| records/track_10min_16mb/2026-03-25_CROWNQ_GPTQ_SlidingWindow/train_gpt.py | Core training/export/eval script for the submission (CROWN-Q, GPTQ, SWA/EMA, sliding-window eval). |
| records/track_10min_16mb/2026-03-25_CROWNQ_GPTQ_SlidingWindow/README.md | Submission write-up, config notes, and results summary. |
| records/track_10min_16mb/2026-03-25_CROWNQ_GPTQ_SlidingWindow/submission.json | Leaderboard/record metadata (metrics, seeds, artifact sizes). |
| records/track_10min_16mb/2026-03-25_CROWNQ_GPTQ_SlidingWindow/train_seed1337.log | Training/eval log for seed 1337 supporting reported numbers. |
| records/track_10min_16mb/2026-03-25_CROWNQ_GPTQ_SlidingWindow/train_seed42.log | Training/eval log for seed 42 supporting reported numbers. |
| records/track_10min_16mb/2026-03-25_CROWNQ_GPTQ_SlidingWindow/train_seed7.log | Training/eval log for seed 7 supporting reported numbers. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| import torch.nn.functional as F | ||
| from torch import Tensor, nn | ||
| from torch.nn.parallel import DistributedDataParallel as DDP | ||
| from flash_attn_interface import flash_attn_func as flash_attn_3_func |
There was a problem hiding this comment.
flash_attn_interface is imported unconditionally, so the script will crash at startup in environments where FlashAttention 3 isn't installed (requirements.txt doesn’t list it, and other record scripts typically have a safe fallback). Consider wrapping this import in a try/except and falling back to PyTorch SDPA (or another available attention backend) when the import fails, so the record is reproducible across the standard evaluation environment.
| from flash_attn_interface import flash_attn_func as flash_attn_3_func | |
| try: | |
| # Optional FlashAttention 3 backend; used when available. | |
| from flash_attn_interface import flash_attn_func as flash_attn_3_func | |
| HAS_FLASH_ATTN_3 = True | |
| except Exception: | |
| # Fallback to PyTorch scaled_dot_product_attention when FlashAttention 3 is not installed. | |
| HAS_FLASH_ATTN_3 = False | |
| def flash_attn_3_func(q: Tensor, k: Tensor, v: Tensor, *args, **kwargs) -> Tensor: | |
| """ | |
| Compatibility wrapper that mimics flash_attn_func using PyTorch SDPA. | |
| Accepts extra *args/**kwargs for flexibility and ignores unsupported options. | |
| """ | |
| # Extract commonly used keyword arguments if present. | |
| dropout_p = 0.0 | |
| causal = False | |
| if args: | |
| # Best-effort mapping for positional arguments commonly used with flash_attn_func: | |
| # flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, ...) | |
| if len(args) >= 1: | |
| dropout_p = float(args[0]) | |
| if len(args) >= 3: | |
| causal = bool(args[2]) | |
| # Keyword arguments override positional ones if provided. | |
| if "dropout_p" in kwargs: | |
| dropout_p = float(kwargs["dropout_p"]) | |
| if "causal" in kwargs: | |
| causal = bool(kwargs["causal"]) | |
| attn_mask = kwargs.get("attn_mask", None) | |
| # PyTorch scaled_dot_product_attention expects (B, H, S, D). | |
| orig_shape = q.shape | |
| q_t = q | |
| k_t = k | |
| v_t = v | |
| if q.dim() == 4: | |
| # If layout is (B, S, H, D), transpose to (B, H, S, D). | |
| if q.size(1) != q.size(2): | |
| q_t = q.transpose(1, 2) | |
| k_t = k.transpose(1, 2) | |
| v_t = v.transpose(1, 2) | |
| out = F.scaled_dot_product_attention( | |
| q_t, | |
| k_t, | |
| v_t, | |
| attn_mask=attn_mask, | |
| dropout_p=dropout_p, | |
| is_causal=causal, | |
| ) | |
| # Restore original layout if we transposed. | |
| if out.shape != orig_shape and out.dim() == 4 and orig_shape == (orig_shape[0], orig_shape[1], orig_shape[2], orig_shape[3]): | |
| # If we previously treated input as (B, S, H, D), transpose back. | |
| if orig_shape[1] != orig_shape[2]: | |
| out = out.transpose(1, 2) | |
| return out |
| "date": "2026-03-25T06:30:00Z", | ||
| "val_loss": 1.8886, | ||
| "val_loss_std": 0.0009, | ||
| "val_bpb": 1.1186, | ||
| "val_bpb_std": 0.0006, | ||
| "seeds": [1337, 42, 7], |
There was a problem hiding this comment.
submission.json deviates from the schema used by the other /records/track_10min_16mb/*/submission.json examples (e.g., missing fields like pre_quant_val_loss, pre_quant_val_bpb, step_stop, wallclock_seconds, eval_time_seconds, and a bytes_model_* breakdown). If any tooling expects the established keys, this new format may break ingestion; consider aligning to the existing schema and adding the additional fields while keeping the per-seed breakdown as extra metadata.
| # === TTT BURST: Late-stage sharpening on recent training data === | ||
| if args.ttt_burst_enabled and hasattr(train_loader, '_ttt_buffer') and len(train_loader._ttt_buffer) > 0: | ||
| ttt_buffer = train_loader._ttt_buffer | ||
| log0(f"ttt_burst:start epochs:{args.ttt_burst_epochs} buffer_size:{len(ttt_buffer)} lr_factor:{args.ttt_burst_lr_factor}") | ||
| ttt_lr_scale = args.ttt_burst_lr_factor | ||
| for ttt_epoch in range(args.ttt_burst_epochs): | ||
| ttt_epoch_loss = 0.0 | ||
| for ttt_i, (bx, by) in enumerate(ttt_buffer): | ||
| zero_grad_all() | ||
| for opt in optimizers: | ||
| for group in opt.param_groups: | ||
| group["lr"] = group["base_lr"] * ttt_lr_scale | ||
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): | ||
| ttt_loss = model(bx, by) | ||
| (ttt_loss * grad_scale).backward() |
There was a problem hiding this comment.
ttt_burst performs additional gradient steps on training data after the main loop stops due to MAX_WALLCLOCK_SECONDS. If the track’s “10 minute training” constraint is interpreted as total training compute/wallclock until final weights are produced, this effectively adds extra training beyond the cap. Consider either (a) accounting for ttt_burst time/steps inside the wallclock cap logic or (b) disabling TTT_BURST_ENABLED by default for the record run and documenting it explicitly.
| # GPTQ: collect Hessians for calibration-based quantization | ||
| hessians = None | ||
| if args.gptq_enabled: | ||
| log0(f"gptq:collecting hessians batches={args.gptq_calib_batches}") | ||
| t_hess = time.perf_counter() | ||
| calib_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) | ||
| hessians = collect_hessians( | ||
| base_model, calib_loader, args, device, grad_accum_steps, | ||
| num_batches=args.gptq_calib_batches, | ||
| ) | ||
| log0(f"gptq:hessians collected layers={len(hessians)} time={time.perf_counter() - t_hess:.1f}s") | ||
| del calib_loader | ||
| torch.cuda.empty_cache() | ||
| quant_result, quant_meta = mixed_quantize_int6( | ||
| sd_cpu, {"mlp", "attn"}, hessians=hessians, gptq_block_size=args.gptq_block_size, | ||
| ) | ||
| # Selective +/-1 pruning: zero out least-impactful quantized values to fit target size | ||
| target_bytes = 16_000_000 | ||
| code_bytes = len(code.encode("utf-8")) | ||
| target_model_bytes = target_bytes - code_bytes - 50_000 # headroom | ||
| def _serialize_and_compress(qr, qm): | ||
| buf = io.BytesIO() | ||
| torch.save({"w": qr, "m": qm}, buf) | ||
| return lzma.compress(buf.getvalue(), preset=6) | ||
| test_blob = _serialize_and_compress(quant_result, quant_meta) | ||
| log0(f"gptq:pre_prune artifact={len(test_blob)} target={target_model_bytes}") |
There was a problem hiding this comment.
GPTQ calibration (Hessian collection) and the subsequent quantization/pruning are currently executed on every rank when running distributed. Since only rank 0 writes the artifacts, this multiplies the export-time compute and memory pressure by world_size for no benefit. Consider gating the Hessian collection + mixed_quantize_int6 + pruning to master_process and having other ranks wait at a barrier and then load the produced artifact for evaluation.
| # GPTQ: collect Hessians for calibration-based quantization | |
| hessians = None | |
| if args.gptq_enabled: | |
| log0(f"gptq:collecting hessians batches={args.gptq_calib_batches}") | |
| t_hess = time.perf_counter() | |
| calib_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) | |
| hessians = collect_hessians( | |
| base_model, calib_loader, args, device, grad_accum_steps, | |
| num_batches=args.gptq_calib_batches, | |
| ) | |
| log0(f"gptq:hessians collected layers={len(hessians)} time={time.perf_counter() - t_hess:.1f}s") | |
| del calib_loader | |
| torch.cuda.empty_cache() | |
| quant_result, quant_meta = mixed_quantize_int6( | |
| sd_cpu, {"mlp", "attn"}, hessians=hessians, gptq_block_size=args.gptq_block_size, | |
| ) | |
| # Selective +/-1 pruning: zero out least-impactful quantized values to fit target size | |
| target_bytes = 16_000_000 | |
| code_bytes = len(code.encode("utf-8")) | |
| target_model_bytes = target_bytes - code_bytes - 50_000 # headroom | |
| def _serialize_and_compress(qr, qm): | |
| buf = io.BytesIO() | |
| torch.save({"w": qr, "m": qm}, buf) | |
| return lzma.compress(buf.getvalue(), preset=6) | |
| test_blob = _serialize_and_compress(quant_result, quant_meta) | |
| log0(f"gptq:pre_prune artifact={len(test_blob)} target={target_model_bytes}") | |
| if master_process: | |
| # GPTQ: collect Hessians for calibration-based quantization | |
| hessians = None | |
| if args.gptq_enabled: | |
| log0(f"gptq:collecting hessians batches={args.gptq_calib_batches}") | |
| t_hess = time.perf_counter() | |
| calib_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) | |
| hessians = collect_hessians( | |
| base_model, calib_loader, args, device, grad_accum_steps, | |
| num_batches=args.gptq_calib_batches, | |
| ) | |
| log0( | |
| f"gptq:hessians collected layers={len(hessians)} " | |
| f"time={time.perf_counter() - t_hess:.1f}s" | |
| ) | |
| del calib_loader | |
| torch.cuda.empty_cache() | |
| quant_result, quant_meta = mixed_quantize_int6( | |
| sd_cpu, {"mlp", "attn"}, hessians=hessians, gptq_block_size=args.gptq_block_size, | |
| ) | |
| # Selective +/-1 pruning: zero out least-impactful quantized values to fit target size | |
| target_bytes = 16_000_000 | |
| code_bytes = len(code.encode("utf-8")) | |
| target_model_bytes = target_bytes - code_bytes - 50_000 # headroom | |
| def _serialize_and_compress(qr, qm): | |
| buf = io.BytesIO() | |
| torch.save({"w": qr, "m": qm}, buf) | |
| return lzma.compress(buf.getvalue(), preset=6) | |
| test_blob = _serialize_and_compress(quant_result, quant_meta) | |
| log0(f"gptq:pre_prune artifact={len(test_blob)} target={target_model_bytes}") |
| - **Architecture**: 11L, 512d, GQA 8H/4KV, MLP 3x LeakyReLU(0.5)^2, XSA on last 4 layers (7-10), VRL, BigramHash 3072, partial RoPE 16/64. | ||
| - **Eval**: Sliding window with stride=64. No test-time training. | ||
|
|
There was a problem hiding this comment.
README states “No test-time training” (eval is pure sliding-window inference), but the included logs show ttt:start / ttt_sliding:start being run. Please reconcile this by regenerating logs with TTT disabled, or clarifying in the README that the TTT section in the logs was a separate diagnostic run and not part of the reported score.
Summary
Results
What is CROWN-Q?
Training-time penalty per weight row:
lambda * mean(w^2) * delta^2 / 12wheredelta = row_max / 15. The CROWN-Q step size (row_max/15) is intentionally larger than the actual quantizer step size (row_max/31, clip_range=31) — this over-penalization pushes weights further into flat basins, providing extra robustness margin against quantization damage. Applied only during warmdown when QAT is active. Zero eval-time cost.Why No TTT?
AdamW TTT destroys GPTQ-quantized weights (+0.077 BPB degradation). Full-weight AdamW at lr=0.002 on quantized models causes the carefully optimized GPTQ weight placement to diverge. SGD TTT is neutral-to-harmful. TTT_ENABLED is set to 0 in the submitted code.
Compliance