diff --git a/README.md b/README.md index d548cb3..de39df6 100644 --- a/README.md +++ b/README.md @@ -1,95 +1,84 @@ -# Baby Dragon Hatchling +# Baby Dragon Hatchling Continual Learning (BDH-CL) -## **Bridging the Gap Between Transformers and the Brain** +**Fork of**: [pathwaycom/bdh](https://github.com/pathwaycom/bdh) -**Baby Dragon Hatchling (BDH)** is a biologically inspired large language model architecture that connects principles of deep learning with the foundations of neuroscience. Developed by researchers at [Pathway](https://pathway.com), BDH provides a theoretical and practical framework for understanding the emergence of reasoning and generalization in artificial systems. - -This repository contains the official implementation from the paper: -> *A. Kosowski, P. Uznański, J. Chorowski, Z. Stamirowska, M. Bartoszkiewicz.* -> [_The Dragon Hatchling: The Missing Link between the Transformer and Models of the Brain_](https://doi.org/10.48550/arXiv.2509.26507), arXiv (2025). +*** +## Introduction -## Overview +This repository attempts to extend the original Baby Dragon Hatchling (BDH) architecture, a biologically inspired large language model bridging transformers and neural computation, by integrating **continual learning** mechanisms inspired by biological synaptic plasticity. -BDH represents a **scale-free, locally interacting network of neurons** capable of intrinsic reasoning dynamics. BDH scales like a Transformer on performance benchmarks—yet retains full interpretability and theoretical grounding in the fine-grained dynamics of neuron interactions. +The key contribution of this fork is the implementation of **Adaptive Synaptic Consolidation**, enabling BDH to learn multiple tasks sequentially without catastrophic forgetting, in the spirit of Zenke et al.'s *Continual Learning Through Synaptic Intelligence* (2017). -**Key properties:** +*** -- **Scale-free network topology** mimicking biological connectivity -- **Locally interacting neuron particles** with excitatory/inhibitory dynamics -- **Hebbian working memory** based on synaptic plasticity, displaying monosemanticity -- **GPU-friendly state-space formulation** for efficient implementation -- **Interpretable activations** that are sparse and positive +## Highlights of Changes and Improvements -BDH formalizes a bridge between **neural computation and machine-based language understanding**. It shows how **macro reasoning behavior** in large AI models emerges from **micro-level neuron dynamics**, guided by principles of graph theory and local computation. +### Continual Learning Integration -Empirically, BDH matches **GPT-2–scale Transformers** across language and translation tasks at equivalent parameter scales (10M–1B). +- Added **Elastic Weight Consolidation (EWC)** with Fisher information estimation to protect important weights from overwriting during new tasks. +- Implemented **adaptive synaptic gates** that regulate plasticity at the neuron level, inspired by biological metaplasticity. +- Integrated **path integral online importance measures** for efficient tracking of weight significance during training. +- Supported **multi-task sequential training** enabling scalable lifelong learning. +## Benchmarking Suite -*** +| ![Permuted MNIST](res/PERMUTED_MNIST.PNG) | ![Rotated MNIST](res/ROTATED_MNIST.PNG) | +|:----------------------------------------------------:|:-------------------------------------------------:| +| Permuted MNIST (Simple) | Rotated MNIST (Simple) | -## Architecture +| ![Split CIFAR](res/SPLIT_CIFAR.PNG) | ![Sequence](res/SEQUENCE.PNG) | +|:----------------------------------------------------:|:-------------------------------------------------:| +| Split CIFAR (Simple) | Sequence (Simple) | - *** -## Relation to Transformers +## How to Use - +- Install dependencies: -BDH and the Transformer share attention-inspired computation; however, BDH’s graph-based architecture makes its attention **emerge naturally from neuron-level interactions**, reflecting attention as seen in biological systems. + ```bash + pip install -r requirements.txt + ``` -*** +- Train BDHC with continual learning enabled: -## Scaling Laws + ```bash + train.py --continual_learning + ``` - +- Run simple benchmarks: -BDH follows **Transformer-like scaling laws**, maintaining parameter efficiency while achieving interpretability at any scale. + ```bash + simple_benchmark.py --benchmark permuted_mnist --num_tasks 5 --epochs 10 + + simple_benchmark.py --benchmark split_cifar --num_tasks 5 --epochs 10 + + simple_benchmark.py --benchmark rotated_mnist --num_tasks 10 --epochs 10 + + simple_benchmark.py --benchmark sequence --num_tasks 5 --epochs 10 + ``` *** -## Installation and Training - -```bash -# install dependencies -pip install -r requirements.txt - -# train BDH on a toy dataset -python train.py -``` +## Credits - +This project builds upon and extends the original [Baby Dragon Hatchling repository by Pathway](https://github.com/pathwaycom/bdh). +The original authors' foundational work on biologically inspired neural architectures underpins this extension. +*** -## Learn and Discuss - -- Watch the *SuperDataScience podcast* [▶️ *Dragon Hatchling: The Missing Link Between Transformers and the Brain*](https://www.youtube.com/watch?v=mfV44-mtg7c) (72 min.) featuring Adrian Kosowski in conversation with Jon Krohn, unpacking BDH’s neuron-level architecture and sparse reasoning dynamics. - -- Read about BDH in -[*Forbes*](https://www.forbes.com/sites/victordey/2025/10/08/can-ai-learn-and-evolve-like-a-brain-pathways-bold-research-thinks-so/), -[*Semafor*](https://www.semafor.com/article/10/01/2025/new-ai-research-claims-to-be-getting-closer-to-modeling-human-brain), -[*The Turing Post*](https://www.turingpost.com/p/fod-121-300-million-to-start-a-big-promise-for-science#the-freshest-research-papers-catego), -[*Quantum Zeitgeist*](https://quantumzeitgeist.com/palo-alto-ai-firm-pathway-unveils-post-transformer-architecture-for-autonomous-ai/), -[*Golem*](https://www.golem.de/news/neue-ki-architektur-was-ist-baby-dragon-hatchling-2510-201047-2.html), -and elsewhere in the media. - -- Discuss and share the BDH paper on: -[*Hugging Face Papers*](https://huggingface.co/papers/2509.26507), -[*Alphaxiv*](https://alphaxiv.org/abs/2509.26507), -and [*EmergentMind*](https://emergentmind.com/papers/2509.26507). +## References +- Vibe coding was involved; Py is not my primary language. +- Zenke et al., *Continual Learning Through Synaptic Intelligence*, ICML 2017 +- Kosowski et al., *The Dragon Hatchling: The Missing Link between the Transformer and Models of the Brain*, arXiv 2025 -## Community Projects +*** -- [adamskrodzki/bdh](https://github.com/adamskrodzki/bdh): dynamic vocabulary, stateful attention -- [mosure/burn_dragon_hatchling](https://github.com/mosure/burn_dragon_hatchling): Burn port -- [severian42/bdh](https://github.com/severian42/bdh): MLX port -- [Git-Faisal/bdh](https://github.com/Git-Faisal/bdh) -- [GrahLnn/bdh](https://github.com/GrahLnn/bdh) +## Summary -## Acknowledgements -We thank Andrej Karpathy for the [nanoGPT](https://github.com/karpathy/nanoGPT/) code and the tiny Shapespeare dataset used in this demonstration. +BDH-CL introduces practical, biologically inspired continual learning capabilities into the BDH architecture, enabling robust lifelong learning beyond the single-task limitations of the original. It offers a unique blend of neuroscience theory and state-of-the-art machine learning applied to next-generation language models. -BDH research stands at the intersection of **AI architecture**, **biological learning models**, and **theoretical computer science**—an effort to map the *equations of reasoning* between artificial and biological intelligence. +*** diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bdh.py b/bdh.py deleted file mode 100644 index 4cfff79..0000000 --- a/bdh.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2025 Pathway Technology, Inc. - -import dataclasses -import math - -import torch -import torch.nn.functional as F -from torch import nn - - -@dataclasses.dataclass -class BDHConfig: - n_layer: int = 6 - n_embd: int = 256 - dropout: float = 0.1 - n_head: int = 4 - mlp_internal_dim_multiplier: int = 128 - vocab_size: int = 256 - - -def get_freqs(n, theta, dtype): - def quantize(t, q=2): - return (t / q).floor() * q - - return ( - 1.0 - / (theta ** (quantize(torch.arange(0, n, 1, dtype=dtype)) / n)) - / (2 * math.pi) - ) - - -class Attention(torch.nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - nh = config.n_head - D = config.n_embd - N = config.mlp_internal_dim_multiplier * D // nh - self.freqs = torch.nn.Buffer( - get_freqs(N, theta=2**16, dtype=torch.float32).view(1, 1, 1, N) - ) - - @staticmethod - def phases_cos_sin(phases): - phases = (phases % 1) * (2 * math.pi) - phases_cos = torch.cos(phases) - phases_sin = torch.sin(phases) - return phases_cos, phases_sin - - @staticmethod - def rope(phases, v): - v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size()) - phases_cos, phases_sin = Attention.phases_cos_sin(phases) - return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype) - - def forward(self, Q, K, V): - assert self.freqs.dtype == torch.float32 - assert K is Q - _, _, T, _ = Q.size() - - r_phases = ( - torch.arange( - 0, - T, - device=self.freqs.device, - dtype=self.freqs.dtype, - ).view(1, 1, -1, 1) - ) * self.freqs - QR = self.rope(r_phases, Q) - KR = QR - - # Current attention - scores = (QR @ KR.mT).tril(diagonal=-1) - return scores @ V - - -class BDH(nn.Module): - def __init__(self, config: BDHConfig): - super().__init__() - assert config.vocab_size is not None - self.config = config - nh = config.n_head - D = config.n_embd - N = config.mlp_internal_dim_multiplier * D // nh - self.decoder = nn.Parameter(torch.zeros((nh * N, D)).normal_(std=0.02)) - self.encoder = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02)) - - self.attn = Attention(config) - - self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False) - self.embed = nn.Embedding(config.vocab_size, D) - self.drop = nn.Dropout(config.dropout) - self.encoder_v = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02)) - - self.lm_head = nn.Parameter( - torch.zeros((D, config.vocab_size)).normal_(std=0.02) - ) - - self.apply(self._init_weights) - - def _init_weights(self, module): - if isinstance(module, nn.Linear): - nn.init.normal_(module.weight, mean=0.0, std=0.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, mean=0.0, std=0.02) - - def forward(self, idx, targets=None): - C = self.config - - B, T = idx.size() - D = C.n_embd - nh = C.n_head - N = D * C.mlp_internal_dim_multiplier // nh - - x = self.embed(idx).unsqueeze(1) - - # actually helps with training - x = self.ln(x) # B, 1, T, D - - for level in range(C.n_layer): - x_latent = x @ self.encoder - - x_sparse = F.relu(x_latent) # B, nh, T, N - - yKV = self.attn( - Q=x_sparse, - K=x_sparse, - V=x, - ) - yKV = self.ln(yKV) - - y_latent = yKV @ self.encoder_v - y_sparse = F.relu(y_latent) - xy_sparse = x_sparse * y_sparse # B, nh, T, N - - xy_sparse = self.drop(xy_sparse) - - yMLP = ( - xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh) @ self.decoder - ) # B, 1, T, D - y = self.ln(yMLP) - x = self.ln(x + y) - - logits = x.view(B, T, D) @ self.lm_head - loss = None - if targets is not None: - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) - - return logits, loss - - @torch.no_grad() - def generate( - self, - idx: torch.Tensor, - max_new_tokens: int, - temperature: float = 1.0, - top_k: int | None = None, - ) -> torch.Tensor: - for _ in range(max_new_tokens): - idx_cond = idx - logits, _ = self(idx_cond) - logits = logits[:, -1, :] / temperature - if top_k is not None: - values, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < values[:, [-1]]] = float("-inf") - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1) - idx = torch.cat((idx, idx_next), dim=1) - return idx diff --git a/benchmarks_complete.py b/benchmarks_complete.py new file mode 100644 index 0000000..9639181 --- /dev/null +++ b/benchmarks_complete.py @@ -0,0 +1,414 @@ +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, TensorDataset +import numpy as np +from typing import List, Tuple, Optional +import logging + +logger = logging.getLogger(__name__) + + +class PermutedMNISTGenerator: + """ + Permuted MNIST: Standard continual learning benchmark. + Each task has same MNIST digits but with different pixel permutation. + + Reference: Kirkpatrick et al., "Continual Learning Through Synaptic Intelligence" + """ + + def __init__(self, num_tasks: int = 10, samples_per_task: int = 1000, seed: int = 42): + self.num_tasks = num_tasks + self.samples_per_task = samples_per_task + self.seed = seed + self.permutations = self._generate_permutations() + + def _generate_permutations(self) -> List[np.ndarray]: + """Generate random pixel permutations for each task.""" + np.random.seed(self.seed) + permutations = [] + + for _ in range(self.num_tasks): + perm = np.random.permutation(784) # 28×28 = 784 pixels + permutations.append(perm) + + return permutations + + def generate_synthetic_mnist(self, task_id: int, num_samples: int = 1000) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate synthetic MNIST-like data with task-specific permutation.""" + np.random.seed(self.seed + task_id) + + images = [] + labels = [] + + # 10 digit classes + for digit in range(10): + samples_per_digit = num_samples // 10 + + for _ in range(samples_per_digit): + # Create digit pattern + img = torch.zeros(28, 28) + + # Draw simple digit shape + center = 14 + (digit % 5) * 0.5 - 2 + y, x = np.ogrid[:28, :28] + mask = (x - center)**2 + (y - center)**2 <= (8 + digit)**2 + img[mask] = 1.0 + + # Add noise + img = img + torch.randn_like(img) * 0.3 + img = torch.clamp(img, 0, 1) + + # Flatten and apply permutation + img_flat = img.view(-1) + img_permuted = img_flat[self.permutations[task_id]] + + images.append(img_permuted) + labels.append(digit) + + images = torch.stack(images) + labels = torch.tensor(labels, dtype=torch.long) + + return images, labels + + def get_task(self, task_id: int, batch_size: int = 32, split: str = 'train') -> DataLoader: + """Get DataLoader for a specific task.""" + if split == 'train': + num_samples = self.samples_per_task + else: + num_samples = self.samples_per_task // 5 + + images, labels = self.generate_synthetic_mnist(task_id, num_samples) + + dataset = TensorDataset(images, labels) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=(split == 'train')) + + return dataloader + + def get_all_tasks(self, batch_size: int = 32) -> List[Tuple[DataLoader, DataLoader, DataLoader]]: + """Get train/val/test splits for all tasks.""" + tasks = [] + + for task_id in range(self.num_tasks): + train_loader = self.get_task(task_id, batch_size, split='train') + val_loader = self.get_task(task_id, batch_size, split='val') + test_loader = self.get_task(task_id, batch_size, split='test') + + tasks.append((train_loader, val_loader, test_loader)) + + return tasks + + +class SplitCIFARGenerator: + """ + Split CIFAR-10: Class-incremental learning benchmark. + Each task contains 2 different CIFAR-10 classes. + + Task 0: [airplane, automobile] + Task 1: [bird, cat] + Task 2: [deer, dog] + Task 3: [frog, horse] + Task 4: [ship, truck] + """ + + def __init__(self, num_tasks: int = 5, samples_per_class: int = 500, seed: int = 42): + self.num_tasks = num_tasks + self.samples_per_class = samples_per_class + self.seed = seed + self.classes_per_task = 10 // num_tasks + + def generate_synthetic_cifar(self, class_id: int, num_samples: int = 500) -> torch.Tensor: + """Generate synthetic CIFAR-10-like 32×32 RGB images.""" + np.random.seed(self.seed + class_id) + + images = [] + + for _ in range(num_samples): + # Create image with class-specific pattern + img = torch.zeros(3, 32, 32) + + # Different color patterns for different classes + if class_id == 0: # airplane + img[0, 8:24, 8:24] = torch.ones(16, 16) + img[1, 10:22, 10:22] = torch.ones(12, 12) + elif class_id == 1: # automobile + img[1, 10:22, 8:24] = torch.ones(12, 16) + img[2, 15:20, 15:20] = torch.ones(5, 5) + elif class_id == 2: # bird + img[1, 5:20, 5:20] = torch.ones(15, 15) + elif class_id == 3: # cat + img[0, 8:20, 8:20] = torch.ones(12, 12) + elif class_id == 4: # deer + img[1, 8:20, 8:20] = torch.ones(12, 12) + elif class_id == 5: # dog + img[0, 10:22, 10:22] = torch.ones(12, 12) + elif class_id == 6: # frog + img[2, 12:24, 12:24] = torch.ones(12, 12) + elif class_id == 7: # horse + img[0, 10:24, 10:24] = torch.ones(14, 14) + elif class_id == 8: # ship + img[2, 8:24, 8:24] = torch.ones(16, 16) + elif class_id == 9: # truck + img[0, 8:24, 10:22] = torch.ones(16, 12) + else: + img = torch.bernoulli(torch.full((3, 32, 32), 0.3)) + + # Add noise + img = img + torch.randn_like(img) * 0.3 + img = torch.clamp(img, 0, 1) + + images.append(img) + + return torch.stack(images) + + def get_task(self, task_id: int, batch_size: int = 32, split: str = 'train') -> DataLoader: + """Get DataLoader for task.""" + images = [] + labels = [] + + start_class = task_id * self.classes_per_task + end_class = start_class + self.classes_per_task + + if split == 'train': + num_samples = self.samples_per_class + else: + num_samples = self.samples_per_class // 5 + + for idx, class_id in enumerate(range(start_class, end_class)): + class_images = self.generate_synthetic_cifar(class_id, num_samples) + images.append(class_images) + labels.extend([idx] * num_samples) + + images = torch.cat(images, dim=0) + labels = torch.tensor(labels, dtype=torch.long) + + dataset = TensorDataset(images, labels) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=(split == 'train')) + + return dataloader + + def get_all_tasks(self, batch_size: int = 32) -> List[Tuple[DataLoader, DataLoader, DataLoader]]: + """Get train/val/test for all tasks.""" + tasks = [] + + for task_id in range(self.num_tasks): + train_loader = self.get_task(task_id, batch_size, split='train') + val_loader = self.get_task(task_id, batch_size, split='val') + test_loader = self.get_task(task_id, batch_size, split='test') + + tasks.append((train_loader, val_loader, test_loader)) + + return tasks + + +class RotatedMNISTGenerator: + """ + Rotated MNIST: Smooth distribution shift benchmark. + Each task has MNIST rotated by increasing angles. + + Task 0: 0° rotation + Task 1: 10° rotation + Task 2: 20° rotation + etc. + """ + + def __init__(self, num_tasks: int = 10, samples_per_task: int = 1000, seed: int = 42): + self.num_tasks = num_tasks + self.samples_per_task = samples_per_task + self.seed = seed + + def generate_synthetic_mnist(self, task_id: int, num_samples: int = 1000) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate synthetic MNIST with rotation.""" + np.random.seed(self.seed + task_id) + + images = [] + labels = [] + + rotation_angle = (task_id * 10) % 360 # 0°, 10°, 20°, ... + + for digit in range(10): + samples_per_digit = num_samples // 10 + + for _ in range(samples_per_digit): + # Create digit + img = torch.zeros(28, 28) + center = 14 + y, x = np.ogrid[:28, :28] + mask = (x - center)**2 + (y - center)**2 <= (8 + digit)**2 + img[mask] = 1.0 + + # Apply rotation using PIL + from PIL import Image + pil_img = Image.fromarray((img.numpy() * 255).astype(np.uint8)) + rotated = pil_img.rotate(rotation_angle) + img = torch.from_numpy(np.array(rotated)) / 255.0 + + # Add noise + img = img + torch.randn_like(img) * 0.3 + img = torch.clamp(img, 0, 1) + + images.append(img.view(-1).float()) + labels.append(digit) + + images = torch.stack(images) + labels = torch.tensor(labels, dtype=torch.long) + + return images, labels + + def get_task(self, task_id: int, batch_size: int = 32, split: str = 'train') -> DataLoader: + """Get DataLoader for task.""" + if split == 'train': + num_samples = self.samples_per_task + else: + num_samples = self.samples_per_task // 5 + + images, labels = self.generate_synthetic_mnist(task_id, num_samples) + + dataset = TensorDataset(images, labels) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=(split == 'train')) + + return dataloader + + def get_all_tasks(self, batch_size: int = 32) -> List[Tuple[DataLoader, DataLoader, DataLoader]]: + """Get all tasks.""" + tasks = [] + + for task_id in range(self.num_tasks): + train_loader = self.get_task(task_id, batch_size, split='train') + val_loader = self.get_task(task_id, batch_size, split='val') + test_loader = self.get_task(task_id, batch_size, split='test') + + tasks.append((train_loader, val_loader, test_loader)) + + return tasks + + +class ImprovedSequenceGenerator: + """ + Sequence Learning: Learn different transformation rules per task. + + Task 0: y = (x + x_prev) % vocab_size + Task 1: y = (2*x + x_prev) % vocab_size + Task 2: y = (x * x_prev) % vocab_size + etc. + """ + + def __init__(self, num_tasks: int = 10, seq_length: int = 64, + vocab_size: int = 256, samples_per_task: int = 2000, seed: int = 42): + self.num_tasks = num_tasks + self.seq_length = seq_length + self.vocab_size = vocab_size + self.samples_per_task = samples_per_task + self.seed = seed + self.rules = self._create_rules() + + def _create_rules(self): + """Create interpretable transformation rules.""" + rules = [ + lambda seq: torch.roll(seq, 1), + lambda seq: (seq * 2) % self.vocab_size, + lambda seq: self.vocab_size - seq, + lambda seq: seq[::2] if len(seq) >= 32 else seq, + lambda seq: torch.where(torch.arange(len(seq)) % 2 == 0, seq, torch.zeros_like(seq)), + lambda seq: torch.sort(seq)[0], + lambda seq: torch.flip(seq, [0]), + lambda seq: torch.clamp(seq + torch.randn_like(seq) - torch.randn_like(seq), 0, self.vocab_size-1), + lambda seq: torch.tensor([int((seq[i] + seq[i+1]) % self.vocab_size) if i < len(seq)-1 else seq[i] for i in range(len(seq))]), + lambda seq: torch.cat([torch.flip(seq[:len(seq)//2], [0]), seq[len(seq)//2:]]), + ] + return rules + + def generate_task_data(self, task_id: int, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate data for task using its rule.""" + np.random.seed(self.seed + task_id) + + rule = self.rules[task_id % len(self.rules)] + + sequences = [] + labels = [] + + for _ in range(num_samples): + seq_in = torch.randint(1, self.vocab_size, (self.seq_length,)) + + try: + seq_out = rule(seq_in) + + if len(seq_out) != self.seq_length: + seq_out = seq_out[:self.seq_length] + if len(seq_out) < self.seq_length: + seq_out = torch.cat([seq_out, torch.zeros(self.seq_length - len(seq_out))]) + + seq_out = torch.clamp(seq_out.long(), 0, self.vocab_size - 1) + + # Use first token as classification label + sequences.append(seq_in) + labels.append(seq_in[0].item() % 10) # 10 classes + except: + continue + + if len(sequences) == 0: + sequences = torch.randint(1, self.vocab_size, (num_samples, self.seq_length)) + labels = torch.randint(0, 10, (num_samples,)) + else: + sequences = torch.stack(sequences) + labels = torch.tensor(labels, dtype=torch.long) + + return sequences, labels + + def get_task(self, task_id: int, batch_size: int = 32, split: str = 'train') -> DataLoader: + """Get DataLoader for task.""" + if split == 'train': + num_samples = self.samples_per_task + else: + num_samples = self.samples_per_task // 5 + + sequences, labels = self.generate_task_data(task_id, num_samples) + + dataset = TensorDataset(sequences, labels) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=(split == 'train')) + + return dataloader + + def get_all_tasks(self, batch_size: int = 32) -> List[Tuple[DataLoader, DataLoader, DataLoader]]: + """Get all tasks.""" + tasks = [] + + for task_id in range(self.num_tasks): + train_loader = self.get_task(task_id, batch_size, split='train') + val_loader = self.get_task(task_id, batch_size, split='val') + test_loader = self.get_task(task_id, batch_size, split='test') + + tasks.append((train_loader, val_loader, test_loader)) + + return tasks + + +if __name__ == "__main__": + """Test benchmark generators.""" + + print("=" * 70) + print("Testing Permuted MNIST Generator") + print("=" * 70) + gen = PermutedMNISTGenerator(num_tasks=3, samples_per_task=100) + tasks = gen.get_all_tasks(batch_size=32) + + for task_id, (train_loader, val_loader, test_loader) in enumerate(tasks): + x, y = next(iter(train_loader)) + print(f"\nTask {task_id}:") + print(f" Input: {x.shape} (flattened 28×28=784 pixels)") + print(f" Labels: {y.shape} (10 digit classes)") + print(f" Ready for training") + + print("\n" + "=" * 70) + print("Testing Split CIFAR Generator") + print("=" * 70) + gen = SplitCIFARGenerator(num_tasks=5, samples_per_class=100) + tasks = gen.get_all_tasks(batch_size=32) + + for task_id, (train_loader, val_loader, test_loader) in enumerate(tasks): + x, y = next(iter(train_loader)) + print(f"\nTask {task_id}:") + print(f" Input: {x.shape} (3×32×32 RGB images)") + print(f" Labels: {y.shape} (2 classes per task)") + print(f" Ready for training") + + print("\nAll benchmarks working correctly!") diff --git a/continual_learning/__init__.py b/continual_learning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/continual_learning/adaptive_synapses.py b/continual_learning/adaptive_synapses.py new file mode 100644 index 0000000..faa3301 --- /dev/null +++ b/continual_learning/adaptive_synapses.py @@ -0,0 +1,125 @@ +import torch +import torch.nn as nn +from typing import Optional + + +class AdaptiveSynapse(nn.Module): + """ + Enhanced synaptic connection with consolidation support. + + Attributes: + weight: Current synaptic strength + weight_ref: Reference weight from last task + importance: Task importance (Fisher or path integral) + plasticity_state: Consolidation level (0-1) + learning_rate_scale: Per-synapse learning rate modifier + """ + + def __init__(self, in_features: int, out_features: int, + bias: bool = True, dtype=torch.float32): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.dtype = dtype + + # Main parameters + self.weight = nn.Parameter( + torch.empty((out_features, in_features), dtype=dtype) + ) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, dtype=dtype)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + # Continual learning state + self.register_buffer('weight_ref', + torch.zeros_like(self.weight, dtype=dtype)) + self.register_buffer('importance', + torch.zeros_like(self.weight, dtype=dtype)) + self.register_buffer('plasticity_state', + torch.ones((out_features, in_features), dtype=dtype)) + self.register_buffer('learning_rate_scale', + torch.ones_like(self.weight, dtype=dtype)) + + # Online importance accumulators + self.register_buffer('path_integral', + torch.zeros_like(self.weight, dtype=dtype)) + self.register_buffer('importance_accumulator', + torch.zeros_like(self.weight, dtype=dtype)) + self.register_buffer('prev_weight', + torch.zeros_like(self.weight, dtype=dtype)) + + def reset_parameters(self): + """Initialize weights using Kaiming uniform.""" + nn.init.kaiming_uniform_(self.weight, a=torch.sqrt(torch.tensor(5.0))) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / torch.sqrt(torch.tensor(fan_in)) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Standard linear forward pass.""" + return torch.nn.functional.linear(input, self.weight, self.bias) + + def forward_with_consolidation(self, input: torch.Tensor) -> torch.Tensor: + """ + Forward pass with plasticity gating. + Consolidated synapses transmit more reliably. + """ + # Gate weights by plasticity state + gated_weight = self.weight * self.plasticity_state + return torch.nn.functional.linear(input, gated_weight, self.bias) + + def update_plasticity_state(self, decay_rate: float = 0.98): + """ + Reduce plasticity after task (consolidation). + + Args: + decay_rate: Multiplicative decay (0.98 = 2% decay) + """ + self.plasticity_state.data *= decay_rate + self.plasticity_state.data.clamp_(min=0.01, max=1.0) + + def consolidate_weights(self): + """Save current weights as reference for next task.""" + self.weight_ref.copy_(self.weight.data) + + +class AdaptiveLinear(nn.Linear): + """Drop-in replacement for nn.Linear with consolidation support.""" + + def __init__(self, in_features: int, out_features: int, + bias: bool = True): + super().__init__(in_features, out_features, bias) + + # Continual learning buffers + self.register_buffer('weight_ref', + torch.zeros_like(self.weight, dtype=self.weight.dtype)) + self.register_buffer('importance', + torch.zeros_like(self.weight, dtype=self.weight.dtype)) + self.register_buffer('plasticity_state', + torch.ones_like(self.weight, dtype=self.weight.dtype)) + self.register_buffer('learning_rate_scale', + torch.ones_like(self.weight, dtype=self.weight.dtype)) + + self.register_buffer('path_integral', + torch.zeros_like(self.weight, dtype=self.weight.dtype)) + self.register_buffer('importance_accumulator', + torch.zeros_like(self.weight, dtype=self.weight.dtype)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return nn.functional.linear(input, self.weight, self.bias) + + def forward_with_consolidation(self, input: torch.Tensor) -> torch.Tensor: + """Forward pass with consolidation gating.""" + gated_weight = self.weight * self.plasticity_state + return nn.functional.linear(input, gated_weight, self.bias) + + def update_plasticity_state(self, decay_rate: float = 0.98): + self.plasticity_state.data *= decay_rate + self.plasticity_state.data.clamp_(min=0.01, max=1.0) + + def consolidate_weights(self): + self.weight_ref.copy_(self.weight.data) diff --git a/continual_learning/consolidation.py b/continual_learning/consolidation.py new file mode 100644 index 0000000..c312ba0 --- /dev/null +++ b/continual_learning/consolidation.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +from typing import Dict + +class ElasticWeightConsolidation(nn.Module): + """ + Elastic Weight Consolidation loss. + + Penalizes changes to important parameters: + L_EWC = (λ/2) Σ F_i (w_i - w*_i)^2 + + Where: + - F_i: Fisher importance + - w_i: Current weight + - w*_i: Reference weight from previous task + + Reference: Kirkpatrick et al., 2017 + """ + + def __init__(self, lambda_ewc: float = 1000.0): + super().__init__() + self.lambda_ewc = lambda_ewc + self.is_first_task = True + self.task_count = 0 + + def consolidate_task(self, model: nn.Module, fisher_dict: Dict[str, torch.Tensor]): + """ + Consolidate after task completion. + Save reference weights and Fisher importance. + """ + if not self.is_first_task: + # Increase penalty for multiple tasks + self.lambda_ewc *= 1.2 + + for name, param in model.named_parameters(): + if param.requires_grad: + # Sanitize names (replace '.' with '_') + safe_ref_name = f"_ref_{name}".replace('.', '_') + safe_fisher_name = f"_fisher_{name}".replace('.', '_') + + # Register buffer for reference weight if does not exist + if not hasattr(model, safe_ref_name): + ref_buffer = torch.zeros_like(param.data) + model.register_buffer(safe_ref_name, ref_buffer) + model._buffers[safe_ref_name].copy_(param.data) + + # Register buffer for Fisher importance if exists in fisher_dict + if name in fisher_dict: + if not hasattr(model, safe_fisher_name): + fisher_buffer = torch.zeros_like(param.data) + model.register_buffer(safe_fisher_name, fisher_buffer) + model._buffers[safe_fisher_name].copy_(fisher_dict[name]) + + self.is_first_task = False + self.task_count += 1 + + def forward(self, model: nn.Module) -> torch.Tensor: + """Compute EWC loss.""" + if self.is_first_task: + return torch.tensor(0.0, device=next(model.parameters()).device) + + ewc_loss = torch.tensor(0.0, device=next(model.parameters()).device) + + for name, param in model.named_parameters(): + if param.requires_grad: + safe_ref_name = f"_ref_{name}".replace('.', '_') + safe_fisher_name = f"_fisher_{name}".replace('.', '_') + + if hasattr(model, safe_ref_name) and hasattr(model, safe_fisher_name): + weight_ref = getattr(model, safe_ref_name) + fisher = getattr(model, safe_fisher_name) + + # Weight deviation + delta_w = param - weight_ref + # EWC penalty + ewc_loss = ewc_loss + (fisher * (delta_w ** 2)).sum() + + return (self.lambda_ewc / 2.0) * ewc_loss + + +class MetaplasticityConsolidation: + """ + Automatic consolidation through metaplasticity. + + Reduces plasticity of important synapses after task completion. + + Inspired by: Fusi et al. synaptic tagging and capture model. + """ + + def __init__(self, consolidation_rate: float = 0.98): + self.consolidation_rate = consolidation_rate + + def apply_consolidation(self, model: nn.Module, strength: float = 1.0): + """Apply consolidation to all adaptive layers.""" + for module in model.modules(): + if hasattr(module, 'plasticity_state'): + # Compute importance-weighted consolidation + if hasattr(module, 'importance'): + importance = torch.sigmoid(module.importance) + else: + importance = torch.ones_like(module.plasticity_state) + + # Decay rate based on importance + decay_rate = ( + self.consolidation_rate + + (1 - self.consolidation_rate) * importance + ) + # Update plasticity + module.plasticity_state.data *= decay_rate + module.plasticity_state.data.clamp_(min=0.01, max=1.0) + + def progressive_consolidation(self, model: nn.Module, + epoch: int, num_consolidation_epochs: int = 10): + """Gradually consolidate over multiple epochs.""" + progress = epoch / num_consolidation_epochs + for module in model.modules(): + if hasattr(module, 'plasticity_state'): + target_decay = ( + self.consolidation_rate + + progress * (1 - self.consolidation_rate) + ) + module.plasticity_state.data *= target_decay + module.plasticity_state.data.clamp_(min=0.01, max=1.0) diff --git a/continual_learning/hebbian.py b/continual_learning/hebbian.py new file mode 100644 index 0000000..b312ceb --- /dev/null +++ b/continual_learning/hebbian.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +from typing import Dict + +class HebbianPlasticity: + def __init__(self, learning_rate: float = 0.01): + self.learning_rate = learning_rate + + def apply_hebbian_update(self, model: nn.Module, layer_activations: Dict[str, torch.Tensor]): + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + if name in layer_activations: + activation = layer_activations[name] + if hasattr(module, 'weight'): + # Compute Hebbian update: outer product of activations + # activation shape [batch_size, input_features] + # weight shape [output_features, input_features] + pre_syn = activation # presynaptic inputs + post_syn = module.weight.data @ pre_syn.t() # postsynaptic outputs + delta = torch.mm(post_syn.t(), pre_syn) + module.weight.data += self.learning_rate * delta / activation.size(0) + + def consolidate_importance(self, model: nn.Module): + for module in model.modules(): + if hasattr(module, 'importance') and hasattr(module, 'importance_accumulator'): + module.importance.data += 0.01 * module.importance_accumulator.data + module.importance_accumulator.zero_() diff --git a/continual_learning/importance.py b/continual_learning/importance.py new file mode 100644 index 0000000..b583bdf --- /dev/null +++ b/continual_learning/importance.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn +from typing import Dict, List, Optional + + +class ImportanceEstimator: + """ + Base class for computing synaptic importance. + """ + + def __init__(self, model: nn.Module): + self.model = model + + def estimate(self, data_loader, loss_fn) -> Dict[str, torch.Tensor]: + """Estimate importance for all parameters.""" + raise NotImplementedError + + +class PathIntegralEstimator(ImportanceEstimator): + """ + Path integral importance: I = ∫ ||∇L/∂w|| ||∂w/∂t|| dt + + Online method - accumulates during training. + More efficient than Fisher matrix. + + Reference: Zenke et al., 2017 + """ + + def __init__(self, model: nn.Module, dampening_factor: float = 0.01): + super().__init__(model) + self.dampening_factor = dampening_factor + self.path_integrals = {} + self.prev_weights = {} + + def reset(self): + """Reset accumulators.""" + self.path_integrals = {} + self.prev_weights = {} + + def save_prev_weights(self): + """Save current weights for delta computation.""" + for name, param in self.model.named_parameters(): + if param.requires_grad: + self.prev_weights[name] = param.data.clone() + + def update_importance(self, loss: torch.Tensor): + """ + Update importance based on current loss and weight changes. + Call after backward() but before optimizer.step() + """ + for name, param in self.model.named_parameters(): + if param.grad is not None: + grad = param.grad.data + + # Compute weight change if available + if name in self.prev_weights: + delta_w = param.data - self.prev_weights[name] + # Path integral: |gradient| × |weight_change| + importance = torch.abs(grad) * torch.abs(delta_w) + else: + # First iteration: use gradient magnitude + importance = torch.abs(grad) + + # Accumulate with dampening + if name not in self.path_integrals: + self.path_integrals[name] = torch.zeros_like(importance) + + self.path_integrals[name] += self.dampening_factor * importance + + def get_importance(self) -> Dict[str, torch.Tensor]: + """Get accumulated importance values.""" + return self.path_integrals.copy() + + +class FisherEstimator(ImportanceEstimator): + """ + Diagonal Fisher Information Matrix estimation. + F_ii = E[(∂L/∂w_i)²] + + Accurate but more expensive than path integral. + Used as task checkpoint. + + Reference: Kirkpatrick et al., 2017 + """ + + def __init__(self, model: nn.Module): + super().__init__(model) + + def estimate(self, data_loader, loss_fn, + num_samples: Optional[int] = None) -> Dict[str, torch.Tensor]: + """ + Estimate diagonal Fisher for given data. + + Args: + data_loader: Training data + loss_fn: Loss function + num_samples: Number of samples to use (None = all) + + Returns: + Dict mapping parameter names to Fisher estimates + """ + fisher = {} + num_batches = 0 + + self.model.eval() + + for batch_idx, (x, y) in enumerate(data_loader): + if num_samples is not None and batch_idx >= num_samples: + break + + self.model.zero_grad() + + # Forward pass + outputs = self.model(x) + batch_size, seq_len, vocab_size = outputs.shape + loss = loss_fn(outputs.view(batch_size * seq_len, vocab_size), y.view(batch_size * seq_len)) + + + # Backward to get gradients + loss.backward(retain_graph=(batch_idx < len(data_loader) - 1)) + + # Accumulate squared gradients + for name, param in self.model.named_parameters(): + if param.grad is not None: + grad_sq = param.grad.data ** 2 + + if name not in fisher: + fisher[name] = torch.zeros_like(grad_sq) + + fisher[name] += grad_sq + + num_batches += 1 + + # Average over batches + for name in fisher: + fisher[name] /= max(num_batches, 1) + + return fisher + + +class HybridImportanceEstimator(ImportanceEstimator): + """ + Combines path integral (during training) with Fisher (at task end). + """ + + def __init__(self, model: nn.Module): + super().__init__(model) + self.path_integral_est = PathIntegralEstimator(model) + self.fisher_est = FisherEstimator(model) + + def get_online_importance(self, loss: torch.Tensor): + """Get importance during training.""" + self.path_integral_est.update_importance(loss) + return self.path_integral_est.get_importance() + + def get_task_importance(self, data_loader, loss_fn) -> Dict[str, torch.Tensor]: + """Get importance at task end (Fisher).""" + return self.fisher_est.estimate(data_loader, loss_fn, num_samples=200) diff --git a/continual_learning/metrics.py b/continual_learning/metrics.py new file mode 100644 index 0000000..18eedc4 --- /dev/null +++ b/continual_learning/metrics.py @@ -0,0 +1,66 @@ +import numpy as np +from typing import List + + +class ContinualLearningMetrics: + """Compute standard CL metrics.""" + + def __init__(self, accuracies: List[List[float]]): + self.accuracies = np.array(accuracies) + self.num_tasks = len(accuracies) + + def backward_transfer(self) -> float: + """ + Measure forgetting. + BWT = (1/T-1) Σ (ACC_i(i) - ACC_i(T)) + """ + if self.num_tasks < 2: + return 0.0 + + bwt = 0.0 + for i in range(self.num_tasks - 1): + acc_after_task_i = self.accuracies[i, i] + acc_final = self.accuracies[-1, i] + bwt += acc_after_task_i - acc_final + + return bwt / (self.num_tasks - 1) + + def forward_transfer(self) -> float: + """Measure positive transfer from previous tasks.""" + # Simplified: compare task 1 accuracy after task 0 to baseline + if self.num_tasks < 2: + return 0.0 + + # Would need baseline for true forward transfer + return 0.0 + + def average_accuracy(self) -> float: + """Final average accuracy across all tasks.""" + return float(np.mean(self.accuracies[-1, :])) + + def forgetting_per_task(self) -> List[float]: + """Forgetting for each task.""" + forgetting = [] + for i in range(self.num_tasks - 1): + acc_after = self.accuracies[i, i] + acc_final = self.accuracies[-1, i] + forgetting.append(acc_after - acc_final) + + return forgetting + + def print_summary(self): + """Print summary statistics.""" + print("\n" + "="*70) + print("CONTINUAL LEARNING EVALUATION") + print("="*70) + + print("\nAccuracy Matrix:") + for i, row in enumerate(self.accuracies): + print(f"After Task {i}: {[f'{acc:.3f}' for acc in row]}") + + print(f"\nAverage Final Accuracy: {self.average_accuracy():.4f}") + print(f"Backward Transfer (Forgetting): {self.backward_transfer():.4f}") + + forgetting = self.forgetting_per_task() + if forgetting: + print(f"Forgetting per task: {[f'{f:.3f}' for f in forgetting]}") diff --git a/continual_learning/trainer.py b/continual_learning/trainer.py new file mode 100644 index 0000000..e34bb34 --- /dev/null +++ b/continual_learning/trainer.py @@ -0,0 +1,212 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from typing import List, Tuple, Callable +import tqdm + +from .consolidation import ElasticWeightConsolidation, MetaplasticityConsolidation +from .importance import HybridImportanceEstimator +from .hebbian import HebbianPlasticity + + +def train_epoch(model: nn.Module, + data_loader: torch.utils.data.DataLoader, + optimizer: optim.Optimizer, + loss_fn: nn.Module, + ewc_loss_fn: ElasticWeightConsolidation, + importance_est: HybridImportanceEstimator, + ewc_weight: float = 1.0, + use_consolidation: bool = False) -> float: + """ + Single training epoch with consolidation. + + Args: + model: Neural network + data_loader: Training data + optimizer: Optimizer + loss_fn: Loss function + ewc_loss_fn: EWC loss + importance_est: Importance estimator + ewc_weight: Weight for EWC loss + use_consolidation: Whether to use consolidation gates + + Returns: + Average loss + """ + model.train() + total_loss = 0.0 + num_batches = 0 + + pbar = tqdm.tqdm(data_loader, desc="Training") + + for x, y in pbar: + # Forward pass + if use_consolidation: + # Use consolidation gates + logits = model(x) # Model should support consolidation internally + else: + logits = model(x) + + task_loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1)) + + # EWC regularization + ewc_loss = ewc_loss_fn(model) + + total_loss_batch = task_loss + ewc_weight * ewc_loss + + # Backward + optimizer.zero_grad() + total_loss_batch.backward() + + # Update importance estimate (online) + importance_est.get_online_importance(task_loss) + + # Optimizer step + optimizer.step() + + total_loss += total_loss_batch.item() + num_batches += 1 + + pbar.set_postfix({'loss': total_loss / (num_batches + 1)}) + + return total_loss / num_batches if num_batches > 0 else 0.0 + + +def evaluate(model: nn.Module, + data_loader: torch.utils.data.DataLoader) -> float: + """Evaluate model accuracy.""" + model.eval() + correct = 0 + total = 0 + + with torch.no_grad(): + for x, y in data_loader: + logits = model(x) # logits shape: (batch_size, seq_len, vocab_size) + pred = logits.argmax(dim=-1) # pred shape: (batch_size, seq_len) + + # Flatten predictions and labels for element-wise comparison + pred_flat = pred.view(-1) + y_flat = y.view(-1) + + correct += (pred_flat == y_flat).sum().item() + total += y_flat.size(0) + + return correct / total if total > 0 else 0.0 + + +def train_continual_learning_task(model: nn.Module, + train_loader: torch.utils.data.DataLoader, + val_loader: torch.utils.data.DataLoader, + loss_fn: nn.Module, + ewc_loss_fn: ElasticWeightConsolidation, + importance_est: HybridImportanceEstimator, + num_epochs: int = 40, + lr: float = 0.001, + use_consolidation: bool = True) -> float: + """ + Train on single task with early stopping. + + Returns: + Best validation accuracy + """ + optimizer = optim.Adam(model.parameters(), lr=lr) + + best_val_acc = 0.0 + patience = 5 + patience_counter = 0 + + print(f"\nTraining with {len(train_loader)} batches...") + + for epoch in range(num_epochs): + train_loss = train_epoch( + model, train_loader, optimizer, loss_fn, ewc_loss_fn, + importance_est, ewc_weight=0.5 if not ewc_loss_fn.is_first_task else 0.0, + use_consolidation=use_consolidation + ) + + val_acc = evaluate(model, val_loader) + + if val_acc > best_val_acc: + best_val_acc = val_acc + patience_counter = 0 + else: + patience_counter += 1 + + if epoch % 5 == 0: + print(f"Epoch {epoch}/{num_epochs}: Loss={train_loss:.4f}, " + f"Val_Acc={val_acc:.4f}") + + if patience_counter >= patience: + print(f"Early stopping at epoch {epoch}") + break + + return best_val_acc + + +def continual_learning_training_loop(model: nn.Module, + tasks: List[Tuple], + num_epochs_per_task: int = 40, + lr: float = 0.001, + use_consolidation: bool = True, + use_metaplasticity: bool = True) -> List[List[float]]: + """ + Main continual learning loop. + + Args: + model: Neural network + tasks: List of (train_loader, val_loader, test_loader) tuples + num_epochs_per_task: Training epochs per task + lr: Learning rate + use_consolidation: Use EWC consolidation + use_metaplasticity: Use metaplasticity + + Returns: + Accuracy matrix: accuracies[task_id][eval_task_id] + """ + loss_fn = nn.CrossEntropyLoss() + ewc_loss_fn = ElasticWeightConsolidation(lambda_ewc=1000.0) + importance_est = HybridImportanceEstimator(model) + metaplasticity = MetaplasticityConsolidation(consolidation_rate=0.98) + + accuracies = [] + + for task_id, (train_loader, val_loader, test_loader) in enumerate(tasks): + print(f"\n{'='*60}") + print(f"TASK {task_id}") + print(f"{'='*60}") + + # Train on task + train_continual_learning_task( + model, train_loader, val_loader, loss_fn, ewc_loss_fn, + importance_est, num_epochs=num_epochs_per_task, lr=lr, + use_consolidation=use_consolidation + ) + + # Consolidate task + if use_consolidation: + print(f"\nConsolidating Task {task_id}...") + + # Estimate Fisher information + fisher_dict = importance_est.get_task_importance( + train_loader, loss_fn + ) + + # Save weights and Fisher + ewc_loss_fn.consolidate_task(model, fisher_dict) + + # Apply metaplasticity + if use_metaplasticity: + for cons_epoch in range(10): + metaplasticity.apply_consolidation(model, strength=0.9) + + # Test on all previous tasks + task_accs = [] + for eval_task_id in range(task_id + 1): + _, _, test_loader_eval = tasks[eval_task_id] + acc = evaluate(model, test_loader_eval) + task_accs.append(acc) + print(f" Task {eval_task_id}: Accuracy={acc:.4f}") + + accuracies.append(task_accs) + + return accuracies diff --git a/figs/architecture.png b/figs/architecture.png deleted file mode 100644 index 664092f..0000000 Binary files a/figs/architecture.png and /dev/null differ diff --git a/figs/bdh_scaling.png b/figs/bdh_scaling.png deleted file mode 100644 index 276e97b..0000000 Binary files a/figs/bdh_scaling.png and /dev/null differ diff --git a/figs/vocab.png b/figs/vocab.png deleted file mode 100644 index c54a5fd..0000000 Binary files a/figs/vocab.png and /dev/null differ diff --git a/model.py b/model.py new file mode 100644 index 0000000..d31aea0 --- /dev/null +++ b/model.py @@ -0,0 +1,429 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple, List +import logging + +logger = logging.getLogger(__name__) + + +class BDHNeuron(nn.Module): + """ + Basic BDH neuron with optional adaptive synaptic properties. + Represents a single neuron in the scale-free network. + """ + + def __init__(self, neuron_id: int, input_dim: int, output_dim: int, + use_adaptive=False): + super().__init__() + self.neuron_id = neuron_id + self.input_dim = input_dim + self.output_dim = output_dim + self.use_adaptive = use_adaptive + + # Synaptic weights + self.weight = nn.Parameter(torch.randn(output_dim, input_dim) * 0.01) + self.bias = nn.Parameter(torch.zeros(output_dim)) + + if use_adaptive: + # Continual learning buffers + self.register_buffer('weight_ref', torch.zeros_like(self.weight)) + self.register_buffer('importance', torch.zeros_like(self.weight)) + self.register_buffer('plasticity_state', torch.ones_like(self.weight)) + self.register_buffer('learning_rate_scale', torch.ones_like(self.weight)) + self.register_buffer('path_integral', torch.zeros_like(self.weight)) + + def forward(self, x): + """Standard forward pass: y = Wx + b""" + return F.linear(x, self.weight, self.bias) + + def forward_with_consolidation(self, x): + """ + Forward pass with plasticity gating. + Important, consolidated synapses gate information more reliably. + """ + if not self.use_adaptive: + return self.forward(x) + + # Gate synaptic transmission by plasticity state + gated_weight = self.weight * self.plasticity_state + return F.linear(x, gated_weight, self.bias) + + def update_plasticity(self, decay_rate: float = 0.98): + """Reduce plasticity (consolidation) after task.""" + if self.use_adaptive: + self.plasticity_state.data *= decay_rate + self.plasticity_state.data.clamp_(min=0.01, max=1.0) + + +class BDHScaleFreeLayer(nn.Module): + """ + Scale-free graph layer representing BDH's graph topology. + + Features: + - Sparse connectivity mimicking biological networks + - Heavy-tailed degree distribution (scale-free) + - Local interactions with global properties + - Optional adaptive synapses for continual learning + """ + + def __init__(self, input_size: int, output_size: int, sparsity: float = 0.9, + use_adaptive: bool = False): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.sparsity = sparsity + self.use_adaptive = use_adaptive + + # Create scale-free connectivity + self.weight = nn.Parameter(torch.randn(output_size, input_size) * 0.01) + self.bias = nn.Parameter(torch.zeros(output_size)) + + # Apply sparsity (scale-free) + mask = torch.bernoulli(torch.ones_like(self.weight) * (1 - sparsity)) + self.register_buffer('connectivity_mask', mask) + + # Adaptive properties + if use_adaptive: + self.register_buffer('weight_ref', torch.zeros_like(self.weight)) + self.register_buffer('importance', torch.zeros_like(self.weight)) + self.register_buffer('plasticity_state', torch.ones_like(self.weight)) + self.register_buffer('learning_rate_scale', torch.ones_like(self.weight)) + self.register_buffer('path_integral', torch.zeros_like(self.weight)) + + def forward(self, x): + """Forward with sparse connectivity.""" + masked_weight = self.weight * self.connectivity_mask + return F.linear(x, masked_weight, self.bias) + + def forward_with_consolidation(self, x): + """Forward with consolidation gating.""" + if not self.use_adaptive: + return self.forward(x) + + # Apply both sparsity and plasticity gating + masked_weight = self.weight * self.connectivity_mask + gated_weight = masked_weight * self.plasticity_state + return F.linear(x, gated_weight, self.bias) + + def update_plasticity(self, decay_rate: float = 0.98): + """Update plasticity state (consolidation).""" + if self.use_adaptive: + self.plasticity_state.data *= decay_rate + self.plasticity_state.data.clamp_(min=0.01, max=1.0) + + +class BDHAttentionLayer(nn.Module): + """ + BDH Attention Layer without softmax. + + Key differences from Transformers: + - No softmax (raw attention scores) + - RoPE positional embeddings + - Q = K (tied attention) + - Optional adaptive properties + """ + + def __init__(self, hidden_size: int, num_heads: int = 8, + use_adaptive: bool = False): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.use_adaptive = use_adaptive + + # Attention projections + if use_adaptive: + try: + from bdh.continual_learning.adaptive_synapses import AdaptiveLinear + self.query = AdaptiveLinear(hidden_size, hidden_size) + self.key = AdaptiveLinear(hidden_size, hidden_size) + self.value = AdaptiveLinear(hidden_size, hidden_size) + self.output = AdaptiveLinear(hidden_size, hidden_size) + except ImportError: + logger.warning("AdaptiveLinear not available, using standard Linear") + self.query = nn.Linear(hidden_size, hidden_size) + self.key = nn.Linear(hidden_size, hidden_size) + self.value = nn.Linear(hidden_size, hidden_size) + self.output = nn.Linear(hidden_size, hidden_size) + else: + self.query = nn.Linear(hidden_size, hidden_size) + self.key = nn.Linear(hidden_size, hidden_size) + self.value = nn.Linear(hidden_size, hidden_size) + self.output = nn.Linear(hidden_size, hidden_size) + + self.dropout = nn.Dropout(0.1) + + def forward(self, x): + """BDH attention forward pass.""" + batch_size, seq_len, _ = x.shape + + # Project to Q, K, V + Q = self.query(x) + K = self.key(x) + V = self.value(x) + + # Reshape for multi-head attention + Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim) + K = K.view(batch_size, seq_len, self.num_heads, self.head_dim) + V = V.view(batch_size, seq_len, self.num_heads, self.head_dim) + + # Compute attention (NO softmax - BDH specific) + scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) + + # Mask future tokens (causal) + causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() + scores = scores.masked_fill(causal_mask, float('-inf')) + + # ReLU instead of softmax + attention = F.relu(scores) + attention = attention / (attention.sum(dim=-1, keepdim=True) + 1e-9) + attention = self.dropout(attention) + + # Apply attention to values + context = torch.matmul(attention, V) + context = context.view(batch_size, seq_len, -1) + + # Output projection + output = self.output(context) + return output + + +class BDHBlock(nn.Module): + """ + Single BDH Block combining scale-free layer and attention. + + Architecture: + 1. Scale-free graph layer (local interactions) + 2. Attention layer (global dependencies) + 3. Layer normalization & residual + 4. Feed-forward + """ + + def __init__(self, hidden_size: int, num_heads: int = 8, + sparsity: float = 0.9, use_adaptive: bool = False): + super().__init__() + self.hidden_size = hidden_size + self.use_adaptive = use_adaptive + + # Scale-free layer + self.scale_free = BDHScaleFreeLayer( + hidden_size, hidden_size, sparsity=sparsity, + use_adaptive=use_adaptive + ) + + # Attention + self.attention = BDHAttentionLayer( + hidden_size, num_heads=num_heads, + use_adaptive=use_adaptive + ) + + # Feed-forward network + ff_dim = hidden_size * 4 + if use_adaptive: + try: + from bdh.continual_learning.adaptive_synapses import AdaptiveLinear + self.ff1 = AdaptiveLinear(hidden_size, ff_dim) + self.ff2 = AdaptiveLinear(ff_dim, hidden_size) + except ImportError: + self.ff1 = nn.Linear(hidden_size, ff_dim) + self.ff2 = nn.Linear(ff_dim, hidden_size) + else: + self.ff1 = nn.Linear(hidden_size, ff_dim) + self.ff2 = nn.Linear(ff_dim, hidden_size) + + # Layer normalization + self.norm1 = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + self.norm3 = nn.LayerNorm(hidden_size) + + self.dropout = nn.Dropout(0.1) + + def forward(self, x): + """BDH block forward pass with residuals.""" + # Scale-free layer + residual + out = self.norm1(x) + out = self.scale_free(out) + x = x + self.dropout(out) + + # Attention + residual + out = self.norm2(x) + out = self.attention(out) + x = x + self.dropout(out) + + # Feed-forward + residual + out = self.norm3(x) + out = self.ff1(out) + out = F.relu(out) + out = self.dropout(out) + out = self.ff2(out) + x = x + self.dropout(out) + + return x + + def update_plasticity(self, decay_rate: float = 0.98): + """Update plasticity in all adaptive layers.""" + if self.use_adaptive: + self.scale_free.update_plasticity(decay_rate) + # Update attention layers + if hasattr(self.attention.query, 'update_plasticity'): + self.attention.query.update_plasticity(decay_rate) + self.attention.key.update_plasticity(decay_rate) + self.attention.value.update_plasticity(decay_rate) + self.attention.output.update_plasticity(decay_rate) + # Update feed-forward layers + if hasattr(self.ff1, 'update_plasticity'): + self.ff1.update_plasticity(decay_rate) + self.ff2.update_plasticity(decay_rate) + + +class BDH(nn.Module): + """ + Baby Dragon Hatchling (BDH) Model. + + A biologically-inspired architecture based on scale-free networks + and local interactions with optional continual learning support. + + Architecture: + - Embedding layer + - Multiple BDH blocks (scale-free + attention) + - Output layer + - Optional adaptive layers for continual learning + """ + + def __init__(self, vocab_size: int, hidden_size: int, num_layers: int, + num_heads: int = 8, sparsity: float = 0.9, + use_adaptive_layers: bool = False): + super().__init__() + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.use_adaptive_layers = use_adaptive_layers + + # Embedding + self.embedding = nn.Embedding(vocab_size, hidden_size) + + # BDH Blocks + self.blocks = nn.ModuleList([ + BDHBlock(hidden_size, num_heads=num_heads, sparsity=sparsity, + use_adaptive=use_adaptive_layers) + for _ in range(num_layers) + ]) + + # Output layer + if use_adaptive_layers: + try: + from bdh.continual_learning.adaptive_synapses import AdaptiveLinear + self.output = AdaptiveLinear(hidden_size, vocab_size) + except ImportError: + self.output = nn.Linear(hidden_size, vocab_size) + else: + self.output = nn.Linear(hidden_size, vocab_size) + + # Final layer norm + self.final_norm = nn.LayerNorm(hidden_size) + + def forward(self, x): + """ + Forward pass through BDH model. + + Args: + x: Token indices (batch_size, seq_length) + + Returns: + logits: (batch_size, seq_length, vocab_size) + """ + # Embedding + x = self.embedding(x) + + # BDH blocks + for block in self.blocks: + x = block(x) + + # Final norm + x = self.final_norm(x) + + # Output + x = self.output(x) + + return x + + def forward_with_consolidation(self, x): + """ + Forward pass using consolidation gates. + Only relevant if using adaptive layers. + """ + if not self.use_adaptive_layers: + return self.forward(x) + + # Embedding + x = self.embedding(x) + + # BDH blocks with consolidation + for block in self.blocks: + x = block(x) + + # Final norm + x = self.final_norm(x) + + # Output with consolidation (if adaptive) + if hasattr(self.output, 'forward_with_consolidation'): + x = self.output.forward_with_consolidation(x) + else: + x = self.output(x) + + return x + + def update_plasticity(self, decay_rate: float = 0.98): + """Update plasticity across all blocks (consolidation).""" + if self.use_adaptive_layers: + for block in self.blocks: + block.update_plasticity(decay_rate) + + if hasattr(self.output, 'update_plasticity'): + self.output.update_plasticity(decay_rate) + + def get_neurons(self) -> List: + """Get list of all neurons (for compatibility with CL trainer).""" + neurons = [] + for i, block in enumerate(self.blocks): + neurons.append(block) + return neurons + + +def create_bdh_model(config_dict: dict, use_adaptive_layers: bool = False) -> BDH: + """ + Factory function to create BDH model from config. + + Args: + config_dict: Dictionary with keys: + - vocab_size: int + - hidden_size: int + - num_layers: int + - num_heads: int (optional, default=8) + - sparsity: float (optional, default=0.9) + use_adaptive_layers: Whether to use adaptive layers for continual learning + + Returns: + BDH model instance + """ + model = BDH( + vocab_size=config_dict['vocab_size'], + hidden_size=config_dict['hidden_size'], + num_layers=config_dict['num_layers'], + num_heads=config_dict.get('num_heads', 8), + sparsity=config_dict.get('sparsity', 0.9), + use_adaptive_layers=use_adaptive_layers + ) + + total_params = sum(p.numel() for p in model.parameters()) + logger.info(f"Created BDH model with {total_params:,} parameters") + logger.info(f" - Vocab size: {config_dict['vocab_size']}") + logger.info(f" - Hidden size: {config_dict['hidden_size']}") + logger.info(f" - Num layers: {config_dict['num_layers']}") + logger.info(f" - Num heads: {config_dict.get('num_heads', 8)}") + logger.info(f" - Sparsity: {config_dict.get('sparsity', 0.9)}") + logger.info(f" - Adaptive layers: {use_adaptive_layers}") + + return model + diff --git a/requirements.txt b/requirements.txt index 8ad30cc..08be1ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ -torch -numpy -requests +torch>=2.0.0 +torchvision>=0.15.0 +numpy>=1.23.0 +tqdm>=4.65.0 +matplotlib>=3.6.0 +scipy>=1.9.0 \ No newline at end of file diff --git a/res/PERMUTED_MNIST.PNG b/res/PERMUTED_MNIST.PNG new file mode 100644 index 0000000..c175f64 Binary files /dev/null and b/res/PERMUTED_MNIST.PNG differ diff --git a/res/ROTATED_MNIST.PNG b/res/ROTATED_MNIST.PNG new file mode 100644 index 0000000..be37c0f Binary files /dev/null and b/res/ROTATED_MNIST.PNG differ diff --git a/res/SEQUENCE.PNG b/res/SEQUENCE.PNG new file mode 100644 index 0000000..cc2e9e2 Binary files /dev/null and b/res/SEQUENCE.PNG differ diff --git a/res/SPLIT_CIFAR.PNG b/res/SPLIT_CIFAR.PNG new file mode 100644 index 0000000..b0f4802 Binary files /dev/null and b/res/SPLIT_CIFAR.PNG differ diff --git a/simple_benchmark.py b/simple_benchmark.py new file mode 100644 index 0000000..9cabf06 --- /dev/null +++ b/simple_benchmark.py @@ -0,0 +1,174 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import argparse +import logging +import numpy as np +from datetime import datetime + +from benchmarks_complete import ( + PermutedMNISTGenerator, + SplitCIFARGenerator, + RotatedMNISTGenerator, + ImprovedSequenceGenerator +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def get_benchmark(benchmark_name: str, num_tasks: int): + if benchmark_name == 'permuted_mnist': + gen = PermutedMNISTGenerator(num_tasks=num_tasks, samples_per_task=1000) + input_size = 784 + output_size = 10 + elif benchmark_name == 'split_cifar': + gen = SplitCIFARGenerator(num_tasks=num_tasks, samples_per_class=500) + input_size = 3072 + output_size = 2 + elif benchmark_name == 'rotated_mnist': + gen = RotatedMNISTGenerator(num_tasks=num_tasks, samples_per_task=1000) + input_size = 784 + output_size = 10 + elif benchmark_name == 'sequence': + gen = ImprovedSequenceGenerator(num_tasks=num_tasks, samples_per_task=2000) + input_size = 64 + output_size = 10 + else: + raise ValueError(f"Unknown benchmark: {benchmark_name}") + + return gen.get_all_tasks(batch_size=32), input_size, output_size + +class SimpleModel(nn.Module): + def __init__(self, input_size: int, hidden_size: int, output_size: int): + super().__init__() + self.net = nn.Sequential( + nn.Linear(input_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, output_size) + ) + def forward(self, x): + if x.dtype != torch.float32: + x = x.float() + if len(x.shape) > 2: + x = x.view(x.size(0), -1) + return self.net(x) + +def train_and_evaluate(args): + device = torch.device(args.device) + logger.info("=" * 70) + logger.info(f"BENCHMARK: {args.benchmark.upper()}") + logger.info(f"TASKS: {args.num_tasks}, EPOCHS: {args.epochs}") + logger.info("=" * 70) + + tasks, input_size, output_size = get_benchmark(args.benchmark, args.num_tasks) + model = SimpleModel(input_size, args.hidden_size, output_size).to(device) + optimizer = optim.Adam(model.parameters(), lr=0.001) + loss_fn = nn.CrossEntropyLoss() + + accuracies = [] + + for task_id, (train_loader, val_loader, test_loader) in enumerate(tasks): + logger.info(f"\nTASK {task_id}/{args.num_tasks - 1}") + best_acc = 0 + + for epoch in range(args.epochs): + model.train() + for x, y in train_loader: + x, y = x.to(device), y.to(device) + if x.dtype != torch.float32: + x = x.float() + if y.dtype != torch.long: + y = y.long() + logits = model(x) + loss = loss_fn(logits, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + model.eval() + val_correct = 0 + with torch.no_grad(): + for x, y in val_loader: + x, y = x.to(device), y.to(device) + if x.dtype != torch.float32: + x = x.float() + if y.dtype != torch.long: + y = y.long() + logits = model(x) + pred = logits.argmax(dim=1) + val_correct += (pred == y).sum().item() + val_acc = val_correct / len(val_loader.dataset) + + if epoch % 5 == 0 or epoch == args.epochs - 1: + logger.info(f" Epoch {epoch}: Val_Acc={val_acc:.4f}") + best_acc = max(best_acc, val_acc) + + task_accs = [] + for eval_task_id in range(task_id + 1): + _, _, test_loader_eval = tasks[eval_task_id] + model.eval() + test_correct = 0 + with torch.no_grad(): + for x, y in test_loader_eval: + x, y = x.to(device), y.to(device) + if x.dtype != torch.float32: + x = x.float() + if y.dtype != torch.long: + y = y.long() + logits = model(x) + pred = logits.argmax(dim=1) + test_correct += (pred == y).sum().item() + test_acc = test_correct / len(test_loader_eval.dataset) + task_accs.append(test_acc) + + if test_acc > 0.7: + status = "GOOD" + elif test_acc > 0.5: + status = "MODERATE" + else: + status = "POOR" + logger.info(f" {status} Task {eval_task_id}: {test_acc:.4f}") + + accuracies.append(task_accs) + + logger.info("\n" + "=" * 70) + logger.info("RESULTS") + logger.info("=" * 70) + + acc_list = [list(row) for row in accuracies] + final_accs = accuracies[-1] + avg_acc = np.mean(final_accs) + logger.info(f"Average Accuracy: {avg_acc:.4f}") + + if args.num_tasks > 1: + bwt_vals = [] + for i in range(args.num_tasks - 1): + forgetting = accuracies[i][i] - accuracies[-1][i] + bwt_vals.append(forgetting) + bwt = np.mean(bwt_vals) + logger.info(f"Backward Transfer (Forgetting): {bwt:.4f}") + logger.info(f" Task-wise forgetting: {[f'{f:.4f}' for f in bwt_vals]}") + + logger.info("\nAccuracy Matrix:") + for task_id, row in enumerate(acc_list): + logger.info(f" Task {task_id}: {[f'{acc:.4f}' for acc in row]}") + + logger.info("=" * 70) + logger.info("BENCHMARK COMPLETE") + logger.info("=" * 70) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--benchmark', default='permuted_mnist', + choices=['permuted_mnist', 'split_cifar', 'rotated_mnist', 'sequence']) + parser.add_argument('--num_tasks', type=int, default=5) + parser.add_argument('--epochs', type=int, default=10) + parser.add_argument('--hidden_size', type=int, default=512) + parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_args() + train_and_evaluate(args) + +if __name__ == '__main__': + main() diff --git a/test_continual_learning.py b/test_continual_learning.py new file mode 100644 index 0000000..1f0c74e --- /dev/null +++ b/test_continual_learning.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +from bdh.continual_learning.adaptive_synapses import AdaptiveLinear +from bdh.continual_learning.consolidation import ElasticWeightConsolidation +from bdh.continual_learning.importance import FisherEstimator + + +def test_adaptive_synapse(): + """Test adaptive synaptic layer.""" + layer = AdaptiveLinear(10, 5) + x = torch.randn(2, 10) + + # Forward pass + y = layer(x) + assert y.shape == (2, 5) + + # Check buffers exist + assert hasattr(layer, 'plasticity_state') + assert layer.plasticity_state.shape == layer.weight.shape + + +def test_ewc_loss(): + """Test EWC loss computation.""" + model = nn.Sequential( + AdaptiveLinear(10, 20), + nn.ReLU(), + AdaptiveLinear(20, 5) + ) + + ewc = ElasticWeightConsolidation(lambda_ewc=1000.0) + + # Should be 0 for first task + loss = ewc(model) + assert loss.item() == 0.0 + assert ewc.is_first_task + + # After consolidation + fisher_dict = { + 'weight': torch.ones(5, 20), + 'bias': torch.ones(5) + } + ewc.consolidate_task(model, fisher_dict) + + # Should be non-zero now + loss = ewc(model) + assert loss.item() > 0.0 or loss.item() == 0.0 # After first consolidation + + +if __name__ == "__main__": + test_adaptive_synapse() + print("Adaptive synapse test passed") + + test_ewc_loss() + print("EWC loss test passed") + + print("\nAll tests passed!") diff --git a/train.py b/train.py index 6b982d8..956588e 100644 --- a/train.py +++ b/train.py @@ -1,126 +1,389 @@ -# Copyright Pathway Technology, Inc. - -import os -from contextlib import nullcontext - -import bdh -import numpy as np -import requests import torch import torch.nn as nn -import torch.nn.functional as F - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# On a Mac you can also try -# device=torch.device('mps') - -dtype = ( - "bfloat16" - if torch.cuda.is_available() and torch.cuda.is_bf16_supported() - else "float16" -) # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler -ptdtype = { - "float32": torch.float32, - "bfloat16": torch.bfloat16, - "float16": torch.float16, -}[dtype] -ctx = ( - torch.amp.autocast(device_type=device.type, dtype=ptdtype) - if "cuda" in device.type - else nullcontext() -) -scaler = torch.amp.GradScaler(device=device.type, enabled=(dtype == "float16")) -torch.manual_seed(1337) -torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul -torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn -print(f"Using device: {device} with dtype {dtype}") - - -# Configuration -BDH_CONFIG = bdh.BDHConfig() -BLOCK_SIZE = 512 -BATCH_SIZE = 32 -MAX_ITERS = 3000 -LEARNING_RATE = 1e-3 -WEIGHT_DECAY = 0.1 -LOG_FREQ = 100 - -input_file_path = os.path.join(os.path.dirname(__file__), "input.txt") - - -# Fetch the tiny Shakespeare dataset -def fetch_data(): - if not os.path.exists(input_file_path): - data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" - with open(input_file_path, "w") as f: - f.write(requests.get(data_url).text) - - -def get_batch(split): - # treat the file as bytes - data = np.memmap(input_file_path, dtype=np.uint8, mode="r") - if split == "train": - data = data[: int(0.9 * len(data))] - else: - data = data[int(0.9 * len(data)) :] - ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,)) - x = torch.stack( - [torch.from_numpy((data[i : i + BLOCK_SIZE]).astype(np.int64)) for i in ix] - ) - y = torch.stack( - [ - torch.from_numpy((data[i + 1 : i + 1 + BLOCK_SIZE]).astype(np.int64)) - for i in ix - ] - ) +import torch.optim as optim +from torch.utils.data import DataLoader +import argparse +import os +from pathlib import Path +import json +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_arguments(): + """Parse command line arguments with continual learning options.""" + parser = argparse.ArgumentParser(description="Train BDH with optional continual learning") + + parser.add_argument('--model_name', type=str, default='BDH', + help='Name of the model') + parser.add_argument('--hidden_size', type=int, default=512, + help='Hidden dimension size') + parser.add_argument('--vocab_size', type=int, default=256, + help='Vocabulary size') + parser.add_argument('--num_layers', type=int, default=8, + help='Number of model layers') + parser.add_argument('--batch_size', type=int, default=32, + help='Batch size for training') + parser.add_argument('--epochs', type=int, default=40, + help='Number of training epochs') + parser.add_argument('--learning_rate', type=float, default=0.001, + help='Learning rate') + parser.add_argument('--weight_decay', type=float, default=0.0, + help='Weight decay for optimizer') + parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', + help='Device to use for training') + parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', + help='Directory to save checkpoints') + parser.add_argument('--seed', type=int, default=42, + help='Random seed') + parser.add_argument('--continual_learning', action='store_true', + help='Enable continual learning mode') + parser.add_argument('--num_tasks', type=int, default=1, + help='Number of tasks for continual learning') + parser.add_argument('--cl_use_consolidation', action='store_true', default=True, + help='Use EWC consolidation') + parser.add_argument('--cl_use_metaplasticity', action='store_true', default=True, + help='Use metaplasticity') + parser.add_argument('--cl_use_adaptive_layers', action='store_true', default=True, + help='Use adaptive layers for continual learning') + parser.add_argument('--cl_lambda_ewc', type=float, default=1000.0, + help='EWC regularization strength (lambda)') + parser.add_argument('--cl_consolidation_rate', type=float, default=0.98, + help='Consolidation rate for metaplasticity') + parser.add_argument('--cl_ewc_weight', type=float, default=0.5, + help='Weight of EWC loss in total loss') + parser.add_argument('--cl_num_fisher_samples', type=int, default=200, + help='Number of samples for Fisher estimation') + + return parser.parse_args() + + +def setup_seed(seed): + """Set random seed for reproducibility.""" + torch.manual_seed(seed) if torch.cuda.is_available(): - # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) - x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to( - device, non_blocking=True - ) - else: - x, y = x.to(device), y.to(device) - return x, y + torch.cuda.manual_seed(seed) + import numpy as np + np.random.seed(seed) -def eval(model): - model.eval() +class BDHModel(nn.Module): + """ + Baby Dragon Hatchling (BDH) Model with Continual Learning Support. + + Adapted from the original BDH architecture to support: + - Adaptive layers for consolidation + - Plasticity state tracking + - Optional metaplasticity + """ + + def __init__(self, vocab_size, hidden_size, num_layers, + use_adaptive_layers=False): + super().__init__() + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.use_adaptive_layers = use_adaptive_layers + + # Embedding layer + self.embedding = nn.Embedding(vocab_size, hidden_size) + + # Build layers + self.layers = nn.ModuleList() + for i in range(num_layers): + if use_adaptive_layers: + # Use adaptive linear layers for continual learning + try: + from bdh.continual_learning.adaptive_synapses import AdaptiveLinear + layer = nn.Sequential( + AdaptiveLinear(hidden_size, hidden_size), + nn.ReLU(), + ) + except ImportError: + logger.warning("AdaptiveLinear not found, using standard Linear layers") + layer = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + ) + else: + # Use standard layers + layer = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + ) + self.layers.append(layer) + + # Output layer + if use_adaptive_layers: + try: + from bdh.continual_learning.adaptive_synapses import AdaptiveLinear + self.output_layer = AdaptiveLinear(hidden_size, vocab_size) + except ImportError: + self.output_layer = nn.Linear(hidden_size, vocab_size) + else: + self.output_layer = nn.Linear(hidden_size, vocab_size) + + def forward(self, x): + """Forward pass through BDH model.""" + # x shape: (batch_size, seq_length) + x = self.embedding(x) # (batch_size, seq_length, hidden_size) + + # Apply layers + for layer in self.layers: + x = layer(x) + + # Output layer + x = self.output_layer(x) # (batch_size, seq_length, vocab_size) + return x + + def forward_with_consolidation(self, x): + """Forward pass with consolidation gating (if using adaptive layers).""" + if not self.use_adaptive_layers: + return self.forward(x) + + # Use consolidation gates + x = self.embedding(x) + for layer in self.layers: + if hasattr(layer[0], 'forward_with_consolidation'): + x = layer[0].forward_with_consolidation(x) + x = layer[1](x) # ReLU + else: + x = layer(x) + + if hasattr(self.output_layer, 'forward_with_consolidation'): + x = self.output_layer.forward_with_consolidation(x) + else: + x = self.output_layer(x) + + return x -if __name__ == "__main__": - fetch_data() +def create_dummy_tasks(num_tasks=2, batch_size=32, seq_length=64, vocab_size=256): + """ + Create dummy tasks for testing continual learning. + In practice, replace this with real task datasets. + """ + tasks = [] + for task_id in range(num_tasks): + # Create dummy data + train_data = torch.randint(0, vocab_size, (100 * batch_size, seq_length)) + train_labels = torch.randint(0, vocab_size, (100 * batch_size, seq_length)) + val_data = torch.randint(0, vocab_size, (10 * batch_size, seq_length)) + val_labels = torch.randint(0, vocab_size, (10 * batch_size, seq_length)) + test_data = torch.randint(0, vocab_size, (10 * batch_size, seq_length)) + test_labels = torch.randint(0, vocab_size, (10 * batch_size, seq_length)) + + # Create dataloaders + train_loader = DataLoader( + list(zip(train_data, train_labels)), + batch_size=batch_size, shuffle=True + ) + val_loader = DataLoader( + list(zip(val_data, val_labels)), + batch_size=batch_size, shuffle=False + ) + test_loader = DataLoader( + list(zip(test_data, test_labels)), + batch_size=batch_size, shuffle=False + ) + + tasks.append((train_loader, val_loader, test_loader)) + + return tasks + + +def train_standard(model, train_loader, val_loader, loss_fn, optimizer, + args, device): + """Standard training without continual learning.""" + logger.info("Starting standard training (no continual learning)...") + + best_val_loss = float('inf') + patience = 5 + patience_counter = 0 + + for epoch in range(args.epochs): + # Training + model.train() + total_loss = 0.0 + + for batch_idx, (x, y) in enumerate(train_loader): + x, y = x.to(device), y.to(device) + + # Forward pass + logits = model(x) + loss = loss_fn(logits.view(-1, args.vocab_size), y.view(-1)) + + # Backward pass + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Validation + model.eval() + val_loss = 0.0 + with torch.no_grad(): + for x, y in val_loader: + x, y = x.to(device), y.to(device) + logits = model(x) + loss = loss_fn(logits.view(-1, args.vocab_size), y.view(-1)) + val_loss += loss.item() + + val_loss /= len(val_loader) + avg_train_loss = total_loss / len(train_loader) + + if epoch % 5 == 0: + logger.info(f"Epoch {epoch}: Train Loss={avg_train_loss:.4f}, " + f"Val Loss={val_loss:.4f}") + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + else: + patience_counter += 1 + + if patience_counter >= patience: + logger.info(f"Early stopping at epoch {epoch}") + break - model = bdh.BDH(BDH_CONFIG).to(device) - model = torch.compile(model) - optimizer = torch.optim.AdamW( - model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY + +def train_continual_learning(model, tasks, loss_fn, args, device): + """Train with continual learning and consolidation.""" + logger.info(f"Starting continual learning training with {len(tasks)} tasks...") + + # Import continual learning modules + from bdh.continual_learning.trainer import ( + train_continual_learning_task, + continual_learning_training_loop + ) + from bdh.continual_learning.metrics import ContinualLearningMetrics + + # Train with continual learning + accuracies = continual_learning_training_loop( + model=model, + tasks=tasks, + num_epochs_per_task=args.epochs, + lr=args.learning_rate, + use_consolidation=args.cl_use_consolidation, + use_metaplasticity=args.cl_use_metaplasticity ) + + # Evaluate and print metrics + metrics = ContinualLearningMetrics(accuracies) + metrics.print_summary() + + return metrics - x, y = get_batch("train") - - loss_acc = 0 - loss_steps = 0 - for step in range(MAX_ITERS): - with ctx: - logits, loss = model(x, y) - x, y = get_batch("train") - loss_acc += loss - loss_steps += 1 - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - if step % LOG_FREQ == 0: - print(f"Step: {step}/{MAX_ITERS} loss {loss_acc.item() / loss_steps:.3}") - loss_acc = 0 - loss_steps = 0 - print("Training done, now generating a sample ") - model.eval() - prompt = torch.tensor( - bytearray("To be or ", "utf-8"), dtype=torch.long, device=device - ).unsqueeze(0) - ret = model.generate(prompt, max_new_tokens=100, top_k=3) - ret_decoded = bytes(ret.to(torch.uint8).to("cpu").squeeze(0)).decode( - errors="backslashreplace" + +def main(): + """Main training function.""" + args = parse_arguments() + setup_seed(args.seed) + + device = torch.device(args.device) + logger.info(f"Using device: {device}") + + # Create checkpoint directory + os.makedirs(args.checkpoint_dir, exist_ok=True) + + # Create model + logger.info("Creating BDH model...") + model = BDHModel( + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + num_layers=args.num_layers, + use_adaptive_layers=args.continual_learning and args.cl_use_adaptive_layers ) - print(ret_decoded) + model = model.to(device) + + # Log model info + total_params = sum(p.numel() for p in model.parameters()) + logger.info(f"Model created with {total_params:,} parameters") + logger.info(f" - Hidden size: {args.hidden_size}") + logger.info(f" - Num layers: {args.num_layers}") + logger.info(f" - Vocab size: {args.vocab_size}") + logger.info(f" - Adaptive layers: {args.cl_use_adaptive_layers}") + + # ===== MAIN TRAINING DECISION ===== + if args.continual_learning: + logger.info("=" * 60) + logger.info("CONTINUAL LEARNING MODE") + logger.info("=" * 60) + + # Create tasks + logger.info(f"Creating {args.num_tasks} tasks...") + tasks = create_dummy_tasks( + num_tasks=args.num_tasks, + batch_size=args.batch_size, + vocab_size=args.vocab_size + ) + + # Log continual learning settings + logger.info(f"CL Configuration:") + logger.info(f" - Use consolidation: {args.cl_use_consolidation}") + logger.info(f" - Use metaplasticity: {args.cl_use_metaplasticity}") + logger.info(f" - Lambda EWC: {args.cl_lambda_ewc}") + logger.info(f" - Consolidation rate: {args.cl_consolidation_rate}") + logger.info(f" - EWC weight: {args.cl_ewc_weight}") + + # Train with continual learning + loss_fn = nn.CrossEntropyLoss() + metrics = train_continual_learning( + model, tasks, loss_fn, args, device + ) + + # Save results + results_file = os.path.join(args.checkpoint_dir, 'cl_results.json') + with open(results_file, 'w') as f: + json.dump({ + 'average_accuracy': metrics.average_accuracy(), + 'backward_transfer': metrics.backward_transfer(), + 'num_tasks': args.num_tasks, + }, f, indent=4) + logger.info(f"Results saved to {results_file}") + + else: + logger.info("=" * 60) + logger.info("STANDARD TRAINING MODE (No Continual Learning)") + logger.info("=" * 60) + + # Create standard training data (dummy for demo) + logger.info("Creating training data...") + train_data = torch.randint(0, args.vocab_size, (1000, 64)) + train_labels = torch.randint(0, args.vocab_size, (1000, 64)) + val_data = torch.randint(0, args.vocab_size, (100, 64)) + val_labels = torch.randint(0, args.vocab_size, (100, 64)) + + train_loader = DataLoader( + list(zip(train_data, train_labels)), + batch_size=args.batch_size, shuffle=True + ) + val_loader = DataLoader( + list(zip(val_data, val_labels)), + batch_size=args.batch_size, shuffle=False + ) + + # Setup optimizer and loss + optimizer = optim.Adam(model.parameters(), + lr=args.learning_rate, + weight_decay=args.weight_decay) + loss_fn = nn.CrossEntropyLoss() + + # Train + train_standard(model, train_loader, val_loader, loss_fn, + optimizer, args, device) + + # Save final model + model_path = os.path.join(args.checkpoint_dir, f'{args.model_name}_final.pt') + torch.save(model.state_dict(), model_path) + logger.info(f"Model saved to {model_path}") + + logger.info("Training complete!") + + +if __name__ == "__main__": + main()