diff --git a/open_mythos/main.py b/open_mythos/main.py index 10de093..a176b73 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -823,7 +823,7 @@ def forward( B, T, D = h.shape halted = torch.zeros(B, T, device=h.device, dtype=torch.bool) - cumulative_p = torch.zeros(B, T, device=h.device) + cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype) h_out = torch.zeros_like(h) for t in range(n_loops): @@ -848,10 +848,10 @@ def forward( remainder, p, ) - weight = weight * still_running.float() + weight = weight * still_running.to(h.dtype) h_out = h_out + weight.unsqueeze(-1) * h - cumulative_p = cumulative_p + p * still_running.float() + cumulative_p = cumulative_p + p * still_running.to(h.dtype) halted = halted | (cumulative_p >= self.cfg.act_threshold) # Only short-circuit when there is no KV cache to keep consistent. @@ -938,7 +938,7 @@ def _init_weights(self) -> None: nn.init.normal_(m.weight, std=0.02) @staticmethod - def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: + def _causal_mask(seq_len: int, device: torch.device, dtype: torch.dtype = torch.float32) -> torch.Tensor: """ Build an additive causal mask: 0 on and below the diagonal, -inf above. @@ -949,7 +949,7 @@ def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: Returns: Tensor of shape (1, 1, seq_len, seq_len) broadcastable over (B, H, T, S) """ - mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=device) + mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=device, dtype=dtype) return torch.triu(mask, diagonal=1) def forward( @@ -983,7 +983,7 @@ def forward( freqs_cis = ( self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis )[start_pos : start_pos + T] - mask = self._causal_mask(T, device) if T > 1 else None + mask = self._causal_mask(T, device, dtype=x.dtype) if T > 1 else None for i, layer in enumerate(self.prelude): x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"prelude_{i}")