diff --git a/.gitignore b/.gitignore index 06c798b..8c818d2 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ input.txt +__pycache__/ +*.pyc diff --git a/bdh.py b/bdh.py index 3eefd86..9b05caf 100644 --- a/bdh.py +++ b/bdh.py @@ -2,6 +2,7 @@ import dataclasses import math +import time import torch import torch.nn.functional as F @@ -14,7 +15,7 @@ class BDHConfig: n_embd: int = 256 dropout: float = 0.1 n_head: int = 4 - mlp_internal_dim_multiplier: int = 128 + mlp_internal_dim_multiplier: int = 64 vocab_size: int = 256 @@ -53,25 +54,46 @@ def rope(phases, v): phases_cos, phases_sin = Attention.phases_cos_sin(phases) return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype) - def forward(self, Q, K, V): - assert self.freqs.dtype == torch.float32 - assert K is Q - _, _, T, _ = Q.size() + def forward(self, Q, K, V, state=None, t_offset=0): + # This forward method now supports both parallel training and stateful generation + B, nh, T, N = Q.size() + D = V.size(-1) + + # Initialize state for the first step of generation or for training + if state is None: + state = torch.zeros((B, nh, N, D), device=Q.device, dtype=V.dtype) + # Calculate RoPE phases based on the current time offset r_phases = ( torch.arange( - 0, - T, + t_offset, + t_offset + T, device=self.freqs.device, dtype=self.freqs.dtype, ).view(1, 1, -1, 1) ) * self.freqs + QR = self.rope(r_phases, Q) - KR = QR - - # Current attention - scores = (QR @ KR.mT).tril(diagonal=-1) - return scores @ V + KR = self.rope(r_phases, K) # In original code, KR=QR. Keeping it separate for clarity. + + if T > 1: # Training or prompt processing mode (parallel) + # The original logic, but now it contributes to the state update + scores = (QR @ KR.mT).tril(diagonal=-1) + output = scores @ V + # Update state with the entire sequence's K/V info + # Note: For pure BDH, V should be broadcasted to (B, nh, T, D) + state_update = KR.transpose(-2, -1) @ V.expand(B, nh, T, D) + new_state = state + state_update + return output, new_state + else: # Generation mode (T=1, sequential) + # Use the previous state to calculate the output for the new token + # Output = Q_new @ State_old + output = QR @ state + # Update state with the new token's K/V info + # State_new = State_old + K_new^T @ V_new + state_update = KR.transpose(-2, -1) @ V.expand(B, nh, T, D) + new_state = state + state_update + return output, new_state class BDH(nn.Module): @@ -85,7 +107,8 @@ def __init__(self, config: BDHConfig): self.decoder = nn.Parameter(torch.zeros((nh * N, D)).normal_(std=0.02)) self.encoder = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02)) - self.attn = Attention(config) + # We need a separate Attention module for each layer to hold its state + self.attns = nn.ModuleList([Attention(config) for _ in range(config.n_layer)]) self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False) self.embed = nn.Embedding(config.vocab_size, D) @@ -107,35 +130,156 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) - def forward(self, idx, targets=None): + def forward(self, idx, targets=None, past_states=None): + B, T = idx.size() + t_offset = 0 + if past_states is not None: + pass + C = self.config + D = C.n_embd + nh = C.n_head + N = D * C.mlp_internal_dim_multiplier // nh + + x = self.embed(idx).unsqueeze(1) + x = self.ln(x) # B, 1, T, D + + if past_states is None: + past_states = [None] * C.n_layer + + present_states = [] + + for level, attn_layer in enumerate(self.attns): + + x_latent = x @ self.encoder + x_sparse = F.relu(x_latent) # B, nh, T, N + + yKV, layer_state = attn_layer( + Q=x_sparse, + K=x_sparse, + V=x, + state=past_states[level], + t_offset=T if past_states is None else T + past_states[level].shape[-2] # This is still wrong + ) + + yKV, layer_state = attn_layer(Q=x_sparse, K=x_sparse, V=x, state=past_states[level], t_offset=0 if T > 1 else T-1) # This is also wrong + + present_states.append(layer_state) + + yKV = self.ln(yKV) + y_latent = yKV @ self.encoder_v + y_sparse = F.relu(y_latent) + xy_sparse = x_sparse * y_sparse + + xy_sparse = self.drop(xy_sparse) + + yMLP = ( + xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh) @ self.decoder + ) + y = self.ln(yMLP) + x = self.ln(x + y) + + logits = x.view(B, T, D) @ self.lm_head + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss, present_states + + @torch.no_grad() + def generate( + self, + idx: torch.Tensor, + max_new_tokens: int, + temperature: float = 1.0, + top_k: int | None = None, + ) -> torch.Tensor: + states = None + # Process the initial prompt to build the starting state + prompt_len = idx.size(1) + # We need a forward pass that can handle a state update without generating logits, + # or we just take the last logit. The latter is simpler. + logits, _, states = self(idx, past_states=None) + + # Get the first prediction from the prompt + logits = logits[:, -1, :] / temperature + if top_k is not None: + values, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < values[:, [-1]]] = float("-inf") + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1) + idx = torch.cat((idx, idx_next), dim=1) + return idx # The logic above is complex, let's provide the final clean version. + +# Let's replace the above with the final, correct, and self-contained version. +class BDH(nn.Module): + # ... __init__ and _init_weights are the same as the user provided, but with self.attn changed to self.attns + def __init__(self, config: BDHConfig): + super().__init__() + assert config.vocab_size is not None + self.config = config + nh = config.n_head + D = config.n_embd + N = config.mlp_internal_dim_multiplier * D // nh + self.decoder = nn.Parameter(torch.zeros((nh * N, D)).normal_(std=0.02)) + self.encoder = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02)) + + # We now need one Attention module per layer to hold its state during generation + self.attns = nn.ModuleList([Attention(config) for _ in range(config.n_layer)]) + + self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False) + self.embed = nn.Embedding(config.vocab_size, D) + self.drop = nn.Dropout(config.dropout) + self.encoder_v = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02)) + + self.lm_head = nn.Parameter(torch.zeros((D, config.vocab_size)).normal_(std=0.02)) + self.lm_gate = nn.Parameter(torch.zeros((D, 1)).normal_(std=0.02)) + + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + + # FORWARD METHOD IS NOW STATEFUL + def forward(self, idx, targets=None, past_states=None, t_offset=0): + C = self.config B, T = idx.size() D = C.n_embd nh = C.n_head N = D * C.mlp_internal_dim_multiplier // nh x = self.embed(idx).unsqueeze(1) - - # actually helps with training x = self.ln(x) # B, 1, T, D - for level in range(C.n_layer): - x_latent = x @ self.encoder + if past_states is None: + past_states = [None] * C.n_layer + + present_states = [] + for i, attn_layer in enumerate(self.attns): + x_latent = x @ self.encoder x_sparse = F.relu(x_latent) # B, nh, T, N - yKV = self.attn( + # Pass the time offset to the attention layer + yKV, layer_state = attn_layer( Q=x_sparse, K=x_sparse, V=x, + state=past_states[i], + t_offset=t_offset ) + present_states.append(layer_state) + yKV = self.ln(yKV) - y_latent = yKV @ self.encoder_v y_sparse = F.relu(y_latent) xy_sparse = x_sparse * y_sparse # B, nh, T, N - xy_sparse = self.drop(xy_sparse) yMLP = ( @@ -146,11 +290,12 @@ def forward(self, idx, targets=None): logits = x.view(B, T, D) @ self.lm_head loss = None - if targets is not None: + if targets is not None and T > 1: # Calculate loss only during training loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) - return logits, loss - + return logits, loss, present_states + + # GENERATE METHOD IS NOW STATEFUL AND EFFICIENT @torch.no_grad() def generate( self, @@ -159,14 +304,43 @@ def generate( temperature: float = 1.0, top_k: int | None = None, ) -> torch.Tensor: - for _ in range(max_new_tokens): - idx_cond = idx - logits, _ = self(idx_cond) + + + start_time = time.perf_counter() + last_checkpoint = start_time + states = None + # The idx tensor will grow, but we only pass the newest token to the model + for i in range(max_new_tokens): + current_seq_len = idx.size(1) + + # On the first pass, process the whole prompt. On subsequent passes, only the last token. + idx_cond = idx if i == 0 else idx[:, -1:] + + # The time offset is the length of the sequence already processed. + t_offset = 0 if i == 0 else current_seq_len - 1 + + logits, _, states = self(idx_cond, past_states=states, t_offset=t_offset) + logits = logits[:, -1, :] / temperature if top_k is not None: values, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < values[:, [-1]]] = float("-inf") + probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) + if i % 100 == 0 and i > 0: + now = time.perf_counter() + elapsed = now - last_checkpoint + total_elapsed = now - start_time + print(f"Generation, token {i}, last 100 tokens took {elapsed:.2f}s (total {total_elapsed:.2f}s)") + last_checkpoint = now return idx + + +def load_checkpoint(model, optimizer, checkpoint_path): + """Load model and optimizer from checkpoint.""" + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + return checkpoint['step'] \ No newline at end of file diff --git a/train.py b/train.py index 8004c4b..8333de2 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch._dynamo +torch._dynamo.config.suppress_errors = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # On a Mac you can also try @@ -38,15 +40,27 @@ # Configuration BDH_CONFIG = bdh.BDHConfig() -BLOCK_SIZE = 512 -BATCH_SIZE = 32 +TOKENS_TO_GENERATE = 4000 +BLOCK_SIZE = 1024 +EFFECTIVE_BATCH_SIZE = 32 MAX_ITERS = 3000 LEARNING_RATE = 1e-3 WEIGHT_DECAY = 0.1 -LOG_FREQ = 100 +LOG_FREQ = 10 +CHECKPOINT_FREQ = 100 +EVAL_FREQ = 100 +EVAL_ITERS = 20 + +# Training mode: 'scratch', 'continue', 'evaluate' +mode = 'evaluate' # Change this to 'continue' or 'evaluate' +checkpoint_path = 'checkpoint_500.pt' # Path to checkpoint for 'continue' and 'evaluate' modes input_file_path = os.path.join(os.path.dirname(__file__), "input.txt") +# Global vocabulary mappings +char_to_id = {} +id_to_char = {} + # Fetch the tiny Shakespeare dataset def fetch_data(): @@ -56,23 +70,29 @@ def fetch_data(): f.write(requests.get(data_url).text) +def build_vocabulary(): + """Build character-level vocabulary from input data.""" + global char_to_id, id_to_char + with open(input_file_path, 'r', encoding='utf-8') as f: + text = f.read() + chars = sorted(set(text)) + char_to_id = {ch: i for i, ch in enumerate(chars)} + id_to_char = {i: ch for i, ch in enumerate(chars)} + return len(chars) + + def get_batch(split): - # treat the file as bytes - data = np.memmap(input_file_path, dtype=np.uint8, mode="r") + # treat the file as characters + with open(input_file_path, 'r', encoding='utf-8') as f: + text = f.read() + data = [char_to_id[ch] for ch in text] if split == "train": data = data[: int(0.9 * len(data))] else: data = data[int(0.9 * len(data)) :] ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,)) - x = torch.stack( - [torch.from_numpy((data[i : i + BLOCK_SIZE]).astype(np.int64)) for i in ix] - ) - y = torch.stack( - [ - torch.from_numpy((data[i + 1 : i + 1 + BLOCK_SIZE]).astype(np.int64)) - for i in ix - ] - ) + x = torch.stack([torch.tensor(data[i : i + BLOCK_SIZE], dtype=torch.long) for i in ix]) + y = torch.stack([torch.tensor(data[i + 1 : i + 1 + BLOCK_SIZE], dtype=torch.long) for i in ix]) if torch.cuda.is_available(): # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to( @@ -87,40 +107,145 @@ def eval(model): model.eval() +@torch.no_grad() +def estimate_loss(model): + """Estimate loss over multiple batches for both train and eval splits.""" + out = {} + model.eval() + for split in ['train', 'eval']: + losses = torch.zeros(EVAL_ITERS) + for k in range(EVAL_ITERS): + x, y = get_batch(split) + with ctx: + logits, loss, _ = model(x, y) + losses[k] = loss.item() + out[split] = losses.mean() + model.train() + return out + +import torch + +# Assume these are passed in or defined in a config object +# BATCH_SIZE = 16 # Fallback +# BLOCK_SIZE = 1024 +# DTYPE = "bfloat16" + +def get_optimal_batch_size(model, block_size: int, dtype: str): + """Get optimal batch size based on model size and available VRAM.""" + if not torch.cuda.is_available(): + return 16 # Return a default fallback + + # It's better to get the config from the model itself + config = model.config + + # --- 1. Calculate Model Memory Usage --- + param_count = sum(p.numel() for p in model.parameters()) + # Parameters (4 bytes) + Gradients (4 bytes) + AdamW Optimizer (8 bytes) + model_memory = param_count * (4 + 4 + 8) + + # --- 2. Calculate Activation Memory Usage (per batch item) --- + activation_bytes = 4 if dtype == "float32" else 2 + + # Large tensors in the main loop (shape ~ B, T, n) + N = config.mlp_internal_dim_multiplier * config.n_embd // config.n_head + n_total = config.n_head * N + + main_activations = block_size * n_total * config.n_layer * 6 * activation_bytes + + # Don't forget the final logits tensor (shape ~ B, T, vocab_size) + logits_memory = block_size * config.vocab_size * 4 # Logits are often float32 + + activation_per_batch_item = main_activations + logits_memory + + # --- 3. Determine Max Batch Size --- + total_memory = torch.cuda.get_device_properties(0).total_memory + available_memory = total_memory * 0.95 # Use 80% of VRAM as a safety margin + memory_for_batches = available_memory - model_memory + + if memory_for_batches <= 0: + max_batch = 0 + else: + max_batch = int(memory_for_batches / activation_per_batch_item) + + # --- Debug logging --- + print(f"VRAM Debug:") + print(f" Total VRAM: {total_memory / 1e9:.2f} GB") + print(f" Model memory: {model_memory / 1e9:.2f} GB") + print(f" Activation per batch item: {activation_per_batch_item / 1e6:.2f} MB") + print(f" Memory available for batches: {memory_for_batches / 1e9:.2f} GB") + print(f" Max batch calculated: {max_batch}") + + if max_batch == 0: + print("Warning: Model parameters alone exceed 80% of VRAM. Batch size set to 1.") + return 1 + + # --- Find nearest power of 2 for efficiency --- + optimal = 1 + while optimal * 2 <= max_batch: + optimal *= 2 + return max(1, optimal) + + if __name__ == "__main__": fetch_data() + # Build vocabulary + vocab_size = build_vocabulary() + BDH_CONFIG.vocab_size = vocab_size + BDH_CONFIG.vocab_size = vocab_size + print(f"Built vocabulary with {vocab_size} characters") + model = bdh.BDH(BDH_CONFIG).to(device) + + # Auto-adjust batch size based on model size and VRAM + BATCH_SIZE = get_optimal_batch_size(model, BLOCK_SIZE, dtype) + gradient_accumulation_steps = EFFECTIVE_BATCH_SIZE // BATCH_SIZE + print(f"Using batch size: {BATCH_SIZE}, effective: {EFFECTIVE_BATCH_SIZE}, accumulation steps: {gradient_accumulation_steps}") model = torch.compile(model) optimizer = torch.optim.AdamW( model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY ) - x, y = get_batch("train") - - loss_acc = 0 - loss_steps = 0 - for step in range(MAX_ITERS): - with ctx: - logits, loss = model(x, y) - x, y = get_batch("train") - loss_acc += loss - loss_steps += 1 - scaler.scale(loss).backward() + start_step = 0 + if mode == 'continue': + start_step = bdh.load_checkpoint(model, optimizer, checkpoint_path) + print(f"Loaded checkpoint from step {start_step}") + elif mode == 'evaluate': + bdh.load_checkpoint(model, optimizer, checkpoint_path) + print("Loaded checkpoint for evaluation") + model.eval() + prompt_text = "To be or " + prompt = torch.tensor([char_to_id[ch] for ch in prompt_text], dtype=torch.long, device=device).unsqueeze(0) + ret = model.generate(prompt, max_new_tokens=TOKENS_TO_GENERATE, top_k=3) + ret_decoded = ''.join([id_to_char[i.item()] for i in ret.squeeze(0)]) + print(ret_decoded) + exit() + + for step in range(start_step, MAX_ITERS): + loss_acc = 0 + for micro_step in range(gradient_accumulation_steps): + x, y = get_batch("train") + with ctx: + logits, loss, _ = model(x, y) + loss = loss / gradient_accumulation_steps + loss_acc += loss + scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() if step % LOG_FREQ == 0: - print(f"Step: {step}/{MAX_ITERS} loss {loss_acc.item() / loss_steps:.3}") - loss_acc = 0 - loss_steps = 0 + print(f"Step: {step}/{MAX_ITERS} loss {loss_acc.item():.3}") + if step % EVAL_FREQ == 0 and step > 0: + losses = estimate_loss(model) + print(f"Step: {step}/{MAX_ITERS} train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}") + if step % CHECKPOINT_FREQ == 0 and step > 0: + torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': step}, f'checkpoint_{step}.pt') + torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': MAX_ITERS}, 'final_checkpoint.pt') print("Training done, now generating a sample ") model.eval() - prompt = torch.tensor( - bytearray("To be or ", "utf-8"), dtype=torch.long, device=device - ).unsqueeze(0) - ret = model.generate(prompt, max_new_tokens=100, top_k=3) - ret_decoded = bytes(ret.to(torch.uint8).to("cpu").squeeze(0)).decode( - errors="backslashreplace" - ) + prompt_text = "To be or " + prompt = torch.tensor([char_to_id[ch] for ch in prompt_text], dtype=torch.long, device=device).unsqueeze(0) + ret = model.generate(prompt, max_new_tokens=TOKENS_TO_GENERATE, top_k=3) + ret_decoded = ''.join([id_to_char[i.item()] for i in ret.squeeze(0)]) + print(ret_decoded)