From 5354ddcb8b7c7fee066babcce1534a7c8adf3134 Mon Sep 17 00:00:00 2001 From: tonyzdev Date: Tue, 21 Apr 2026 20:00:09 +0800 Subject: [PATCH] Fix dtype leaks that break bfloat16 training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two small dtype-promotion bugs silently upcast the hidden state from bf16 to fp32 partway through the forward pass, producing RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16 at the next nn.Linear call. Both places omit an explicit `dtype=` argument and fall back to torch's default float32, poisoning any tensor they subsequently combine with. 1) `OpenMythos._causal_mask` builds its additive mask with `torch.full(..., float("-inf"), device=device)`, i.e. fp32. When the attention adds this mask to a bf16 `attn` tensor, `attn` becomes fp32, then the subsequent `torch.matmul(attn, v)` crashes because `v` is still bf16. 2) `RecurrentBlock.forward` allocates the ACT accumulator `cumulative_p = torch.zeros(B, T, device=h.device)` and uses `still_running.float()` in the weight update. Both are fp32 regardless of `h.dtype`, so `h_out = h_out + weight.unsqueeze(-1) * h` silently upcasts the returned hidden state to fp32. The next Coda layer then fails at `q_down(h)` for the same dtype-mismatch reason. The fix threads `h.dtype` / `x.dtype` through both sites. With both patches applied, `mythos_1b` and a custom 150M MLA/MoE variant train end-to-end in bf16 on H100 / A40 with no dtype errors, ρ(A) stable at 0.357, and zero impact on fp32 behaviour. Tested on the latest main (torch 2.8.0+cu128, H100 SXM). --- open_mythos/main.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/open_mythos/main.py b/open_mythos/main.py index 10de093..a176b73 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -823,7 +823,7 @@ def forward( B, T, D = h.shape halted = torch.zeros(B, T, device=h.device, dtype=torch.bool) - cumulative_p = torch.zeros(B, T, device=h.device) + cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype) h_out = torch.zeros_like(h) for t in range(n_loops): @@ -848,10 +848,10 @@ def forward( remainder, p, ) - weight = weight * still_running.float() + weight = weight * still_running.to(h.dtype) h_out = h_out + weight.unsqueeze(-1) * h - cumulative_p = cumulative_p + p * still_running.float() + cumulative_p = cumulative_p + p * still_running.to(h.dtype) halted = halted | (cumulative_p >= self.cfg.act_threshold) # Only short-circuit when there is no KV cache to keep consistent. @@ -938,7 +938,7 @@ def _init_weights(self) -> None: nn.init.normal_(m.weight, std=0.02) @staticmethod - def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: + def _causal_mask(seq_len: int, device: torch.device, dtype: torch.dtype = torch.float32) -> torch.Tensor: """ Build an additive causal mask: 0 on and below the diagonal, -inf above. @@ -949,7 +949,7 @@ def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: Returns: Tensor of shape (1, 1, seq_len, seq_len) broadcastable over (B, H, T, S) """ - mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=device) + mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=device, dtype=dtype) return torch.triu(mask, diagonal=1) def forward( @@ -983,7 +983,7 @@ def forward( freqs_cis = ( self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis )[start_pos : start_pos + T] - mask = self._causal_mask(T, device) if T > 1 else None + mask = self._causal_mask(T, device, dtype=x.dtype) if T > 1 else None for i, layer in enumerate(self.prelude): x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"prelude_{i}")