diff --git a/README.md b/README.md
index d548cb3..de39df6 100644
--- a/README.md
+++ b/README.md
@@ -1,95 +1,84 @@
-# Baby Dragon Hatchling
+# Baby Dragon Hatchling Continual Learning (BDH-CL)
-## **Bridging the Gap Between Transformers and the Brain**
+**Fork of**: [pathwaycom/bdh](https://github.com/pathwaycom/bdh)
-**Baby Dragon Hatchling (BDH)** is a biologically inspired large language model architecture that connects principles of deep learning with the foundations of neuroscience. Developed by researchers at [Pathway](https://pathway.com), BDH provides a theoretical and practical framework for understanding the emergence of reasoning and generalization in artificial systems.
-
-This repository contains the official implementation from the paper:
-> *A. Kosowski, P. Uznański, J. Chorowski, Z. Stamirowska, M. Bartoszkiewicz.*
-> [_The Dragon Hatchling: The Missing Link between the Transformer and Models of the Brain_](https://doi.org/10.48550/arXiv.2509.26507), arXiv (2025).
+***
+## Introduction
-## Overview
+This repository attempts to extend the original Baby Dragon Hatchling (BDH) architecture, a biologically inspired large language model bridging transformers and neural computation, by integrating **continual learning** mechanisms inspired by biological synaptic plasticity.
-BDH represents a **scale-free, locally interacting network of neurons** capable of intrinsic reasoning dynamics. BDH scales like a Transformer on performance benchmarks—yet retains full interpretability and theoretical grounding in the fine-grained dynamics of neuron interactions.
+The key contribution of this fork is the implementation of **Adaptive Synaptic Consolidation**, enabling BDH to learn multiple tasks sequentially without catastrophic forgetting, in the spirit of Zenke et al.'s *Continual Learning Through Synaptic Intelligence* (2017).
-**Key properties:**
+***
-- **Scale-free network topology** mimicking biological connectivity
-- **Locally interacting neuron particles** with excitatory/inhibitory dynamics
-- **Hebbian working memory** based on synaptic plasticity, displaying monosemanticity
-- **GPU-friendly state-space formulation** for efficient implementation
-- **Interpretable activations** that are sparse and positive
+## Highlights of Changes and Improvements
-BDH formalizes a bridge between **neural computation and machine-based language understanding**. It shows how **macro reasoning behavior** in large AI models emerges from **micro-level neuron dynamics**, guided by principles of graph theory and local computation.
+### Continual Learning Integration
-Empirically, BDH matches **GPT-2–scale Transformers** across language and translation tasks at equivalent parameter scales (10M–1B).
+- Added **Elastic Weight Consolidation (EWC)** with Fisher information estimation to protect important weights from overwriting during new tasks.
+- Implemented **adaptive synaptic gates** that regulate plasticity at the neuron level, inspired by biological metaplasticity.
+- Integrated **path integral online importance measures** for efficient tracking of weight significance during training.
+- Supported **multi-task sequential training** enabling scalable lifelong learning.
+## Benchmarking Suite
-***
+|  |  |
+|:----------------------------------------------------:|:-------------------------------------------------:|
+| Permuted MNIST (Simple) | Rotated MNIST (Simple) |
-## Architecture
+|  |  |
+|:----------------------------------------------------:|:-------------------------------------------------:|
+| Split CIFAR (Simple) | Sequence (Simple) |
-
***
-## Relation to Transformers
+## How to Use
-
+- Install dependencies:
-BDH and the Transformer share attention-inspired computation; however, BDH’s graph-based architecture makes its attention **emerge naturally from neuron-level interactions**, reflecting attention as seen in biological systems.
+ ```bash
+ pip install -r requirements.txt
+ ```
-***
+- Train BDHC with continual learning enabled:
-## Scaling Laws
+ ```bash
+ train.py --continual_learning
+ ```
-
+- Run simple benchmarks:
-BDH follows **Transformer-like scaling laws**, maintaining parameter efficiency while achieving interpretability at any scale.
+ ```bash
+ simple_benchmark.py --benchmark permuted_mnist --num_tasks 5 --epochs 10
+
+ simple_benchmark.py --benchmark split_cifar --num_tasks 5 --epochs 10
+
+ simple_benchmark.py --benchmark rotated_mnist --num_tasks 10 --epochs 10
+
+ simple_benchmark.py --benchmark sequence --num_tasks 5 --epochs 10
+ ```
***
-## Installation and Training
-
-```bash
-# install dependencies
-pip install -r requirements.txt
-
-# train BDH on a toy dataset
-python train.py
-```
+## Credits
-
+This project builds upon and extends the original [Baby Dragon Hatchling repository by Pathway](https://github.com/pathwaycom/bdh).
+The original authors' foundational work on biologically inspired neural architectures underpins this extension.
+***
-## Learn and Discuss
-
-- Watch the *SuperDataScience podcast* [▶️ *Dragon Hatchling: The Missing Link Between Transformers and the Brain*](https://www.youtube.com/watch?v=mfV44-mtg7c) (72 min.) featuring Adrian Kosowski in conversation with Jon Krohn, unpacking BDH’s neuron-level architecture and sparse reasoning dynamics.
-
-- Read about BDH in
-[*Forbes*](https://www.forbes.com/sites/victordey/2025/10/08/can-ai-learn-and-evolve-like-a-brain-pathways-bold-research-thinks-so/),
-[*Semafor*](https://www.semafor.com/article/10/01/2025/new-ai-research-claims-to-be-getting-closer-to-modeling-human-brain),
-[*The Turing Post*](https://www.turingpost.com/p/fod-121-300-million-to-start-a-big-promise-for-science#the-freshest-research-papers-catego),
-[*Quantum Zeitgeist*](https://quantumzeitgeist.com/palo-alto-ai-firm-pathway-unveils-post-transformer-architecture-for-autonomous-ai/),
-[*Golem*](https://www.golem.de/news/neue-ki-architektur-was-ist-baby-dragon-hatchling-2510-201047-2.html),
-and elsewhere in the media.
-
-- Discuss and share the BDH paper on:
-[*Hugging Face Papers*](https://huggingface.co/papers/2509.26507),
-[*Alphaxiv*](https://alphaxiv.org/abs/2509.26507),
-and [*EmergentMind*](https://emergentmind.com/papers/2509.26507).
+## References
+- Vibe coding was involved; Py is not my primary language.
+- Zenke et al., *Continual Learning Through Synaptic Intelligence*, ICML 2017
+- Kosowski et al., *The Dragon Hatchling: The Missing Link between the Transformer and Models of the Brain*, arXiv 2025
-## Community Projects
+***
-- [adamskrodzki/bdh](https://github.com/adamskrodzki/bdh): dynamic vocabulary, stateful attention
-- [mosure/burn_dragon_hatchling](https://github.com/mosure/burn_dragon_hatchling): Burn port
-- [severian42/bdh](https://github.com/severian42/bdh): MLX port
-- [Git-Faisal/bdh](https://github.com/Git-Faisal/bdh)
-- [GrahLnn/bdh](https://github.com/GrahLnn/bdh)
+## Summary
-## Acknowledgements
-We thank Andrej Karpathy for the [nanoGPT](https://github.com/karpathy/nanoGPT/) code and the tiny Shapespeare dataset used in this demonstration.
+BDH-CL introduces practical, biologically inspired continual learning capabilities into the BDH architecture, enabling robust lifelong learning beyond the single-task limitations of the original. It offers a unique blend of neuroscience theory and state-of-the-art machine learning applied to next-generation language models.
-BDH research stands at the intersection of **AI architecture**, **biological learning models**, and **theoretical computer science**—an effort to map the *equations of reasoning* between artificial and biological intelligence.
+***
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/bdh.py b/bdh.py
deleted file mode 100644
index 4cfff79..0000000
--- a/bdh.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# Copyright 2025 Pathway Technology, Inc.
-
-import dataclasses
-import math
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-
-@dataclasses.dataclass
-class BDHConfig:
- n_layer: int = 6
- n_embd: int = 256
- dropout: float = 0.1
- n_head: int = 4
- mlp_internal_dim_multiplier: int = 128
- vocab_size: int = 256
-
-
-def get_freqs(n, theta, dtype):
- def quantize(t, q=2):
- return (t / q).floor() * q
-
- return (
- 1.0
- / (theta ** (quantize(torch.arange(0, n, 1, dtype=dtype)) / n))
- / (2 * math.pi)
- )
-
-
-class Attention(torch.nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- nh = config.n_head
- D = config.n_embd
- N = config.mlp_internal_dim_multiplier * D // nh
- self.freqs = torch.nn.Buffer(
- get_freqs(N, theta=2**16, dtype=torch.float32).view(1, 1, 1, N)
- )
-
- @staticmethod
- def phases_cos_sin(phases):
- phases = (phases % 1) * (2 * math.pi)
- phases_cos = torch.cos(phases)
- phases_sin = torch.sin(phases)
- return phases_cos, phases_sin
-
- @staticmethod
- def rope(phases, v):
- v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size())
- phases_cos, phases_sin = Attention.phases_cos_sin(phases)
- return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype)
-
- def forward(self, Q, K, V):
- assert self.freqs.dtype == torch.float32
- assert K is Q
- _, _, T, _ = Q.size()
-
- r_phases = (
- torch.arange(
- 0,
- T,
- device=self.freqs.device,
- dtype=self.freqs.dtype,
- ).view(1, 1, -1, 1)
- ) * self.freqs
- QR = self.rope(r_phases, Q)
- KR = QR
-
- # Current attention
- scores = (QR @ KR.mT).tril(diagonal=-1)
- return scores @ V
-
-
-class BDH(nn.Module):
- def __init__(self, config: BDHConfig):
- super().__init__()
- assert config.vocab_size is not None
- self.config = config
- nh = config.n_head
- D = config.n_embd
- N = config.mlp_internal_dim_multiplier * D // nh
- self.decoder = nn.Parameter(torch.zeros((nh * N, D)).normal_(std=0.02))
- self.encoder = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02))
-
- self.attn = Attention(config)
-
- self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False)
- self.embed = nn.Embedding(config.vocab_size, D)
- self.drop = nn.Dropout(config.dropout)
- self.encoder_v = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02))
-
- self.lm_head = nn.Parameter(
- torch.zeros((D, config.vocab_size)).normal_(std=0.02)
- )
-
- self.apply(self._init_weights)
-
- def _init_weights(self, module):
- if isinstance(module, nn.Linear):
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif isinstance(module, nn.Embedding):
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
-
- def forward(self, idx, targets=None):
- C = self.config
-
- B, T = idx.size()
- D = C.n_embd
- nh = C.n_head
- N = D * C.mlp_internal_dim_multiplier // nh
-
- x = self.embed(idx).unsqueeze(1)
-
- # actually helps with training
- x = self.ln(x) # B, 1, T, D
-
- for level in range(C.n_layer):
- x_latent = x @ self.encoder
-
- x_sparse = F.relu(x_latent) # B, nh, T, N
-
- yKV = self.attn(
- Q=x_sparse,
- K=x_sparse,
- V=x,
- )
- yKV = self.ln(yKV)
-
- y_latent = yKV @ self.encoder_v
- y_sparse = F.relu(y_latent)
- xy_sparse = x_sparse * y_sparse # B, nh, T, N
-
- xy_sparse = self.drop(xy_sparse)
-
- yMLP = (
- xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh) @ self.decoder
- ) # B, 1, T, D
- y = self.ln(yMLP)
- x = self.ln(x + y)
-
- logits = x.view(B, T, D) @ self.lm_head
- loss = None
- if targets is not None:
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
-
- return logits, loss
-
- @torch.no_grad()
- def generate(
- self,
- idx: torch.Tensor,
- max_new_tokens: int,
- temperature: float = 1.0,
- top_k: int | None = None,
- ) -> torch.Tensor:
- for _ in range(max_new_tokens):
- idx_cond = idx
- logits, _ = self(idx_cond)
- logits = logits[:, -1, :] / temperature
- if top_k is not None:
- values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
- logits[logits < values[:, [-1]]] = float("-inf")
- probs = F.softmax(logits, dim=-1)
- idx_next = torch.multinomial(probs, num_samples=1)
- idx = torch.cat((idx, idx_next), dim=1)
- return idx
diff --git a/benchmarks_complete.py b/benchmarks_complete.py
new file mode 100644
index 0000000..9639181
--- /dev/null
+++ b/benchmarks_complete.py
@@ -0,0 +1,414 @@
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader, TensorDataset
+import numpy as np
+from typing import List, Tuple, Optional
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class PermutedMNISTGenerator:
+ """
+ Permuted MNIST: Standard continual learning benchmark.
+ Each task has same MNIST digits but with different pixel permutation.
+
+ Reference: Kirkpatrick et al., "Continual Learning Through Synaptic Intelligence"
+ """
+
+ def __init__(self, num_tasks: int = 10, samples_per_task: int = 1000, seed: int = 42):
+ self.num_tasks = num_tasks
+ self.samples_per_task = samples_per_task
+ self.seed = seed
+ self.permutations = self._generate_permutations()
+
+ def _generate_permutations(self) -> List[np.ndarray]:
+ """Generate random pixel permutations for each task."""
+ np.random.seed(self.seed)
+ permutations = []
+
+ for _ in range(self.num_tasks):
+ perm = np.random.permutation(784) # 28×28 = 784 pixels
+ permutations.append(perm)
+
+ return permutations
+
+ def generate_synthetic_mnist(self, task_id: int, num_samples: int = 1000) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Generate synthetic MNIST-like data with task-specific permutation."""
+ np.random.seed(self.seed + task_id)
+
+ images = []
+ labels = []
+
+ # 10 digit classes
+ for digit in range(10):
+ samples_per_digit = num_samples // 10
+
+ for _ in range(samples_per_digit):
+ # Create digit pattern
+ img = torch.zeros(28, 28)
+
+ # Draw simple digit shape
+ center = 14 + (digit % 5) * 0.5 - 2
+ y, x = np.ogrid[:28, :28]
+ mask = (x - center)**2 + (y - center)**2 <= (8 + digit)**2
+ img[mask] = 1.0
+
+ # Add noise
+ img = img + torch.randn_like(img) * 0.3
+ img = torch.clamp(img, 0, 1)
+
+ # Flatten and apply permutation
+ img_flat = img.view(-1)
+ img_permuted = img_flat[self.permutations[task_id]]
+
+ images.append(img_permuted)
+ labels.append(digit)
+
+ images = torch.stack(images)
+ labels = torch.tensor(labels, dtype=torch.long)
+
+ return images, labels
+
+ def get_task(self, task_id: int, batch_size: int = 32, split: str = 'train') -> DataLoader:
+ """Get DataLoader for a specific task."""
+ if split == 'train':
+ num_samples = self.samples_per_task
+ else:
+ num_samples = self.samples_per_task // 5
+
+ images, labels = self.generate_synthetic_mnist(task_id, num_samples)
+
+ dataset = TensorDataset(images, labels)
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=(split == 'train'))
+
+ return dataloader
+
+ def get_all_tasks(self, batch_size: int = 32) -> List[Tuple[DataLoader, DataLoader, DataLoader]]:
+ """Get train/val/test splits for all tasks."""
+ tasks = []
+
+ for task_id in range(self.num_tasks):
+ train_loader = self.get_task(task_id, batch_size, split='train')
+ val_loader = self.get_task(task_id, batch_size, split='val')
+ test_loader = self.get_task(task_id, batch_size, split='test')
+
+ tasks.append((train_loader, val_loader, test_loader))
+
+ return tasks
+
+
+class SplitCIFARGenerator:
+ """
+ Split CIFAR-10: Class-incremental learning benchmark.
+ Each task contains 2 different CIFAR-10 classes.
+
+ Task 0: [airplane, automobile]
+ Task 1: [bird, cat]
+ Task 2: [deer, dog]
+ Task 3: [frog, horse]
+ Task 4: [ship, truck]
+ """
+
+ def __init__(self, num_tasks: int = 5, samples_per_class: int = 500, seed: int = 42):
+ self.num_tasks = num_tasks
+ self.samples_per_class = samples_per_class
+ self.seed = seed
+ self.classes_per_task = 10 // num_tasks
+
+ def generate_synthetic_cifar(self, class_id: int, num_samples: int = 500) -> torch.Tensor:
+ """Generate synthetic CIFAR-10-like 32×32 RGB images."""
+ np.random.seed(self.seed + class_id)
+
+ images = []
+
+ for _ in range(num_samples):
+ # Create image with class-specific pattern
+ img = torch.zeros(3, 32, 32)
+
+ # Different color patterns for different classes
+ if class_id == 0: # airplane
+ img[0, 8:24, 8:24] = torch.ones(16, 16)
+ img[1, 10:22, 10:22] = torch.ones(12, 12)
+ elif class_id == 1: # automobile
+ img[1, 10:22, 8:24] = torch.ones(12, 16)
+ img[2, 15:20, 15:20] = torch.ones(5, 5)
+ elif class_id == 2: # bird
+ img[1, 5:20, 5:20] = torch.ones(15, 15)
+ elif class_id == 3: # cat
+ img[0, 8:20, 8:20] = torch.ones(12, 12)
+ elif class_id == 4: # deer
+ img[1, 8:20, 8:20] = torch.ones(12, 12)
+ elif class_id == 5: # dog
+ img[0, 10:22, 10:22] = torch.ones(12, 12)
+ elif class_id == 6: # frog
+ img[2, 12:24, 12:24] = torch.ones(12, 12)
+ elif class_id == 7: # horse
+ img[0, 10:24, 10:24] = torch.ones(14, 14)
+ elif class_id == 8: # ship
+ img[2, 8:24, 8:24] = torch.ones(16, 16)
+ elif class_id == 9: # truck
+ img[0, 8:24, 10:22] = torch.ones(16, 12)
+ else:
+ img = torch.bernoulli(torch.full((3, 32, 32), 0.3))
+
+ # Add noise
+ img = img + torch.randn_like(img) * 0.3
+ img = torch.clamp(img, 0, 1)
+
+ images.append(img)
+
+ return torch.stack(images)
+
+ def get_task(self, task_id: int, batch_size: int = 32, split: str = 'train') -> DataLoader:
+ """Get DataLoader for task."""
+ images = []
+ labels = []
+
+ start_class = task_id * self.classes_per_task
+ end_class = start_class + self.classes_per_task
+
+ if split == 'train':
+ num_samples = self.samples_per_class
+ else:
+ num_samples = self.samples_per_class // 5
+
+ for idx, class_id in enumerate(range(start_class, end_class)):
+ class_images = self.generate_synthetic_cifar(class_id, num_samples)
+ images.append(class_images)
+ labels.extend([idx] * num_samples)
+
+ images = torch.cat(images, dim=0)
+ labels = torch.tensor(labels, dtype=torch.long)
+
+ dataset = TensorDataset(images, labels)
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=(split == 'train'))
+
+ return dataloader
+
+ def get_all_tasks(self, batch_size: int = 32) -> List[Tuple[DataLoader, DataLoader, DataLoader]]:
+ """Get train/val/test for all tasks."""
+ tasks = []
+
+ for task_id in range(self.num_tasks):
+ train_loader = self.get_task(task_id, batch_size, split='train')
+ val_loader = self.get_task(task_id, batch_size, split='val')
+ test_loader = self.get_task(task_id, batch_size, split='test')
+
+ tasks.append((train_loader, val_loader, test_loader))
+
+ return tasks
+
+
+class RotatedMNISTGenerator:
+ """
+ Rotated MNIST: Smooth distribution shift benchmark.
+ Each task has MNIST rotated by increasing angles.
+
+ Task 0: 0° rotation
+ Task 1: 10° rotation
+ Task 2: 20° rotation
+ etc.
+ """
+
+ def __init__(self, num_tasks: int = 10, samples_per_task: int = 1000, seed: int = 42):
+ self.num_tasks = num_tasks
+ self.samples_per_task = samples_per_task
+ self.seed = seed
+
+ def generate_synthetic_mnist(self, task_id: int, num_samples: int = 1000) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Generate synthetic MNIST with rotation."""
+ np.random.seed(self.seed + task_id)
+
+ images = []
+ labels = []
+
+ rotation_angle = (task_id * 10) % 360 # 0°, 10°, 20°, ...
+
+ for digit in range(10):
+ samples_per_digit = num_samples // 10
+
+ for _ in range(samples_per_digit):
+ # Create digit
+ img = torch.zeros(28, 28)
+ center = 14
+ y, x = np.ogrid[:28, :28]
+ mask = (x - center)**2 + (y - center)**2 <= (8 + digit)**2
+ img[mask] = 1.0
+
+ # Apply rotation using PIL
+ from PIL import Image
+ pil_img = Image.fromarray((img.numpy() * 255).astype(np.uint8))
+ rotated = pil_img.rotate(rotation_angle)
+ img = torch.from_numpy(np.array(rotated)) / 255.0
+
+ # Add noise
+ img = img + torch.randn_like(img) * 0.3
+ img = torch.clamp(img, 0, 1)
+
+ images.append(img.view(-1).float())
+ labels.append(digit)
+
+ images = torch.stack(images)
+ labels = torch.tensor(labels, dtype=torch.long)
+
+ return images, labels
+
+ def get_task(self, task_id: int, batch_size: int = 32, split: str = 'train') -> DataLoader:
+ """Get DataLoader for task."""
+ if split == 'train':
+ num_samples = self.samples_per_task
+ else:
+ num_samples = self.samples_per_task // 5
+
+ images, labels = self.generate_synthetic_mnist(task_id, num_samples)
+
+ dataset = TensorDataset(images, labels)
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=(split == 'train'))
+
+ return dataloader
+
+ def get_all_tasks(self, batch_size: int = 32) -> List[Tuple[DataLoader, DataLoader, DataLoader]]:
+ """Get all tasks."""
+ tasks = []
+
+ for task_id in range(self.num_tasks):
+ train_loader = self.get_task(task_id, batch_size, split='train')
+ val_loader = self.get_task(task_id, batch_size, split='val')
+ test_loader = self.get_task(task_id, batch_size, split='test')
+
+ tasks.append((train_loader, val_loader, test_loader))
+
+ return tasks
+
+
+class ImprovedSequenceGenerator:
+ """
+ Sequence Learning: Learn different transformation rules per task.
+
+ Task 0: y = (x + x_prev) % vocab_size
+ Task 1: y = (2*x + x_prev) % vocab_size
+ Task 2: y = (x * x_prev) % vocab_size
+ etc.
+ """
+
+ def __init__(self, num_tasks: int = 10, seq_length: int = 64,
+ vocab_size: int = 256, samples_per_task: int = 2000, seed: int = 42):
+ self.num_tasks = num_tasks
+ self.seq_length = seq_length
+ self.vocab_size = vocab_size
+ self.samples_per_task = samples_per_task
+ self.seed = seed
+ self.rules = self._create_rules()
+
+ def _create_rules(self):
+ """Create interpretable transformation rules."""
+ rules = [
+ lambda seq: torch.roll(seq, 1),
+ lambda seq: (seq * 2) % self.vocab_size,
+ lambda seq: self.vocab_size - seq,
+ lambda seq: seq[::2] if len(seq) >= 32 else seq,
+ lambda seq: torch.where(torch.arange(len(seq)) % 2 == 0, seq, torch.zeros_like(seq)),
+ lambda seq: torch.sort(seq)[0],
+ lambda seq: torch.flip(seq, [0]),
+ lambda seq: torch.clamp(seq + torch.randn_like(seq) - torch.randn_like(seq), 0, self.vocab_size-1),
+ lambda seq: torch.tensor([int((seq[i] + seq[i+1]) % self.vocab_size) if i < len(seq)-1 else seq[i] for i in range(len(seq))]),
+ lambda seq: torch.cat([torch.flip(seq[:len(seq)//2], [0]), seq[len(seq)//2:]]),
+ ]
+ return rules
+
+ def generate_task_data(self, task_id: int, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Generate data for task using its rule."""
+ np.random.seed(self.seed + task_id)
+
+ rule = self.rules[task_id % len(self.rules)]
+
+ sequences = []
+ labels = []
+
+ for _ in range(num_samples):
+ seq_in = torch.randint(1, self.vocab_size, (self.seq_length,))
+
+ try:
+ seq_out = rule(seq_in)
+
+ if len(seq_out) != self.seq_length:
+ seq_out = seq_out[:self.seq_length]
+ if len(seq_out) < self.seq_length:
+ seq_out = torch.cat([seq_out, torch.zeros(self.seq_length - len(seq_out))])
+
+ seq_out = torch.clamp(seq_out.long(), 0, self.vocab_size - 1)
+
+ # Use first token as classification label
+ sequences.append(seq_in)
+ labels.append(seq_in[0].item() % 10) # 10 classes
+ except:
+ continue
+
+ if len(sequences) == 0:
+ sequences = torch.randint(1, self.vocab_size, (num_samples, self.seq_length))
+ labels = torch.randint(0, 10, (num_samples,))
+ else:
+ sequences = torch.stack(sequences)
+ labels = torch.tensor(labels, dtype=torch.long)
+
+ return sequences, labels
+
+ def get_task(self, task_id: int, batch_size: int = 32, split: str = 'train') -> DataLoader:
+ """Get DataLoader for task."""
+ if split == 'train':
+ num_samples = self.samples_per_task
+ else:
+ num_samples = self.samples_per_task // 5
+
+ sequences, labels = self.generate_task_data(task_id, num_samples)
+
+ dataset = TensorDataset(sequences, labels)
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=(split == 'train'))
+
+ return dataloader
+
+ def get_all_tasks(self, batch_size: int = 32) -> List[Tuple[DataLoader, DataLoader, DataLoader]]:
+ """Get all tasks."""
+ tasks = []
+
+ for task_id in range(self.num_tasks):
+ train_loader = self.get_task(task_id, batch_size, split='train')
+ val_loader = self.get_task(task_id, batch_size, split='val')
+ test_loader = self.get_task(task_id, batch_size, split='test')
+
+ tasks.append((train_loader, val_loader, test_loader))
+
+ return tasks
+
+
+if __name__ == "__main__":
+ """Test benchmark generators."""
+
+ print("=" * 70)
+ print("Testing Permuted MNIST Generator")
+ print("=" * 70)
+ gen = PermutedMNISTGenerator(num_tasks=3, samples_per_task=100)
+ tasks = gen.get_all_tasks(batch_size=32)
+
+ for task_id, (train_loader, val_loader, test_loader) in enumerate(tasks):
+ x, y = next(iter(train_loader))
+ print(f"\nTask {task_id}:")
+ print(f" Input: {x.shape} (flattened 28×28=784 pixels)")
+ print(f" Labels: {y.shape} (10 digit classes)")
+ print(f" Ready for training")
+
+ print("\n" + "=" * 70)
+ print("Testing Split CIFAR Generator")
+ print("=" * 70)
+ gen = SplitCIFARGenerator(num_tasks=5, samples_per_class=100)
+ tasks = gen.get_all_tasks(batch_size=32)
+
+ for task_id, (train_loader, val_loader, test_loader) in enumerate(tasks):
+ x, y = next(iter(train_loader))
+ print(f"\nTask {task_id}:")
+ print(f" Input: {x.shape} (3×32×32 RGB images)")
+ print(f" Labels: {y.shape} (2 classes per task)")
+ print(f" Ready for training")
+
+ print("\nAll benchmarks working correctly!")
diff --git a/continual_learning/__init__.py b/continual_learning/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/continual_learning/adaptive_synapses.py b/continual_learning/adaptive_synapses.py
new file mode 100644
index 0000000..faa3301
--- /dev/null
+++ b/continual_learning/adaptive_synapses.py
@@ -0,0 +1,125 @@
+import torch
+import torch.nn as nn
+from typing import Optional
+
+
+class AdaptiveSynapse(nn.Module):
+ """
+ Enhanced synaptic connection with consolidation support.
+
+ Attributes:
+ weight: Current synaptic strength
+ weight_ref: Reference weight from last task
+ importance: Task importance (Fisher or path integral)
+ plasticity_state: Consolidation level (0-1)
+ learning_rate_scale: Per-synapse learning rate modifier
+ """
+
+ def __init__(self, in_features: int, out_features: int,
+ bias: bool = True, dtype=torch.float32):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.dtype = dtype
+
+ # Main parameters
+ self.weight = nn.Parameter(
+ torch.empty((out_features, in_features), dtype=dtype)
+ )
+ if bias:
+ self.bias = nn.Parameter(torch.empty(out_features, dtype=dtype))
+ else:
+ self.register_parameter('bias', None)
+
+ self.reset_parameters()
+
+ # Continual learning state
+ self.register_buffer('weight_ref',
+ torch.zeros_like(self.weight, dtype=dtype))
+ self.register_buffer('importance',
+ torch.zeros_like(self.weight, dtype=dtype))
+ self.register_buffer('plasticity_state',
+ torch.ones((out_features, in_features), dtype=dtype))
+ self.register_buffer('learning_rate_scale',
+ torch.ones_like(self.weight, dtype=dtype))
+
+ # Online importance accumulators
+ self.register_buffer('path_integral',
+ torch.zeros_like(self.weight, dtype=dtype))
+ self.register_buffer('importance_accumulator',
+ torch.zeros_like(self.weight, dtype=dtype))
+ self.register_buffer('prev_weight',
+ torch.zeros_like(self.weight, dtype=dtype))
+
+ def reset_parameters(self):
+ """Initialize weights using Kaiming uniform."""
+ nn.init.kaiming_uniform_(self.weight, a=torch.sqrt(torch.tensor(5.0)))
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / torch.sqrt(torch.tensor(fan_in))
+ nn.init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ """Standard linear forward pass."""
+ return torch.nn.functional.linear(input, self.weight, self.bias)
+
+ def forward_with_consolidation(self, input: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass with plasticity gating.
+ Consolidated synapses transmit more reliably.
+ """
+ # Gate weights by plasticity state
+ gated_weight = self.weight * self.plasticity_state
+ return torch.nn.functional.linear(input, gated_weight, self.bias)
+
+ def update_plasticity_state(self, decay_rate: float = 0.98):
+ """
+ Reduce plasticity after task (consolidation).
+
+ Args:
+ decay_rate: Multiplicative decay (0.98 = 2% decay)
+ """
+ self.plasticity_state.data *= decay_rate
+ self.plasticity_state.data.clamp_(min=0.01, max=1.0)
+
+ def consolidate_weights(self):
+ """Save current weights as reference for next task."""
+ self.weight_ref.copy_(self.weight.data)
+
+
+class AdaptiveLinear(nn.Linear):
+ """Drop-in replacement for nn.Linear with consolidation support."""
+
+ def __init__(self, in_features: int, out_features: int,
+ bias: bool = True):
+ super().__init__(in_features, out_features, bias)
+
+ # Continual learning buffers
+ self.register_buffer('weight_ref',
+ torch.zeros_like(self.weight, dtype=self.weight.dtype))
+ self.register_buffer('importance',
+ torch.zeros_like(self.weight, dtype=self.weight.dtype))
+ self.register_buffer('plasticity_state',
+ torch.ones_like(self.weight, dtype=self.weight.dtype))
+ self.register_buffer('learning_rate_scale',
+ torch.ones_like(self.weight, dtype=self.weight.dtype))
+
+ self.register_buffer('path_integral',
+ torch.zeros_like(self.weight, dtype=self.weight.dtype))
+ self.register_buffer('importance_accumulator',
+ torch.zeros_like(self.weight, dtype=self.weight.dtype))
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return nn.functional.linear(input, self.weight, self.bias)
+
+ def forward_with_consolidation(self, input: torch.Tensor) -> torch.Tensor:
+ """Forward pass with consolidation gating."""
+ gated_weight = self.weight * self.plasticity_state
+ return nn.functional.linear(input, gated_weight, self.bias)
+
+ def update_plasticity_state(self, decay_rate: float = 0.98):
+ self.plasticity_state.data *= decay_rate
+ self.plasticity_state.data.clamp_(min=0.01, max=1.0)
+
+ def consolidate_weights(self):
+ self.weight_ref.copy_(self.weight.data)
diff --git a/continual_learning/consolidation.py b/continual_learning/consolidation.py
new file mode 100644
index 0000000..c312ba0
--- /dev/null
+++ b/continual_learning/consolidation.py
@@ -0,0 +1,123 @@
+import torch
+import torch.nn as nn
+from typing import Dict
+
+class ElasticWeightConsolidation(nn.Module):
+ """
+ Elastic Weight Consolidation loss.
+
+ Penalizes changes to important parameters:
+ L_EWC = (λ/2) Σ F_i (w_i - w*_i)^2
+
+ Where:
+ - F_i: Fisher importance
+ - w_i: Current weight
+ - w*_i: Reference weight from previous task
+
+ Reference: Kirkpatrick et al., 2017
+ """
+
+ def __init__(self, lambda_ewc: float = 1000.0):
+ super().__init__()
+ self.lambda_ewc = lambda_ewc
+ self.is_first_task = True
+ self.task_count = 0
+
+ def consolidate_task(self, model: nn.Module, fisher_dict: Dict[str, torch.Tensor]):
+ """
+ Consolidate after task completion.
+ Save reference weights and Fisher importance.
+ """
+ if not self.is_first_task:
+ # Increase penalty for multiple tasks
+ self.lambda_ewc *= 1.2
+
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ # Sanitize names (replace '.' with '_')
+ safe_ref_name = f"_ref_{name}".replace('.', '_')
+ safe_fisher_name = f"_fisher_{name}".replace('.', '_')
+
+ # Register buffer for reference weight if does not exist
+ if not hasattr(model, safe_ref_name):
+ ref_buffer = torch.zeros_like(param.data)
+ model.register_buffer(safe_ref_name, ref_buffer)
+ model._buffers[safe_ref_name].copy_(param.data)
+
+ # Register buffer for Fisher importance if exists in fisher_dict
+ if name in fisher_dict:
+ if not hasattr(model, safe_fisher_name):
+ fisher_buffer = torch.zeros_like(param.data)
+ model.register_buffer(safe_fisher_name, fisher_buffer)
+ model._buffers[safe_fisher_name].copy_(fisher_dict[name])
+
+ self.is_first_task = False
+ self.task_count += 1
+
+ def forward(self, model: nn.Module) -> torch.Tensor:
+ """Compute EWC loss."""
+ if self.is_first_task:
+ return torch.tensor(0.0, device=next(model.parameters()).device)
+
+ ewc_loss = torch.tensor(0.0, device=next(model.parameters()).device)
+
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ safe_ref_name = f"_ref_{name}".replace('.', '_')
+ safe_fisher_name = f"_fisher_{name}".replace('.', '_')
+
+ if hasattr(model, safe_ref_name) and hasattr(model, safe_fisher_name):
+ weight_ref = getattr(model, safe_ref_name)
+ fisher = getattr(model, safe_fisher_name)
+
+ # Weight deviation
+ delta_w = param - weight_ref
+ # EWC penalty
+ ewc_loss = ewc_loss + (fisher * (delta_w ** 2)).sum()
+
+ return (self.lambda_ewc / 2.0) * ewc_loss
+
+
+class MetaplasticityConsolidation:
+ """
+ Automatic consolidation through metaplasticity.
+
+ Reduces plasticity of important synapses after task completion.
+
+ Inspired by: Fusi et al. synaptic tagging and capture model.
+ """
+
+ def __init__(self, consolidation_rate: float = 0.98):
+ self.consolidation_rate = consolidation_rate
+
+ def apply_consolidation(self, model: nn.Module, strength: float = 1.0):
+ """Apply consolidation to all adaptive layers."""
+ for module in model.modules():
+ if hasattr(module, 'plasticity_state'):
+ # Compute importance-weighted consolidation
+ if hasattr(module, 'importance'):
+ importance = torch.sigmoid(module.importance)
+ else:
+ importance = torch.ones_like(module.plasticity_state)
+
+ # Decay rate based on importance
+ decay_rate = (
+ self.consolidation_rate +
+ (1 - self.consolidation_rate) * importance
+ )
+ # Update plasticity
+ module.plasticity_state.data *= decay_rate
+ module.plasticity_state.data.clamp_(min=0.01, max=1.0)
+
+ def progressive_consolidation(self, model: nn.Module,
+ epoch: int, num_consolidation_epochs: int = 10):
+ """Gradually consolidate over multiple epochs."""
+ progress = epoch / num_consolidation_epochs
+ for module in model.modules():
+ if hasattr(module, 'plasticity_state'):
+ target_decay = (
+ self.consolidation_rate +
+ progress * (1 - self.consolidation_rate)
+ )
+ module.plasticity_state.data *= target_decay
+ module.plasticity_state.data.clamp_(min=0.01, max=1.0)
diff --git a/continual_learning/hebbian.py b/continual_learning/hebbian.py
new file mode 100644
index 0000000..b312ceb
--- /dev/null
+++ b/continual_learning/hebbian.py
@@ -0,0 +1,27 @@
+import torch
+import torch.nn as nn
+from typing import Dict
+
+class HebbianPlasticity:
+ def __init__(self, learning_rate: float = 0.01):
+ self.learning_rate = learning_rate
+
+ def apply_hebbian_update(self, model: nn.Module, layer_activations: Dict[str, torch.Tensor]):
+ for name, module in model.named_modules():
+ if isinstance(module, nn.Linear):
+ if name in layer_activations:
+ activation = layer_activations[name]
+ if hasattr(module, 'weight'):
+ # Compute Hebbian update: outer product of activations
+ # activation shape [batch_size, input_features]
+ # weight shape [output_features, input_features]
+ pre_syn = activation # presynaptic inputs
+ post_syn = module.weight.data @ pre_syn.t() # postsynaptic outputs
+ delta = torch.mm(post_syn.t(), pre_syn)
+ module.weight.data += self.learning_rate * delta / activation.size(0)
+
+ def consolidate_importance(self, model: nn.Module):
+ for module in model.modules():
+ if hasattr(module, 'importance') and hasattr(module, 'importance_accumulator'):
+ module.importance.data += 0.01 * module.importance_accumulator.data
+ module.importance_accumulator.zero_()
diff --git a/continual_learning/importance.py b/continual_learning/importance.py
new file mode 100644
index 0000000..b583bdf
--- /dev/null
+++ b/continual_learning/importance.py
@@ -0,0 +1,158 @@
+import torch
+import torch.nn as nn
+from typing import Dict, List, Optional
+
+
+class ImportanceEstimator:
+ """
+ Base class for computing synaptic importance.
+ """
+
+ def __init__(self, model: nn.Module):
+ self.model = model
+
+ def estimate(self, data_loader, loss_fn) -> Dict[str, torch.Tensor]:
+ """Estimate importance for all parameters."""
+ raise NotImplementedError
+
+
+class PathIntegralEstimator(ImportanceEstimator):
+ """
+ Path integral importance: I = ∫ ||∇L/∂w|| ||∂w/∂t|| dt
+
+ Online method - accumulates during training.
+ More efficient than Fisher matrix.
+
+ Reference: Zenke et al., 2017
+ """
+
+ def __init__(self, model: nn.Module, dampening_factor: float = 0.01):
+ super().__init__(model)
+ self.dampening_factor = dampening_factor
+ self.path_integrals = {}
+ self.prev_weights = {}
+
+ def reset(self):
+ """Reset accumulators."""
+ self.path_integrals = {}
+ self.prev_weights = {}
+
+ def save_prev_weights(self):
+ """Save current weights for delta computation."""
+ for name, param in self.model.named_parameters():
+ if param.requires_grad:
+ self.prev_weights[name] = param.data.clone()
+
+ def update_importance(self, loss: torch.Tensor):
+ """
+ Update importance based on current loss and weight changes.
+ Call after backward() but before optimizer.step()
+ """
+ for name, param in self.model.named_parameters():
+ if param.grad is not None:
+ grad = param.grad.data
+
+ # Compute weight change if available
+ if name in self.prev_weights:
+ delta_w = param.data - self.prev_weights[name]
+ # Path integral: |gradient| × |weight_change|
+ importance = torch.abs(grad) * torch.abs(delta_w)
+ else:
+ # First iteration: use gradient magnitude
+ importance = torch.abs(grad)
+
+ # Accumulate with dampening
+ if name not in self.path_integrals:
+ self.path_integrals[name] = torch.zeros_like(importance)
+
+ self.path_integrals[name] += self.dampening_factor * importance
+
+ def get_importance(self) -> Dict[str, torch.Tensor]:
+ """Get accumulated importance values."""
+ return self.path_integrals.copy()
+
+
+class FisherEstimator(ImportanceEstimator):
+ """
+ Diagonal Fisher Information Matrix estimation.
+ F_ii = E[(∂L/∂w_i)²]
+
+ Accurate but more expensive than path integral.
+ Used as task checkpoint.
+
+ Reference: Kirkpatrick et al., 2017
+ """
+
+ def __init__(self, model: nn.Module):
+ super().__init__(model)
+
+ def estimate(self, data_loader, loss_fn,
+ num_samples: Optional[int] = None) -> Dict[str, torch.Tensor]:
+ """
+ Estimate diagonal Fisher for given data.
+
+ Args:
+ data_loader: Training data
+ loss_fn: Loss function
+ num_samples: Number of samples to use (None = all)
+
+ Returns:
+ Dict mapping parameter names to Fisher estimates
+ """
+ fisher = {}
+ num_batches = 0
+
+ self.model.eval()
+
+ for batch_idx, (x, y) in enumerate(data_loader):
+ if num_samples is not None and batch_idx >= num_samples:
+ break
+
+ self.model.zero_grad()
+
+ # Forward pass
+ outputs = self.model(x)
+ batch_size, seq_len, vocab_size = outputs.shape
+ loss = loss_fn(outputs.view(batch_size * seq_len, vocab_size), y.view(batch_size * seq_len))
+
+
+ # Backward to get gradients
+ loss.backward(retain_graph=(batch_idx < len(data_loader) - 1))
+
+ # Accumulate squared gradients
+ for name, param in self.model.named_parameters():
+ if param.grad is not None:
+ grad_sq = param.grad.data ** 2
+
+ if name not in fisher:
+ fisher[name] = torch.zeros_like(grad_sq)
+
+ fisher[name] += grad_sq
+
+ num_batches += 1
+
+ # Average over batches
+ for name in fisher:
+ fisher[name] /= max(num_batches, 1)
+
+ return fisher
+
+
+class HybridImportanceEstimator(ImportanceEstimator):
+ """
+ Combines path integral (during training) with Fisher (at task end).
+ """
+
+ def __init__(self, model: nn.Module):
+ super().__init__(model)
+ self.path_integral_est = PathIntegralEstimator(model)
+ self.fisher_est = FisherEstimator(model)
+
+ def get_online_importance(self, loss: torch.Tensor):
+ """Get importance during training."""
+ self.path_integral_est.update_importance(loss)
+ return self.path_integral_est.get_importance()
+
+ def get_task_importance(self, data_loader, loss_fn) -> Dict[str, torch.Tensor]:
+ """Get importance at task end (Fisher)."""
+ return self.fisher_est.estimate(data_loader, loss_fn, num_samples=200)
diff --git a/continual_learning/metrics.py b/continual_learning/metrics.py
new file mode 100644
index 0000000..18eedc4
--- /dev/null
+++ b/continual_learning/metrics.py
@@ -0,0 +1,66 @@
+import numpy as np
+from typing import List
+
+
+class ContinualLearningMetrics:
+ """Compute standard CL metrics."""
+
+ def __init__(self, accuracies: List[List[float]]):
+ self.accuracies = np.array(accuracies)
+ self.num_tasks = len(accuracies)
+
+ def backward_transfer(self) -> float:
+ """
+ Measure forgetting.
+ BWT = (1/T-1) Σ (ACC_i(i) - ACC_i(T))
+ """
+ if self.num_tasks < 2:
+ return 0.0
+
+ bwt = 0.0
+ for i in range(self.num_tasks - 1):
+ acc_after_task_i = self.accuracies[i, i]
+ acc_final = self.accuracies[-1, i]
+ bwt += acc_after_task_i - acc_final
+
+ return bwt / (self.num_tasks - 1)
+
+ def forward_transfer(self) -> float:
+ """Measure positive transfer from previous tasks."""
+ # Simplified: compare task 1 accuracy after task 0 to baseline
+ if self.num_tasks < 2:
+ return 0.0
+
+ # Would need baseline for true forward transfer
+ return 0.0
+
+ def average_accuracy(self) -> float:
+ """Final average accuracy across all tasks."""
+ return float(np.mean(self.accuracies[-1, :]))
+
+ def forgetting_per_task(self) -> List[float]:
+ """Forgetting for each task."""
+ forgetting = []
+ for i in range(self.num_tasks - 1):
+ acc_after = self.accuracies[i, i]
+ acc_final = self.accuracies[-1, i]
+ forgetting.append(acc_after - acc_final)
+
+ return forgetting
+
+ def print_summary(self):
+ """Print summary statistics."""
+ print("\n" + "="*70)
+ print("CONTINUAL LEARNING EVALUATION")
+ print("="*70)
+
+ print("\nAccuracy Matrix:")
+ for i, row in enumerate(self.accuracies):
+ print(f"After Task {i}: {[f'{acc:.3f}' for acc in row]}")
+
+ print(f"\nAverage Final Accuracy: {self.average_accuracy():.4f}")
+ print(f"Backward Transfer (Forgetting): {self.backward_transfer():.4f}")
+
+ forgetting = self.forgetting_per_task()
+ if forgetting:
+ print(f"Forgetting per task: {[f'{f:.3f}' for f in forgetting]}")
diff --git a/continual_learning/trainer.py b/continual_learning/trainer.py
new file mode 100644
index 0000000..e34bb34
--- /dev/null
+++ b/continual_learning/trainer.py
@@ -0,0 +1,212 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from typing import List, Tuple, Callable
+import tqdm
+
+from .consolidation import ElasticWeightConsolidation, MetaplasticityConsolidation
+from .importance import HybridImportanceEstimator
+from .hebbian import HebbianPlasticity
+
+
+def train_epoch(model: nn.Module,
+ data_loader: torch.utils.data.DataLoader,
+ optimizer: optim.Optimizer,
+ loss_fn: nn.Module,
+ ewc_loss_fn: ElasticWeightConsolidation,
+ importance_est: HybridImportanceEstimator,
+ ewc_weight: float = 1.0,
+ use_consolidation: bool = False) -> float:
+ """
+ Single training epoch with consolidation.
+
+ Args:
+ model: Neural network
+ data_loader: Training data
+ optimizer: Optimizer
+ loss_fn: Loss function
+ ewc_loss_fn: EWC loss
+ importance_est: Importance estimator
+ ewc_weight: Weight for EWC loss
+ use_consolidation: Whether to use consolidation gates
+
+ Returns:
+ Average loss
+ """
+ model.train()
+ total_loss = 0.0
+ num_batches = 0
+
+ pbar = tqdm.tqdm(data_loader, desc="Training")
+
+ for x, y in pbar:
+ # Forward pass
+ if use_consolidation:
+ # Use consolidation gates
+ logits = model(x) # Model should support consolidation internally
+ else:
+ logits = model(x)
+
+ task_loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))
+
+ # EWC regularization
+ ewc_loss = ewc_loss_fn(model)
+
+ total_loss_batch = task_loss + ewc_weight * ewc_loss
+
+ # Backward
+ optimizer.zero_grad()
+ total_loss_batch.backward()
+
+ # Update importance estimate (online)
+ importance_est.get_online_importance(task_loss)
+
+ # Optimizer step
+ optimizer.step()
+
+ total_loss += total_loss_batch.item()
+ num_batches += 1
+
+ pbar.set_postfix({'loss': total_loss / (num_batches + 1)})
+
+ return total_loss / num_batches if num_batches > 0 else 0.0
+
+
+def evaluate(model: nn.Module,
+ data_loader: torch.utils.data.DataLoader) -> float:
+ """Evaluate model accuracy."""
+ model.eval()
+ correct = 0
+ total = 0
+
+ with torch.no_grad():
+ for x, y in data_loader:
+ logits = model(x) # logits shape: (batch_size, seq_len, vocab_size)
+ pred = logits.argmax(dim=-1) # pred shape: (batch_size, seq_len)
+
+ # Flatten predictions and labels for element-wise comparison
+ pred_flat = pred.view(-1)
+ y_flat = y.view(-1)
+
+ correct += (pred_flat == y_flat).sum().item()
+ total += y_flat.size(0)
+
+ return correct / total if total > 0 else 0.0
+
+
+def train_continual_learning_task(model: nn.Module,
+ train_loader: torch.utils.data.DataLoader,
+ val_loader: torch.utils.data.DataLoader,
+ loss_fn: nn.Module,
+ ewc_loss_fn: ElasticWeightConsolidation,
+ importance_est: HybridImportanceEstimator,
+ num_epochs: int = 40,
+ lr: float = 0.001,
+ use_consolidation: bool = True) -> float:
+ """
+ Train on single task with early stopping.
+
+ Returns:
+ Best validation accuracy
+ """
+ optimizer = optim.Adam(model.parameters(), lr=lr)
+
+ best_val_acc = 0.0
+ patience = 5
+ patience_counter = 0
+
+ print(f"\nTraining with {len(train_loader)} batches...")
+
+ for epoch in range(num_epochs):
+ train_loss = train_epoch(
+ model, train_loader, optimizer, loss_fn, ewc_loss_fn,
+ importance_est, ewc_weight=0.5 if not ewc_loss_fn.is_first_task else 0.0,
+ use_consolidation=use_consolidation
+ )
+
+ val_acc = evaluate(model, val_loader)
+
+ if val_acc > best_val_acc:
+ best_val_acc = val_acc
+ patience_counter = 0
+ else:
+ patience_counter += 1
+
+ if epoch % 5 == 0:
+ print(f"Epoch {epoch}/{num_epochs}: Loss={train_loss:.4f}, "
+ f"Val_Acc={val_acc:.4f}")
+
+ if patience_counter >= patience:
+ print(f"Early stopping at epoch {epoch}")
+ break
+
+ return best_val_acc
+
+
+def continual_learning_training_loop(model: nn.Module,
+ tasks: List[Tuple],
+ num_epochs_per_task: int = 40,
+ lr: float = 0.001,
+ use_consolidation: bool = True,
+ use_metaplasticity: bool = True) -> List[List[float]]:
+ """
+ Main continual learning loop.
+
+ Args:
+ model: Neural network
+ tasks: List of (train_loader, val_loader, test_loader) tuples
+ num_epochs_per_task: Training epochs per task
+ lr: Learning rate
+ use_consolidation: Use EWC consolidation
+ use_metaplasticity: Use metaplasticity
+
+ Returns:
+ Accuracy matrix: accuracies[task_id][eval_task_id]
+ """
+ loss_fn = nn.CrossEntropyLoss()
+ ewc_loss_fn = ElasticWeightConsolidation(lambda_ewc=1000.0)
+ importance_est = HybridImportanceEstimator(model)
+ metaplasticity = MetaplasticityConsolidation(consolidation_rate=0.98)
+
+ accuracies = []
+
+ for task_id, (train_loader, val_loader, test_loader) in enumerate(tasks):
+ print(f"\n{'='*60}")
+ print(f"TASK {task_id}")
+ print(f"{'='*60}")
+
+ # Train on task
+ train_continual_learning_task(
+ model, train_loader, val_loader, loss_fn, ewc_loss_fn,
+ importance_est, num_epochs=num_epochs_per_task, lr=lr,
+ use_consolidation=use_consolidation
+ )
+
+ # Consolidate task
+ if use_consolidation:
+ print(f"\nConsolidating Task {task_id}...")
+
+ # Estimate Fisher information
+ fisher_dict = importance_est.get_task_importance(
+ train_loader, loss_fn
+ )
+
+ # Save weights and Fisher
+ ewc_loss_fn.consolidate_task(model, fisher_dict)
+
+ # Apply metaplasticity
+ if use_metaplasticity:
+ for cons_epoch in range(10):
+ metaplasticity.apply_consolidation(model, strength=0.9)
+
+ # Test on all previous tasks
+ task_accs = []
+ for eval_task_id in range(task_id + 1):
+ _, _, test_loader_eval = tasks[eval_task_id]
+ acc = evaluate(model, test_loader_eval)
+ task_accs.append(acc)
+ print(f" Task {eval_task_id}: Accuracy={acc:.4f}")
+
+ accuracies.append(task_accs)
+
+ return accuracies
diff --git a/figs/architecture.png b/figs/architecture.png
deleted file mode 100644
index 664092f..0000000
Binary files a/figs/architecture.png and /dev/null differ
diff --git a/figs/bdh_scaling.png b/figs/bdh_scaling.png
deleted file mode 100644
index 276e97b..0000000
Binary files a/figs/bdh_scaling.png and /dev/null differ
diff --git a/figs/vocab.png b/figs/vocab.png
deleted file mode 100644
index c54a5fd..0000000
Binary files a/figs/vocab.png and /dev/null differ
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..d31aea0
--- /dev/null
+++ b/model.py
@@ -0,0 +1,429 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Tuple, List
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class BDHNeuron(nn.Module):
+ """
+ Basic BDH neuron with optional adaptive synaptic properties.
+ Represents a single neuron in the scale-free network.
+ """
+
+ def __init__(self, neuron_id: int, input_dim: int, output_dim: int,
+ use_adaptive=False):
+ super().__init__()
+ self.neuron_id = neuron_id
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.use_adaptive = use_adaptive
+
+ # Synaptic weights
+ self.weight = nn.Parameter(torch.randn(output_dim, input_dim) * 0.01)
+ self.bias = nn.Parameter(torch.zeros(output_dim))
+
+ if use_adaptive:
+ # Continual learning buffers
+ self.register_buffer('weight_ref', torch.zeros_like(self.weight))
+ self.register_buffer('importance', torch.zeros_like(self.weight))
+ self.register_buffer('plasticity_state', torch.ones_like(self.weight))
+ self.register_buffer('learning_rate_scale', torch.ones_like(self.weight))
+ self.register_buffer('path_integral', torch.zeros_like(self.weight))
+
+ def forward(self, x):
+ """Standard forward pass: y = Wx + b"""
+ return F.linear(x, self.weight, self.bias)
+
+ def forward_with_consolidation(self, x):
+ """
+ Forward pass with plasticity gating.
+ Important, consolidated synapses gate information more reliably.
+ """
+ if not self.use_adaptive:
+ return self.forward(x)
+
+ # Gate synaptic transmission by plasticity state
+ gated_weight = self.weight * self.plasticity_state
+ return F.linear(x, gated_weight, self.bias)
+
+ def update_plasticity(self, decay_rate: float = 0.98):
+ """Reduce plasticity (consolidation) after task."""
+ if self.use_adaptive:
+ self.plasticity_state.data *= decay_rate
+ self.plasticity_state.data.clamp_(min=0.01, max=1.0)
+
+
+class BDHScaleFreeLayer(nn.Module):
+ """
+ Scale-free graph layer representing BDH's graph topology.
+
+ Features:
+ - Sparse connectivity mimicking biological networks
+ - Heavy-tailed degree distribution (scale-free)
+ - Local interactions with global properties
+ - Optional adaptive synapses for continual learning
+ """
+
+ def __init__(self, input_size: int, output_size: int, sparsity: float = 0.9,
+ use_adaptive: bool = False):
+ super().__init__()
+ self.input_size = input_size
+ self.output_size = output_size
+ self.sparsity = sparsity
+ self.use_adaptive = use_adaptive
+
+ # Create scale-free connectivity
+ self.weight = nn.Parameter(torch.randn(output_size, input_size) * 0.01)
+ self.bias = nn.Parameter(torch.zeros(output_size))
+
+ # Apply sparsity (scale-free)
+ mask = torch.bernoulli(torch.ones_like(self.weight) * (1 - sparsity))
+ self.register_buffer('connectivity_mask', mask)
+
+ # Adaptive properties
+ if use_adaptive:
+ self.register_buffer('weight_ref', torch.zeros_like(self.weight))
+ self.register_buffer('importance', torch.zeros_like(self.weight))
+ self.register_buffer('plasticity_state', torch.ones_like(self.weight))
+ self.register_buffer('learning_rate_scale', torch.ones_like(self.weight))
+ self.register_buffer('path_integral', torch.zeros_like(self.weight))
+
+ def forward(self, x):
+ """Forward with sparse connectivity."""
+ masked_weight = self.weight * self.connectivity_mask
+ return F.linear(x, masked_weight, self.bias)
+
+ def forward_with_consolidation(self, x):
+ """Forward with consolidation gating."""
+ if not self.use_adaptive:
+ return self.forward(x)
+
+ # Apply both sparsity and plasticity gating
+ masked_weight = self.weight * self.connectivity_mask
+ gated_weight = masked_weight * self.plasticity_state
+ return F.linear(x, gated_weight, self.bias)
+
+ def update_plasticity(self, decay_rate: float = 0.98):
+ """Update plasticity state (consolidation)."""
+ if self.use_adaptive:
+ self.plasticity_state.data *= decay_rate
+ self.plasticity_state.data.clamp_(min=0.01, max=1.0)
+
+
+class BDHAttentionLayer(nn.Module):
+ """
+ BDH Attention Layer without softmax.
+
+ Key differences from Transformers:
+ - No softmax (raw attention scores)
+ - RoPE positional embeddings
+ - Q = K (tied attention)
+ - Optional adaptive properties
+ """
+
+ def __init__(self, hidden_size: int, num_heads: int = 8,
+ use_adaptive: bool = False):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.num_heads = num_heads
+ self.head_dim = hidden_size // num_heads
+ self.use_adaptive = use_adaptive
+
+ # Attention projections
+ if use_adaptive:
+ try:
+ from bdh.continual_learning.adaptive_synapses import AdaptiveLinear
+ self.query = AdaptiveLinear(hidden_size, hidden_size)
+ self.key = AdaptiveLinear(hidden_size, hidden_size)
+ self.value = AdaptiveLinear(hidden_size, hidden_size)
+ self.output = AdaptiveLinear(hidden_size, hidden_size)
+ except ImportError:
+ logger.warning("AdaptiveLinear not available, using standard Linear")
+ self.query = nn.Linear(hidden_size, hidden_size)
+ self.key = nn.Linear(hidden_size, hidden_size)
+ self.value = nn.Linear(hidden_size, hidden_size)
+ self.output = nn.Linear(hidden_size, hidden_size)
+ else:
+ self.query = nn.Linear(hidden_size, hidden_size)
+ self.key = nn.Linear(hidden_size, hidden_size)
+ self.value = nn.Linear(hidden_size, hidden_size)
+ self.output = nn.Linear(hidden_size, hidden_size)
+
+ self.dropout = nn.Dropout(0.1)
+
+ def forward(self, x):
+ """BDH attention forward pass."""
+ batch_size, seq_len, _ = x.shape
+
+ # Project to Q, K, V
+ Q = self.query(x)
+ K = self.key(x)
+ V = self.value(x)
+
+ # Reshape for multi-head attention
+ Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
+ K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
+ V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)
+
+ # Compute attention (NO softmax - BDH specific)
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
+
+ # Mask future tokens (causal)
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
+ scores = scores.masked_fill(causal_mask, float('-inf'))
+
+ # ReLU instead of softmax
+ attention = F.relu(scores)
+ attention = attention / (attention.sum(dim=-1, keepdim=True) + 1e-9)
+ attention = self.dropout(attention)
+
+ # Apply attention to values
+ context = torch.matmul(attention, V)
+ context = context.view(batch_size, seq_len, -1)
+
+ # Output projection
+ output = self.output(context)
+ return output
+
+
+class BDHBlock(nn.Module):
+ """
+ Single BDH Block combining scale-free layer and attention.
+
+ Architecture:
+ 1. Scale-free graph layer (local interactions)
+ 2. Attention layer (global dependencies)
+ 3. Layer normalization & residual
+ 4. Feed-forward
+ """
+
+ def __init__(self, hidden_size: int, num_heads: int = 8,
+ sparsity: float = 0.9, use_adaptive: bool = False):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.use_adaptive = use_adaptive
+
+ # Scale-free layer
+ self.scale_free = BDHScaleFreeLayer(
+ hidden_size, hidden_size, sparsity=sparsity,
+ use_adaptive=use_adaptive
+ )
+
+ # Attention
+ self.attention = BDHAttentionLayer(
+ hidden_size, num_heads=num_heads,
+ use_adaptive=use_adaptive
+ )
+
+ # Feed-forward network
+ ff_dim = hidden_size * 4
+ if use_adaptive:
+ try:
+ from bdh.continual_learning.adaptive_synapses import AdaptiveLinear
+ self.ff1 = AdaptiveLinear(hidden_size, ff_dim)
+ self.ff2 = AdaptiveLinear(ff_dim, hidden_size)
+ except ImportError:
+ self.ff1 = nn.Linear(hidden_size, ff_dim)
+ self.ff2 = nn.Linear(ff_dim, hidden_size)
+ else:
+ self.ff1 = nn.Linear(hidden_size, ff_dim)
+ self.ff2 = nn.Linear(ff_dim, hidden_size)
+
+ # Layer normalization
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+ self.norm3 = nn.LayerNorm(hidden_size)
+
+ self.dropout = nn.Dropout(0.1)
+
+ def forward(self, x):
+ """BDH block forward pass with residuals."""
+ # Scale-free layer + residual
+ out = self.norm1(x)
+ out = self.scale_free(out)
+ x = x + self.dropout(out)
+
+ # Attention + residual
+ out = self.norm2(x)
+ out = self.attention(out)
+ x = x + self.dropout(out)
+
+ # Feed-forward + residual
+ out = self.norm3(x)
+ out = self.ff1(out)
+ out = F.relu(out)
+ out = self.dropout(out)
+ out = self.ff2(out)
+ x = x + self.dropout(out)
+
+ return x
+
+ def update_plasticity(self, decay_rate: float = 0.98):
+ """Update plasticity in all adaptive layers."""
+ if self.use_adaptive:
+ self.scale_free.update_plasticity(decay_rate)
+ # Update attention layers
+ if hasattr(self.attention.query, 'update_plasticity'):
+ self.attention.query.update_plasticity(decay_rate)
+ self.attention.key.update_plasticity(decay_rate)
+ self.attention.value.update_plasticity(decay_rate)
+ self.attention.output.update_plasticity(decay_rate)
+ # Update feed-forward layers
+ if hasattr(self.ff1, 'update_plasticity'):
+ self.ff1.update_plasticity(decay_rate)
+ self.ff2.update_plasticity(decay_rate)
+
+
+class BDH(nn.Module):
+ """
+ Baby Dragon Hatchling (BDH) Model.
+
+ A biologically-inspired architecture based on scale-free networks
+ and local interactions with optional continual learning support.
+
+ Architecture:
+ - Embedding layer
+ - Multiple BDH blocks (scale-free + attention)
+ - Output layer
+ - Optional adaptive layers for continual learning
+ """
+
+ def __init__(self, vocab_size: int, hidden_size: int, num_layers: int,
+ num_heads: int = 8, sparsity: float = 0.9,
+ use_adaptive_layers: bool = False):
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.use_adaptive_layers = use_adaptive_layers
+
+ # Embedding
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
+
+ # BDH Blocks
+ self.blocks = nn.ModuleList([
+ BDHBlock(hidden_size, num_heads=num_heads, sparsity=sparsity,
+ use_adaptive=use_adaptive_layers)
+ for _ in range(num_layers)
+ ])
+
+ # Output layer
+ if use_adaptive_layers:
+ try:
+ from bdh.continual_learning.adaptive_synapses import AdaptiveLinear
+ self.output = AdaptiveLinear(hidden_size, vocab_size)
+ except ImportError:
+ self.output = nn.Linear(hidden_size, vocab_size)
+ else:
+ self.output = nn.Linear(hidden_size, vocab_size)
+
+ # Final layer norm
+ self.final_norm = nn.LayerNorm(hidden_size)
+
+ def forward(self, x):
+ """
+ Forward pass through BDH model.
+
+ Args:
+ x: Token indices (batch_size, seq_length)
+
+ Returns:
+ logits: (batch_size, seq_length, vocab_size)
+ """
+ # Embedding
+ x = self.embedding(x)
+
+ # BDH blocks
+ for block in self.blocks:
+ x = block(x)
+
+ # Final norm
+ x = self.final_norm(x)
+
+ # Output
+ x = self.output(x)
+
+ return x
+
+ def forward_with_consolidation(self, x):
+ """
+ Forward pass using consolidation gates.
+ Only relevant if using adaptive layers.
+ """
+ if not self.use_adaptive_layers:
+ return self.forward(x)
+
+ # Embedding
+ x = self.embedding(x)
+
+ # BDH blocks with consolidation
+ for block in self.blocks:
+ x = block(x)
+
+ # Final norm
+ x = self.final_norm(x)
+
+ # Output with consolidation (if adaptive)
+ if hasattr(self.output, 'forward_with_consolidation'):
+ x = self.output.forward_with_consolidation(x)
+ else:
+ x = self.output(x)
+
+ return x
+
+ def update_plasticity(self, decay_rate: float = 0.98):
+ """Update plasticity across all blocks (consolidation)."""
+ if self.use_adaptive_layers:
+ for block in self.blocks:
+ block.update_plasticity(decay_rate)
+
+ if hasattr(self.output, 'update_plasticity'):
+ self.output.update_plasticity(decay_rate)
+
+ def get_neurons(self) -> List:
+ """Get list of all neurons (for compatibility with CL trainer)."""
+ neurons = []
+ for i, block in enumerate(self.blocks):
+ neurons.append(block)
+ return neurons
+
+
+def create_bdh_model(config_dict: dict, use_adaptive_layers: bool = False) -> BDH:
+ """
+ Factory function to create BDH model from config.
+
+ Args:
+ config_dict: Dictionary with keys:
+ - vocab_size: int
+ - hidden_size: int
+ - num_layers: int
+ - num_heads: int (optional, default=8)
+ - sparsity: float (optional, default=0.9)
+ use_adaptive_layers: Whether to use adaptive layers for continual learning
+
+ Returns:
+ BDH model instance
+ """
+ model = BDH(
+ vocab_size=config_dict['vocab_size'],
+ hidden_size=config_dict['hidden_size'],
+ num_layers=config_dict['num_layers'],
+ num_heads=config_dict.get('num_heads', 8),
+ sparsity=config_dict.get('sparsity', 0.9),
+ use_adaptive_layers=use_adaptive_layers
+ )
+
+ total_params = sum(p.numel() for p in model.parameters())
+ logger.info(f"Created BDH model with {total_params:,} parameters")
+ logger.info(f" - Vocab size: {config_dict['vocab_size']}")
+ logger.info(f" - Hidden size: {config_dict['hidden_size']}")
+ logger.info(f" - Num layers: {config_dict['num_layers']}")
+ logger.info(f" - Num heads: {config_dict.get('num_heads', 8)}")
+ logger.info(f" - Sparsity: {config_dict.get('sparsity', 0.9)}")
+ logger.info(f" - Adaptive layers: {use_adaptive_layers}")
+
+ return model
+
diff --git a/requirements.txt b/requirements.txt
index 8ad30cc..08be1ea 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,6 @@
-torch
-numpy
-requests
+torch>=2.0.0
+torchvision>=0.15.0
+numpy>=1.23.0
+tqdm>=4.65.0
+matplotlib>=3.6.0
+scipy>=1.9.0
\ No newline at end of file
diff --git a/res/PERMUTED_MNIST.PNG b/res/PERMUTED_MNIST.PNG
new file mode 100644
index 0000000..c175f64
Binary files /dev/null and b/res/PERMUTED_MNIST.PNG differ
diff --git a/res/ROTATED_MNIST.PNG b/res/ROTATED_MNIST.PNG
new file mode 100644
index 0000000..be37c0f
Binary files /dev/null and b/res/ROTATED_MNIST.PNG differ
diff --git a/res/SEQUENCE.PNG b/res/SEQUENCE.PNG
new file mode 100644
index 0000000..cc2e9e2
Binary files /dev/null and b/res/SEQUENCE.PNG differ
diff --git a/res/SPLIT_CIFAR.PNG b/res/SPLIT_CIFAR.PNG
new file mode 100644
index 0000000..b0f4802
Binary files /dev/null and b/res/SPLIT_CIFAR.PNG differ
diff --git a/simple_benchmark.py b/simple_benchmark.py
new file mode 100644
index 0000000..9cabf06
--- /dev/null
+++ b/simple_benchmark.py
@@ -0,0 +1,174 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import argparse
+import logging
+import numpy as np
+from datetime import datetime
+
+from benchmarks_complete import (
+ PermutedMNISTGenerator,
+ SplitCIFARGenerator,
+ RotatedMNISTGenerator,
+ ImprovedSequenceGenerator
+)
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+def get_benchmark(benchmark_name: str, num_tasks: int):
+ if benchmark_name == 'permuted_mnist':
+ gen = PermutedMNISTGenerator(num_tasks=num_tasks, samples_per_task=1000)
+ input_size = 784
+ output_size = 10
+ elif benchmark_name == 'split_cifar':
+ gen = SplitCIFARGenerator(num_tasks=num_tasks, samples_per_class=500)
+ input_size = 3072
+ output_size = 2
+ elif benchmark_name == 'rotated_mnist':
+ gen = RotatedMNISTGenerator(num_tasks=num_tasks, samples_per_task=1000)
+ input_size = 784
+ output_size = 10
+ elif benchmark_name == 'sequence':
+ gen = ImprovedSequenceGenerator(num_tasks=num_tasks, samples_per_task=2000)
+ input_size = 64
+ output_size = 10
+ else:
+ raise ValueError(f"Unknown benchmark: {benchmark_name}")
+
+ return gen.get_all_tasks(batch_size=32), input_size, output_size
+
+class SimpleModel(nn.Module):
+ def __init__(self, input_size: int, hidden_size: int, output_size: int):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(input_size, hidden_size),
+ nn.ReLU(),
+ nn.Linear(hidden_size, hidden_size),
+ nn.ReLU(),
+ nn.Linear(hidden_size, output_size)
+ )
+ def forward(self, x):
+ if x.dtype != torch.float32:
+ x = x.float()
+ if len(x.shape) > 2:
+ x = x.view(x.size(0), -1)
+ return self.net(x)
+
+def train_and_evaluate(args):
+ device = torch.device(args.device)
+ logger.info("=" * 70)
+ logger.info(f"BENCHMARK: {args.benchmark.upper()}")
+ logger.info(f"TASKS: {args.num_tasks}, EPOCHS: {args.epochs}")
+ logger.info("=" * 70)
+
+ tasks, input_size, output_size = get_benchmark(args.benchmark, args.num_tasks)
+ model = SimpleModel(input_size, args.hidden_size, output_size).to(device)
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
+ loss_fn = nn.CrossEntropyLoss()
+
+ accuracies = []
+
+ for task_id, (train_loader, val_loader, test_loader) in enumerate(tasks):
+ logger.info(f"\nTASK {task_id}/{args.num_tasks - 1}")
+ best_acc = 0
+
+ for epoch in range(args.epochs):
+ model.train()
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ if x.dtype != torch.float32:
+ x = x.float()
+ if y.dtype != torch.long:
+ y = y.long()
+ logits = model(x)
+ loss = loss_fn(logits, y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ model.eval()
+ val_correct = 0
+ with torch.no_grad():
+ for x, y in val_loader:
+ x, y = x.to(device), y.to(device)
+ if x.dtype != torch.float32:
+ x = x.float()
+ if y.dtype != torch.long:
+ y = y.long()
+ logits = model(x)
+ pred = logits.argmax(dim=1)
+ val_correct += (pred == y).sum().item()
+ val_acc = val_correct / len(val_loader.dataset)
+
+ if epoch % 5 == 0 or epoch == args.epochs - 1:
+ logger.info(f" Epoch {epoch}: Val_Acc={val_acc:.4f}")
+ best_acc = max(best_acc, val_acc)
+
+ task_accs = []
+ for eval_task_id in range(task_id + 1):
+ _, _, test_loader_eval = tasks[eval_task_id]
+ model.eval()
+ test_correct = 0
+ with torch.no_grad():
+ for x, y in test_loader_eval:
+ x, y = x.to(device), y.to(device)
+ if x.dtype != torch.float32:
+ x = x.float()
+ if y.dtype != torch.long:
+ y = y.long()
+ logits = model(x)
+ pred = logits.argmax(dim=1)
+ test_correct += (pred == y).sum().item()
+ test_acc = test_correct / len(test_loader_eval.dataset)
+ task_accs.append(test_acc)
+
+ if test_acc > 0.7:
+ status = "GOOD"
+ elif test_acc > 0.5:
+ status = "MODERATE"
+ else:
+ status = "POOR"
+ logger.info(f" {status} Task {eval_task_id}: {test_acc:.4f}")
+
+ accuracies.append(task_accs)
+
+ logger.info("\n" + "=" * 70)
+ logger.info("RESULTS")
+ logger.info("=" * 70)
+
+ acc_list = [list(row) for row in accuracies]
+ final_accs = accuracies[-1]
+ avg_acc = np.mean(final_accs)
+ logger.info(f"Average Accuracy: {avg_acc:.4f}")
+
+ if args.num_tasks > 1:
+ bwt_vals = []
+ for i in range(args.num_tasks - 1):
+ forgetting = accuracies[i][i] - accuracies[-1][i]
+ bwt_vals.append(forgetting)
+ bwt = np.mean(bwt_vals)
+ logger.info(f"Backward Transfer (Forgetting): {bwt:.4f}")
+ logger.info(f" Task-wise forgetting: {[f'{f:.4f}' for f in bwt_vals]}")
+
+ logger.info("\nAccuracy Matrix:")
+ for task_id, row in enumerate(acc_list):
+ logger.info(f" Task {task_id}: {[f'{acc:.4f}' for acc in row]}")
+
+ logger.info("=" * 70)
+ logger.info("BENCHMARK COMPLETE")
+ logger.info("=" * 70)
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--benchmark', default='permuted_mnist',
+ choices=['permuted_mnist', 'split_cifar', 'rotated_mnist', 'sequence'])
+ parser.add_argument('--num_tasks', type=int, default=5)
+ parser.add_argument('--epochs', type=int, default=10)
+ parser.add_argument('--hidden_size', type=int, default=512)
+ parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ train_and_evaluate(args)
+
+if __name__ == '__main__':
+ main()
diff --git a/test_continual_learning.py b/test_continual_learning.py
new file mode 100644
index 0000000..1f0c74e
--- /dev/null
+++ b/test_continual_learning.py
@@ -0,0 +1,56 @@
+import torch
+import torch.nn as nn
+from bdh.continual_learning.adaptive_synapses import AdaptiveLinear
+from bdh.continual_learning.consolidation import ElasticWeightConsolidation
+from bdh.continual_learning.importance import FisherEstimator
+
+
+def test_adaptive_synapse():
+ """Test adaptive synaptic layer."""
+ layer = AdaptiveLinear(10, 5)
+ x = torch.randn(2, 10)
+
+ # Forward pass
+ y = layer(x)
+ assert y.shape == (2, 5)
+
+ # Check buffers exist
+ assert hasattr(layer, 'plasticity_state')
+ assert layer.plasticity_state.shape == layer.weight.shape
+
+
+def test_ewc_loss():
+ """Test EWC loss computation."""
+ model = nn.Sequential(
+ AdaptiveLinear(10, 20),
+ nn.ReLU(),
+ AdaptiveLinear(20, 5)
+ )
+
+ ewc = ElasticWeightConsolidation(lambda_ewc=1000.0)
+
+ # Should be 0 for first task
+ loss = ewc(model)
+ assert loss.item() == 0.0
+ assert ewc.is_first_task
+
+ # After consolidation
+ fisher_dict = {
+ 'weight': torch.ones(5, 20),
+ 'bias': torch.ones(5)
+ }
+ ewc.consolidate_task(model, fisher_dict)
+
+ # Should be non-zero now
+ loss = ewc(model)
+ assert loss.item() > 0.0 or loss.item() == 0.0 # After first consolidation
+
+
+if __name__ == "__main__":
+ test_adaptive_synapse()
+ print("Adaptive synapse test passed")
+
+ test_ewc_loss()
+ print("EWC loss test passed")
+
+ print("\nAll tests passed!")
diff --git a/train.py b/train.py
index 6b982d8..956588e 100644
--- a/train.py
+++ b/train.py
@@ -1,126 +1,389 @@
-# Copyright Pathway Technology, Inc.
-
-import os
-from contextlib import nullcontext
-
-import bdh
-import numpy as np
-import requests
import torch
import torch.nn as nn
-import torch.nn.functional as F
-
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-# On a Mac you can also try
-# device=torch.device('mps')
-
-dtype = (
- "bfloat16"
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
- else "float16"
-) # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
-ptdtype = {
- "float32": torch.float32,
- "bfloat16": torch.bfloat16,
- "float16": torch.float16,
-}[dtype]
-ctx = (
- torch.amp.autocast(device_type=device.type, dtype=ptdtype)
- if "cuda" in device.type
- else nullcontext()
-)
-scaler = torch.amp.GradScaler(device=device.type, enabled=(dtype == "float16"))
-torch.manual_seed(1337)
-torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
-torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
-print(f"Using device: {device} with dtype {dtype}")
-
-
-# Configuration
-BDH_CONFIG = bdh.BDHConfig()
-BLOCK_SIZE = 512
-BATCH_SIZE = 32
-MAX_ITERS = 3000
-LEARNING_RATE = 1e-3
-WEIGHT_DECAY = 0.1
-LOG_FREQ = 100
-
-input_file_path = os.path.join(os.path.dirname(__file__), "input.txt")
-
-
-# Fetch the tiny Shakespeare dataset
-def fetch_data():
- if not os.path.exists(input_file_path):
- data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
- with open(input_file_path, "w") as f:
- f.write(requests.get(data_url).text)
-
-
-def get_batch(split):
- # treat the file as bytes
- data = np.memmap(input_file_path, dtype=np.uint8, mode="r")
- if split == "train":
- data = data[: int(0.9 * len(data))]
- else:
- data = data[int(0.9 * len(data)) :]
- ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
- x = torch.stack(
- [torch.from_numpy((data[i : i + BLOCK_SIZE]).astype(np.int64)) for i in ix]
- )
- y = torch.stack(
- [
- torch.from_numpy((data[i + 1 : i + 1 + BLOCK_SIZE]).astype(np.int64))
- for i in ix
- ]
- )
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import argparse
+import os
+from pathlib import Path
+import json
+import logging
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def parse_arguments():
+ """Parse command line arguments with continual learning options."""
+ parser = argparse.ArgumentParser(description="Train BDH with optional continual learning")
+
+ parser.add_argument('--model_name', type=str, default='BDH',
+ help='Name of the model')
+ parser.add_argument('--hidden_size', type=int, default=512,
+ help='Hidden dimension size')
+ parser.add_argument('--vocab_size', type=int, default=256,
+ help='Vocabulary size')
+ parser.add_argument('--num_layers', type=int, default=8,
+ help='Number of model layers')
+ parser.add_argument('--batch_size', type=int, default=32,
+ help='Batch size for training')
+ parser.add_argument('--epochs', type=int, default=40,
+ help='Number of training epochs')
+ parser.add_argument('--learning_rate', type=float, default=0.001,
+ help='Learning rate')
+ parser.add_argument('--weight_decay', type=float, default=0.0,
+ help='Weight decay for optimizer')
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
+ help='Device to use for training')
+ parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints',
+ help='Directory to save checkpoints')
+ parser.add_argument('--seed', type=int, default=42,
+ help='Random seed')
+ parser.add_argument('--continual_learning', action='store_true',
+ help='Enable continual learning mode')
+ parser.add_argument('--num_tasks', type=int, default=1,
+ help='Number of tasks for continual learning')
+ parser.add_argument('--cl_use_consolidation', action='store_true', default=True,
+ help='Use EWC consolidation')
+ parser.add_argument('--cl_use_metaplasticity', action='store_true', default=True,
+ help='Use metaplasticity')
+ parser.add_argument('--cl_use_adaptive_layers', action='store_true', default=True,
+ help='Use adaptive layers for continual learning')
+ parser.add_argument('--cl_lambda_ewc', type=float, default=1000.0,
+ help='EWC regularization strength (lambda)')
+ parser.add_argument('--cl_consolidation_rate', type=float, default=0.98,
+ help='Consolidation rate for metaplasticity')
+ parser.add_argument('--cl_ewc_weight', type=float, default=0.5,
+ help='Weight of EWC loss in total loss')
+ parser.add_argument('--cl_num_fisher_samples', type=int, default=200,
+ help='Number of samples for Fisher estimation')
+
+ return parser.parse_args()
+
+
+def setup_seed(seed):
+ """Set random seed for reproducibility."""
+ torch.manual_seed(seed)
if torch.cuda.is_available():
- # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
- x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(
- device, non_blocking=True
- )
- else:
- x, y = x.to(device), y.to(device)
- return x, y
+ torch.cuda.manual_seed(seed)
+ import numpy as np
+ np.random.seed(seed)
-def eval(model):
- model.eval()
+class BDHModel(nn.Module):
+ """
+ Baby Dragon Hatchling (BDH) Model with Continual Learning Support.
+
+ Adapted from the original BDH architecture to support:
+ - Adaptive layers for consolidation
+ - Plasticity state tracking
+ - Optional metaplasticity
+ """
+
+ def __init__(self, vocab_size, hidden_size, num_layers,
+ use_adaptive_layers=False):
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.use_adaptive_layers = use_adaptive_layers
+
+ # Embedding layer
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
+
+ # Build layers
+ self.layers = nn.ModuleList()
+ for i in range(num_layers):
+ if use_adaptive_layers:
+ # Use adaptive linear layers for continual learning
+ try:
+ from bdh.continual_learning.adaptive_synapses import AdaptiveLinear
+ layer = nn.Sequential(
+ AdaptiveLinear(hidden_size, hidden_size),
+ nn.ReLU(),
+ )
+ except ImportError:
+ logger.warning("AdaptiveLinear not found, using standard Linear layers")
+ layer = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size),
+ nn.ReLU(),
+ )
+ else:
+ # Use standard layers
+ layer = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size),
+ nn.ReLU(),
+ )
+ self.layers.append(layer)
+
+ # Output layer
+ if use_adaptive_layers:
+ try:
+ from bdh.continual_learning.adaptive_synapses import AdaptiveLinear
+ self.output_layer = AdaptiveLinear(hidden_size, vocab_size)
+ except ImportError:
+ self.output_layer = nn.Linear(hidden_size, vocab_size)
+ else:
+ self.output_layer = nn.Linear(hidden_size, vocab_size)
+
+ def forward(self, x):
+ """Forward pass through BDH model."""
+ # x shape: (batch_size, seq_length)
+ x = self.embedding(x) # (batch_size, seq_length, hidden_size)
+
+ # Apply layers
+ for layer in self.layers:
+ x = layer(x)
+
+ # Output layer
+ x = self.output_layer(x) # (batch_size, seq_length, vocab_size)
+ return x
+
+ def forward_with_consolidation(self, x):
+ """Forward pass with consolidation gating (if using adaptive layers)."""
+ if not self.use_adaptive_layers:
+ return self.forward(x)
+
+ # Use consolidation gates
+ x = self.embedding(x)
+ for layer in self.layers:
+ if hasattr(layer[0], 'forward_with_consolidation'):
+ x = layer[0].forward_with_consolidation(x)
+ x = layer[1](x) # ReLU
+ else:
+ x = layer(x)
+
+ if hasattr(self.output_layer, 'forward_with_consolidation'):
+ x = self.output_layer.forward_with_consolidation(x)
+ else:
+ x = self.output_layer(x)
+
+ return x
-if __name__ == "__main__":
- fetch_data()
+def create_dummy_tasks(num_tasks=2, batch_size=32, seq_length=64, vocab_size=256):
+ """
+ Create dummy tasks for testing continual learning.
+ In practice, replace this with real task datasets.
+ """
+ tasks = []
+ for task_id in range(num_tasks):
+ # Create dummy data
+ train_data = torch.randint(0, vocab_size, (100 * batch_size, seq_length))
+ train_labels = torch.randint(0, vocab_size, (100 * batch_size, seq_length))
+ val_data = torch.randint(0, vocab_size, (10 * batch_size, seq_length))
+ val_labels = torch.randint(0, vocab_size, (10 * batch_size, seq_length))
+ test_data = torch.randint(0, vocab_size, (10 * batch_size, seq_length))
+ test_labels = torch.randint(0, vocab_size, (10 * batch_size, seq_length))
+
+ # Create dataloaders
+ train_loader = DataLoader(
+ list(zip(train_data, train_labels)),
+ batch_size=batch_size, shuffle=True
+ )
+ val_loader = DataLoader(
+ list(zip(val_data, val_labels)),
+ batch_size=batch_size, shuffle=False
+ )
+ test_loader = DataLoader(
+ list(zip(test_data, test_labels)),
+ batch_size=batch_size, shuffle=False
+ )
+
+ tasks.append((train_loader, val_loader, test_loader))
+
+ return tasks
+
+
+def train_standard(model, train_loader, val_loader, loss_fn, optimizer,
+ args, device):
+ """Standard training without continual learning."""
+ logger.info("Starting standard training (no continual learning)...")
+
+ best_val_loss = float('inf')
+ patience = 5
+ patience_counter = 0
+
+ for epoch in range(args.epochs):
+ # Training
+ model.train()
+ total_loss = 0.0
+
+ for batch_idx, (x, y) in enumerate(train_loader):
+ x, y = x.to(device), y.to(device)
+
+ # Forward pass
+ logits = model(x)
+ loss = loss_fn(logits.view(-1, args.vocab_size), y.view(-1))
+
+ # Backward pass
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ total_loss += loss.item()
+
+ # Validation
+ model.eval()
+ val_loss = 0.0
+ with torch.no_grad():
+ for x, y in val_loader:
+ x, y = x.to(device), y.to(device)
+ logits = model(x)
+ loss = loss_fn(logits.view(-1, args.vocab_size), y.view(-1))
+ val_loss += loss.item()
+
+ val_loss /= len(val_loader)
+ avg_train_loss = total_loss / len(train_loader)
+
+ if epoch % 5 == 0:
+ logger.info(f"Epoch {epoch}: Train Loss={avg_train_loss:.4f}, "
+ f"Val Loss={val_loss:.4f}")
+
+ # Early stopping
+ if val_loss < best_val_loss:
+ best_val_loss = val_loss
+ patience_counter = 0
+ else:
+ patience_counter += 1
+
+ if patience_counter >= patience:
+ logger.info(f"Early stopping at epoch {epoch}")
+ break
- model = bdh.BDH(BDH_CONFIG).to(device)
- model = torch.compile(model)
- optimizer = torch.optim.AdamW(
- model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
+
+def train_continual_learning(model, tasks, loss_fn, args, device):
+ """Train with continual learning and consolidation."""
+ logger.info(f"Starting continual learning training with {len(tasks)} tasks...")
+
+ # Import continual learning modules
+ from bdh.continual_learning.trainer import (
+ train_continual_learning_task,
+ continual_learning_training_loop
+ )
+ from bdh.continual_learning.metrics import ContinualLearningMetrics
+
+ # Train with continual learning
+ accuracies = continual_learning_training_loop(
+ model=model,
+ tasks=tasks,
+ num_epochs_per_task=args.epochs,
+ lr=args.learning_rate,
+ use_consolidation=args.cl_use_consolidation,
+ use_metaplasticity=args.cl_use_metaplasticity
)
+
+ # Evaluate and print metrics
+ metrics = ContinualLearningMetrics(accuracies)
+ metrics.print_summary()
+
+ return metrics
- x, y = get_batch("train")
-
- loss_acc = 0
- loss_steps = 0
- for step in range(MAX_ITERS):
- with ctx:
- logits, loss = model(x, y)
- x, y = get_batch("train")
- loss_acc += loss
- loss_steps += 1
- scaler.scale(loss).backward()
- scaler.step(optimizer)
- scaler.update()
- optimizer.zero_grad()
- if step % LOG_FREQ == 0:
- print(f"Step: {step}/{MAX_ITERS} loss {loss_acc.item() / loss_steps:.3}")
- loss_acc = 0
- loss_steps = 0
- print("Training done, now generating a sample ")
- model.eval()
- prompt = torch.tensor(
- bytearray("To be or ", "utf-8"), dtype=torch.long, device=device
- ).unsqueeze(0)
- ret = model.generate(prompt, max_new_tokens=100, top_k=3)
- ret_decoded = bytes(ret.to(torch.uint8).to("cpu").squeeze(0)).decode(
- errors="backslashreplace"
+
+def main():
+ """Main training function."""
+ args = parse_arguments()
+ setup_seed(args.seed)
+
+ device = torch.device(args.device)
+ logger.info(f"Using device: {device}")
+
+ # Create checkpoint directory
+ os.makedirs(args.checkpoint_dir, exist_ok=True)
+
+ # Create model
+ logger.info("Creating BDH model...")
+ model = BDHModel(
+ vocab_size=args.vocab_size,
+ hidden_size=args.hidden_size,
+ num_layers=args.num_layers,
+ use_adaptive_layers=args.continual_learning and args.cl_use_adaptive_layers
)
- print(ret_decoded)
+ model = model.to(device)
+
+ # Log model info
+ total_params = sum(p.numel() for p in model.parameters())
+ logger.info(f"Model created with {total_params:,} parameters")
+ logger.info(f" - Hidden size: {args.hidden_size}")
+ logger.info(f" - Num layers: {args.num_layers}")
+ logger.info(f" - Vocab size: {args.vocab_size}")
+ logger.info(f" - Adaptive layers: {args.cl_use_adaptive_layers}")
+
+ # ===== MAIN TRAINING DECISION =====
+ if args.continual_learning:
+ logger.info("=" * 60)
+ logger.info("CONTINUAL LEARNING MODE")
+ logger.info("=" * 60)
+
+ # Create tasks
+ logger.info(f"Creating {args.num_tasks} tasks...")
+ tasks = create_dummy_tasks(
+ num_tasks=args.num_tasks,
+ batch_size=args.batch_size,
+ vocab_size=args.vocab_size
+ )
+
+ # Log continual learning settings
+ logger.info(f"CL Configuration:")
+ logger.info(f" - Use consolidation: {args.cl_use_consolidation}")
+ logger.info(f" - Use metaplasticity: {args.cl_use_metaplasticity}")
+ logger.info(f" - Lambda EWC: {args.cl_lambda_ewc}")
+ logger.info(f" - Consolidation rate: {args.cl_consolidation_rate}")
+ logger.info(f" - EWC weight: {args.cl_ewc_weight}")
+
+ # Train with continual learning
+ loss_fn = nn.CrossEntropyLoss()
+ metrics = train_continual_learning(
+ model, tasks, loss_fn, args, device
+ )
+
+ # Save results
+ results_file = os.path.join(args.checkpoint_dir, 'cl_results.json')
+ with open(results_file, 'w') as f:
+ json.dump({
+ 'average_accuracy': metrics.average_accuracy(),
+ 'backward_transfer': metrics.backward_transfer(),
+ 'num_tasks': args.num_tasks,
+ }, f, indent=4)
+ logger.info(f"Results saved to {results_file}")
+
+ else:
+ logger.info("=" * 60)
+ logger.info("STANDARD TRAINING MODE (No Continual Learning)")
+ logger.info("=" * 60)
+
+ # Create standard training data (dummy for demo)
+ logger.info("Creating training data...")
+ train_data = torch.randint(0, args.vocab_size, (1000, 64))
+ train_labels = torch.randint(0, args.vocab_size, (1000, 64))
+ val_data = torch.randint(0, args.vocab_size, (100, 64))
+ val_labels = torch.randint(0, args.vocab_size, (100, 64))
+
+ train_loader = DataLoader(
+ list(zip(train_data, train_labels)),
+ batch_size=args.batch_size, shuffle=True
+ )
+ val_loader = DataLoader(
+ list(zip(val_data, val_labels)),
+ batch_size=args.batch_size, shuffle=False
+ )
+
+ # Setup optimizer and loss
+ optimizer = optim.Adam(model.parameters(),
+ lr=args.learning_rate,
+ weight_decay=args.weight_decay)
+ loss_fn = nn.CrossEntropyLoss()
+
+ # Train
+ train_standard(model, train_loader, val_loader, loss_fn,
+ optimizer, args, device)
+
+ # Save final model
+ model_path = os.path.join(args.checkpoint_dir, f'{args.model_name}_final.pt')
+ torch.save(model.state_dict(), model_path)
+ logger.info(f"Model saved to {model_path}")
+
+ logger.info("Training complete!")
+
+
+if __name__ == "__main__":
+ main()