diff --git a/.gitmodules b/.gitmodules index 377fc17b7..1176795af 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "plato/models/t2tvit"] path = plato/models/t2tvit url = https://github.com/yitu-opensource/T2T-ViT +[submodule "external/nanochat"] + path = external/nanochat + url = https://github.com/karpathy/nanochat.git diff --git a/cleanup.py b/cleanup.py index 68c658019..f03f5f3cf 100644 --- a/cleanup.py +++ b/cleanup.py @@ -146,18 +146,14 @@ def main() -> None: continue cleared = clean_directory(runtime_dir) - print( - f"Failed to delete {runtime_dir}; cleared {cleared} items instead." - ) + print(f"Failed to delete {runtime_dir}; cleared {cleared} items instead.") fallback_dirs += 1 fallback_items += cleared if runtime_total == 0: print("No runtime directories found.") else: - print( - f"Removed {runtime_removed} of {runtime_total} runtime directories." - ) + print(f"Removed {runtime_removed} of {runtime_total} runtime directories.") if fallback_dirs: print( f"Cleared {fallback_items} items in " diff --git a/configs/Nanochat/parquet_micro.toml b/configs/Nanochat/parquet_micro.toml new file mode 100644 index 000000000..cfe09a921 --- /dev/null +++ b/configs/Nanochat/parquet_micro.toml @@ -0,0 +1,53 @@ +[clients] +type = "simple" +total_clients = 10 +per_round = 3 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/nanochat/parquet" +model_path = "models/nanochat/parquet" + +[data] +datasource = "Nanochat" +sampler = "iid" +partition_size = 1 +random_seed = 1 +mode = "parquet" +max_train_batches = 16 +max_val_batches = 1 +tokenizer_threads = 2 +tokenizer_batch_size = 32 +device = "cuda" +vocab_size = 512 +synthetic_seed = 123 + +[evaluation] +type = "nanochat_core" +# bundle_dir = "~/nanochat" +max_per_task = 16 + +[trainer] +type = "nanochat" +rounds = 10000 +epochs = 5 +batch_size = 1 +model_name = "nanochat" +optimizer = "nanochat" + +[algorithm] +type = "fedavg" + +[parameters.model] +sequence_len = 256 +vocab_size = 50304 +n_layer = 4 +n_head = 4 +n_kv_head = 4 +n_embd = 256 + +[results] +types = "round, elapsed_time, core_metric, train_loss" diff --git a/configs/Nanochat/synthetic_micro.toml b/configs/Nanochat/synthetic_micro.toml new file mode 100644 index 000000000..abc3a6452 --- /dev/null +++ b/configs/Nanochat/synthetic_micro.toml @@ -0,0 +1,54 @@ +[clients] + +type = "simple" +total_clients = 1 +per_round = 1 +do_test = false + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/nanochat/synthetic" +model_path = "models/nanochat/synthetic" + +[data] +datasource = "Nanochat" +sampler = "iid" +partition_size = 1 +random_seed = 1 +mode = "synthetic" +max_train_batches = 4 +max_val_batches = 1 +tokenizer_threads = 2 +tokenizer_batch_size = 64 +device = "cpu" +vocab_size = 512 +synthetic_seed = 123 + +[evaluation] +type = "nanochat_core" +# bundle_dir = "~/nanochat" # Optional, defaults to nanochat base dir or Plato's data directory +max_per_task = 16 # Optional, -1 means run all examples + +[trainer] +type = "nanochat" +rounds = 1 +epochs = 1 +batch_size = 2 +model_name = "nanochat" +optimizer = "nanochat" + +[algorithm] +type = "fedavg" + +[parameters.model] +sequence_len = 128 +vocab_size = 512 +n_layer = 2 +n_head = 4 +n_kv_head = 4 +n_embd = 256 + +[results] +types = "round, elapsed_time, core_metric, train_loss" diff --git a/docs/nanochat_integration_checklist.md b/docs/nanochat_integration_checklist.md new file mode 100644 index 000000000..f0b7bf9b6 --- /dev/null +++ b/docs/nanochat_integration_checklist.md @@ -0,0 +1,55 @@ +# Nanochat Integration Checklist + +This checklist coordinates the work to incorporate the Nanochat stack into Plato. Owners are placeholder roles until specific engineers are assigned. + +## Third-Party Submodule +- **Owner:** Infrastructure +- **Deliverables:** Maintain the `external/nanochat` git submodule; document update procedure in `docs/third_party.md`. +- **Dependencies:** None. + +## Model Registry +- **Owner:** Modeling +- **Deliverables:** Implement `plato/models/nanochat.py` mirroring Nanochat GPT config, register entry in `plato/models/registry.py`, supply weight-loading utilities. +- **Status:** In progress – factory module and registry wiring landed. +- **Dependencies:** Third-party submodule. + +## Tokenizer & Processor +- **Owner:** Infrastructure +- **Deliverables:** Package Rust BPE via Maturin optional extra; wrap as `plato/processors/nanochat_tokenizer.py` with lazy import and fallbacks; document build steps in README. +- **Status:** Prototype processor and optional dependency group landed; CI build integration remains TODO. +- **Dependencies:** Third-party submodule, build tooling prototype. + +## Datasource +- **Owner:** Data +- **Deliverables:** Create `plato/datasources/nanochat.py` handling dataset acquisition and sharding; register in datasource registry; store license metadata. +- **Status:** In progress – streaming dataset with synthetic fallback available. +- **Dependencies:** Tokenizer availability. + +## Trainer & Algorithm +- **Owner:** Training +- **Deliverables:** Port Nanochat engine into `plato/trainers/nanochat.py`; add algorithm glue if federated coordination diverges; ensure checkpoint compatibility. +- **Status:** In progress – composable trainer wrapper with Nanochat-specific optimiser/loader strategies in place. +- **Dependencies:** Model registry entry, datasource. + +## Evaluation Strategy +- **Owner:** Evaluation +- **Deliverables:** Translate `nanochat/core_eval.py` into reusable evaluator hooked into Plato testing strategy; add pytest coverage with synthetic data. +- **Status:** CORE evaluation adapter hooked into trainer testing strategy; follow-up coverage to use real eval bundles outstanding. +- **Dependencies:** Model, tokenizer. + +## Configuration & Examples +- **Owner:** Product +- **Deliverables:** Author `configs/Nanochat/*.toml` scenarios and `examples/nanochat/` workspace; include reference scripts and documentation. +- **Status:** Synthetic micro config and workspace README published; larger-scale scenarios pending. +- **Dependencies:** Model, datasource, trainer. + +## Documentation & Release +- **Owner:** Docs +- **Deliverables:** Publish `docs/models/nanochat.md`, extend root README tables, add integration notes and changelog entry; outline hardware requirements. +- **Dependencies:** All prior tracks. + +## Validation +- **Owner:** QA +- **Deliverables:** Expand CI to compile tokenizer, run smoke train/eval, and enforce import order checks; record expected metrics in evaluation baselines. +- **Status:** Initial pytest smoke checks for tokenizer/trainer added; CI enablement still pending. +- **Dependencies:** Evaluation strategy, trainer. diff --git a/docs/third_party.md b/docs/third_party.md new file mode 100644 index 000000000..d481b6829 --- /dev/null +++ b/docs/third_party.md @@ -0,0 +1,18 @@ +# Third-Party Assets + +This page records external projects that are vendored into the Plato repository to support specific integrations. Please update the relevant entry whenever the upstream source, commit hash, or licensing information changes. + +## Nanochat +- **Upstream:** [karpathy/nanochat](https://github.com/karpathy/nanochat) +- **Location:** `external/nanochat` (git submodule) +- **License:** MIT (included in `external/nanochat/LICENSE`) + +### Updating the Submodule +1. `git submodule update --remote external/nanochat` +2. Inspect upstream changes for compatibility with Plato. +3. Commit the submodule pointer update and note any required integration work in the checklist. + +### Notes +- After cloning Plato, run `git submodule update --init --recursive` to populate all external dependencies. +- The Rust tokenizer (`rustbpe`) builds via `maturin`. Ensure `uv run --with ./external/nanochat maturin develop --release` succeeds before pushing updates. +- Avoid local modifications inside the submodule; contribute fixes upstream when possible. diff --git a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py index 487547209..e8d700c97 100644 --- a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py +++ b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py @@ -22,7 +22,6 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F -from data import build_loader from misc.config import get_config from misc.loss_ops import AdaptiveLossSoft from misc.lr_scheduler import build_scheduler @@ -37,6 +36,8 @@ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.utils import AverageMeter, ModelEma, accuracy +from data import build_loader + try: from apex import amp except ImportError: diff --git a/examples/nanochat/README.md b/examples/nanochat/README.md new file mode 100644 index 000000000..1d061345b --- /dev/null +++ b/examples/nanochat/README.md @@ -0,0 +1,61 @@ +# Nanochat Integration Workspace + +This workspace hosts Nanochat-focused experiments within Plato. + +## Quick Start + +1. Initialize the nanochat submodule (required for the nanochat integration): + + ```bash + git submodule update --init --recursive + ``` + +2. Install dependencies (including the vendored tokenizer build requirements): + + ```bash + uv sync --extra nanochat + uv run --with ./external/nanochat maturin develop --release + ``` + **Troubleshooting:** If you encounter a `maturin failed` error with "Can't find Cargo.toml", run the maturin command from within the nanochat directory: + + ```bash + uv sync --extra nanochat + cd external/nanochat && uv run maturin develop --release && cd ../.. + ``` + +3. Run the synthetic smoke configuration: + + ```bash + uv run --extra nanochat python plato.py --config configs/Nanochat/synthetic_micro.toml + ``` + + This launches a single-client training round using the Nanochat trainer, synthetic + token streams, and a downsized GPT configuration for CPU debugging. + +## CORE Evaluation + +The Nanochat trainer can invoke the upstream CORE benchmark by adding the section +below to your TOML configuration: + +```toml +[evaluation] +type = "nanochat_core" +max_per_task = 128 # optional; limits evaluation samples per task +# bundle_dir = "/custom/path/to/nanochat" # defaults to ~/.cache/nanochat +``` + +Make sure the official evaluation bundle has been downloaded so the following files +exist (the default location is `~/.cache/nanochat/eval_bundle`): + +- `core.yaml` +- `eval_data/*.jsonl` +- `eval_meta_data.csv` + +The provided `configs/Nanochat/synthetic_micro.toml` can be extended with the +`[evaluation]` block once those assets are present. + +## Roadmap + +- Integrate real Nanochat tokenized datasets and publish download helpers. +- Add baseline evaluation scripts leveraging `nanochat/core_eval.py`. +- Capture reproducible metrics and hardware notes for larger-scale runs. diff --git a/examples/nanochat/pyproject.toml b/examples/nanochat/pyproject.toml new file mode 100644 index 000000000..c0e19f09c --- /dev/null +++ b/examples/nanochat/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "plato-nanochat-examples" +version = "0.1.0" +description = "Nanochat integration examples for Plato." +readme = "README.md" +requires-python = ">=3.10" +dependencies = [] diff --git a/examples/unlearning/fedunlearning/fedunlearning_server.py b/examples/unlearning/fedunlearning/fedunlearning_server.py index 8938b6549..6d6ade576 100644 --- a/examples/unlearning/fedunlearning/fedunlearning_server.py +++ b/examples/unlearning/fedunlearning/fedunlearning_server.py @@ -43,9 +43,7 @@ async def aggregate_deltas(self, updates, deltas_received, context): if not filtered_pairs: if self._fallback_to_original: - return await super().aggregate_deltas( - updates, deltas_received, context - ) + return await super().aggregate_deltas(updates, deltas_received, context) zero_delta = self._zero_delta( context, deltas_received[0] if deltas_received else None diff --git a/external/nanochat b/external/nanochat new file mode 160000 index 000000000..c75fe54aa --- /dev/null +++ b/external/nanochat @@ -0,0 +1 @@ +Subproject commit c75fe54aa7c1fa881701c246f9427bcbe4eee5a4 diff --git a/plato/clients/strategies/defaults.py b/plato/clients/strategies/defaults.py index 7b104c163..a7a54c86d 100644 --- a/plato/clients/strategies/defaults.py +++ b/plato/clients/strategies/defaults.py @@ -383,6 +383,18 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: except TypeError: num_samples = 0 + # Extract train_loss from trainer's run_history if available + train_loss = None + if ( + context.trainer is not None + and hasattr(context.trainer, "run_history") + and context.trainer.run_history is not None + ): + try: + train_loss = context.trainer.run_history.get_latest_metric("train_loss") + except (AttributeError, KeyError, IndexError): + train_loss = None + report = SimpleNamespace( client_id=context.client_id, num_samples=num_samples, @@ -390,6 +402,7 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: training_time=training_time, comm_time=time.time(), update_response=False, + train_loss=train_loss, ) return report, weights diff --git a/plato/config.py b/plato/config.py index ec7d840a0..56f35cfa8 100644 --- a/plato/config.py +++ b/plato/config.py @@ -153,6 +153,7 @@ class Config: clients: Any server: Any data: Any + evaluation: Any trainer: Any algorithm: Any results: Any @@ -342,6 +343,10 @@ def __new__(cls): Config.params["base_path"], "data" ) + # User-defined evaluation configuration + if hasattr(config, "evaluation"): + Config.evaluation = config.evaluation + # Pretrained models if hasattr(Config().server, "model_path"): Config.params["model_path"] = os.path.join( diff --git a/plato/datasources/nanochat.py b/plato/datasources/nanochat.py new file mode 100644 index 000000000..7e53c71c6 --- /dev/null +++ b/plato/datasources/nanochat.py @@ -0,0 +1,274 @@ +""" +Streaming datasource backed by the vendored Nanochat project. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable + +try: + import torch + from torch.utils.data import IterableDataset +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat datasource requires PyTorch. " + "Install torch via the project's optional dependencies." + ) from exc + +from plato.config import Config +from plato.datasources.base import DataSource as BaseDataSource +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + +DEFAULT_VOCAB_SIZE = 50304 +DEFAULT_SEQUENCE_LENGTH = 2048 + + +def _resolve_base_dir(base_dir: str | Path | None) -> Path | None: + if base_dir is None: + return None + return Path(base_dir).expanduser().resolve() + + +def _parquet_available(base_dir: Path | None) -> bool: + try: + ensure_nanochat_importable() + from nanochat.dataset import ( # type: ignore[attr-defined] + DATA_DIR, + list_parquet_files, + ) + except (ThirdPartyImportError, ImportError): # pragma: no cover - defensive + return False + + if base_dir is not None: + candidate_dir = base_dir / "base_data" + if not candidate_dir.exists(): + return False + parquet_dir = candidate_dir + else: + parquet_dir = Path(DATA_DIR) + + try: + return len(list_parquet_files(str(parquet_dir))) > 0 + except FileNotFoundError: + return False + + +@dataclass +class _SyntheticState: + generator: torch.Generator + + +class NanochatStreamingDataset(IterableDataset): + """Iterable dataset yielding (inputs, targets) token tensors.""" + + def __init__( + self, + *, + split: str, + batch_size: int, + sequence_length: int, + mode: str, + base_dir: Path | None, + max_batches: int | None, + tokenizer_threads: int, + tokenizer_batch_size: int, + device: str, + vocab_size: int, + synthetic_seed: int, + ): + super().__init__() + if split not in {"train", "val"}: + raise ValueError("split must be 'train' or 'val'.") + + if mode not in {"auto", "parquet", "synthetic"}: + raise ValueError("mode must be 'auto', 'parquet', or 'synthetic'.") + + self.split = split + self.batch_size = batch_size + self.sequence_length = sequence_length + self.base_dir = base_dir + self.max_batches = max_batches + self.tokenizer_threads = tokenizer_threads + self.tokenizer_batch_size = tokenizer_batch_size + self.device = device + self.vocab_size = vocab_size + self.synthetic_seed = synthetic_seed + + resolved_mode = mode + if resolved_mode == "auto": + resolved_mode = "parquet" if _parquet_available(base_dir) else "synthetic" + self.mode = resolved_mode + self._synthetic_state: _SyntheticState | None = None + + # Configure Nanochat's base directory if provided. + if self.base_dir is not None: + os.environ.setdefault("NANOCHAT_BASE_DIR", str(self.base_dir)) + + def _synthetic_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: + if self._synthetic_state is None: + generator = torch.Generator() + generator.manual_seed(self.synthetic_seed) + self._synthetic_state = _SyntheticState(generator=generator) + + generator = self._synthetic_state.generator + while True: + tokens = torch.randint( + low=0, + high=self.vocab_size, + size=(self.batch_size, self.sequence_length + 1), + dtype=torch.long, + generator=generator, + ) + inputs = tokens[:, :-1].contiguous() + targets = tokens[:, 1:].contiguous() + yield inputs, targets + + def _parquet_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: + ensure_nanochat_importable() + from nanochat.dataloader import ( # type: ignore[attr-defined] + tokenizing_distributed_data_loader, + ) + + loader = tokenizing_distributed_data_loader( + self.batch_size, + self.sequence_length, + split=self.split, + tokenizer_threads=self.tokenizer_threads, + tokenizer_batch_size=self.tokenizer_batch_size, + device=self.device, + ) + for inputs, targets in loader: + yield ( + inputs.to(dtype=torch.long).contiguous(), + targets.to(dtype=torch.long).contiguous(), + ) + + def __iter__(self): + iterable = ( + self._parquet_iterable() + if self.mode == "parquet" + else self._synthetic_iterable() + ) + + for batch_index, batch in enumerate(iterable): + if self.max_batches is not None and batch_index >= self.max_batches: + break + yield batch + + def __len__(self) -> int: + if self.max_batches is None: + raise TypeError("Streaming dataset does not have a finite length.") + return self.max_batches + + +class DataSource(BaseDataSource): # type: ignore[misc] + """Plato datasource exposing Nanochat token streams.""" + + def __init__( + self, + *, + batch_size: int | None = None, + sequence_length: int | None = None, + mode: str = "auto", + base_dir: str | Path | None = None, + max_train_batches: int | None = 64, + max_val_batches: int | None = 8, + tokenizer_threads: int = 4, + tokenizer_batch_size: int = 128, + device: str = "cpu", + vocab_size: int = DEFAULT_VOCAB_SIZE, + synthetic_seed: int = 42, + ): + super().__init__() + + cfg_data = getattr(Config(), "data", None) + if cfg_data is not None: + mode = getattr(cfg_data, "mode", mode) + base_dir = getattr(cfg_data, "base_dir", base_dir) + max_train_batches = getattr( + cfg_data, "max_train_batches", max_train_batches + ) + max_val_batches = getattr(cfg_data, "max_val_batches", max_val_batches) + tokenizer_threads = getattr( + cfg_data, "tokenizer_threads", tokenizer_threads + ) + tokenizer_batch_size = getattr( + cfg_data, "tokenizer_batch_size", tokenizer_batch_size + ) + device = getattr(cfg_data, "device", device) + vocab_size = getattr(cfg_data, "vocab_size", vocab_size) + synthetic_seed = getattr(cfg_data, "synthetic_seed", synthetic_seed) + + config = getattr(Config(), "parameters", None) + model_conf = getattr(config, "model", None) + default_seq_len = DEFAULT_SEQUENCE_LENGTH + if model_conf is not None and hasattr(model_conf, "_asdict"): + seq_len_candidate = model_conf._asdict().get("sequence_len") + if isinstance(seq_len_candidate, int) and seq_len_candidate > 0: + default_seq_len = seq_len_candidate + + resolved_sequence_len = sequence_length or default_seq_len + resolved_batch_size = batch_size or getattr( + getattr(Config(), "trainer", None), "batch_size", 1 + ) + + resolved_base_dir = _resolve_base_dir(base_dir) + dataset_mode = mode + if dataset_mode == "auto": + dataset_mode = ( + "parquet" if _parquet_available(resolved_base_dir) else "synthetic" + ) + + self.trainset: NanochatStreamingDataset = NanochatStreamingDataset( + split="train", + batch_size=resolved_batch_size, + sequence_length=resolved_sequence_len, + mode=dataset_mode, + base_dir=resolved_base_dir, + max_batches=max_train_batches, + tokenizer_threads=tokenizer_threads, + tokenizer_batch_size=tokenizer_batch_size, + device=device, + vocab_size=vocab_size, + synthetic_seed=synthetic_seed, + ) + self.testset: NanochatStreamingDataset = NanochatStreamingDataset( + split="val", + batch_size=resolved_batch_size, + sequence_length=resolved_sequence_len, + mode=dataset_mode, + base_dir=resolved_base_dir, + max_batches=max_val_batches, + tokenizer_threads=tokenizer_threads, + tokenizer_batch_size=tokenizer_batch_size, + device=device, + vocab_size=vocab_size, + synthetic_seed=synthetic_seed + 1, + ) + self.sequence_length = resolved_sequence_len + self.batch_size = resolved_batch_size + self.mode = dataset_mode + + @staticmethod + def input_shape(): + """Return the default input shape (sequence length).""" + return (DEFAULT_SEQUENCE_LENGTH,) + + def num_train_examples(self) -> int: + dataset = self.trainset + if dataset.max_batches is None: + raise RuntimeError( + "Nanochat datasource streams infinity; configure max_train_batches to report size." + ) + return dataset.max_batches * dataset.batch_size + + def num_test_examples(self) -> int: + dataset = self.testset + if dataset.max_batches is None: + raise RuntimeError( + "Nanochat datasource streams infinity; configure max_val_batches to report size." + ) + return dataset.max_batches * dataset.batch_size diff --git a/plato/datasources/registry.py b/plato/datasources/registry.py index e2f101de3..d4bda0c8d 100644 --- a/plato/datasources/registry.py +++ b/plato/datasources/registry.py @@ -12,6 +12,7 @@ femnist, huggingface, lora, + nanochat, purchase, texas, tiny_imagenet, @@ -27,6 +28,7 @@ "Texas": texas, "TinyImageNet": tiny_imagenet, "Feature": feature, + "Nanochat": nanochat, } registered_partitioned_datasources = {"FEMNIST": femnist} diff --git a/plato/evaluators/__init__.py b/plato/evaluators/__init__.py new file mode 100644 index 000000000..62d84e994 --- /dev/null +++ b/plato/evaluators/__init__.py @@ -0,0 +1,3 @@ +"""Evaluation helpers for Plato integrations.""" + +# Intentionally empty for now; modules register themselves when imported. diff --git a/plato/evaluators/nanochat_core.py b/plato/evaluators/nanochat_core.py new file mode 100644 index 000000000..97a49e9f2 --- /dev/null +++ b/plato/evaluators/nanochat_core.py @@ -0,0 +1,358 @@ +""" +Adapter utilities to run Nanochat's CORE evaluation benchmark within Plato. +""" + +from __future__ import annotations + +import contextlib +import csv +import gzip +import json +import logging +import os +import random +import sys +import tarfile +import time +import zipfile +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +try: + import torch +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat CORE evaluation requires PyTorch. " + "Install the `nanochat` extra (includes torch)." + ) from exc + +import requests +import yaml + +from plato.config import Config +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + +LOGGER = logging.getLogger(__name__) + +# URL for the CORE evaluation bundle +EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" + + +@contextlib.contextmanager +def _download_guard(data_path: str): + """Serialize dataset downloads to avoid concurrent corruption.""" + os.makedirs(data_path, exist_ok=True) + lock_file = os.path.join(data_path, ".download.lock") + lock_fd = None + waited = False + + try: + while True: + try: + lock_fd = os.open(lock_file, os.O_CREAT | os.O_EXCL | os.O_RDWR) + break + except FileExistsError: + if not waited: + LOGGER.info( + "Another process is preparing the dataset at %s. Waiting.", + data_path, + ) + waited = True + time.sleep(1) + yield + finally: + if lock_fd is not None: + os.close(lock_fd) + try: + os.remove(lock_file) + except FileNotFoundError: + pass + + +def _download_eval_bundle(url: str, data_path: str) -> None: + """Download the CORE evaluation bundle from a URL if not already available.""" + url_parse = urlparse(url) + file_name = os.path.join(data_path, url_parse.path.split("/")[-1]) + os.makedirs(data_path, exist_ok=True) + sentinel = Path(f"{file_name}.complete") + + if sentinel.exists(): + return + + with _download_guard(data_path): + if sentinel.exists(): + return + + LOGGER.info("Downloading CORE evaluation bundle from %s.", url) + + res = requests.get(url, stream=True, timeout=60) + total_size = int(res.headers.get("Content-Length", 0)) + downloaded_size = 0 + + with open(file_name, "wb+") as file: + for chunk in res.iter_content(chunk_size=1024): + if not chunk: + continue + downloaded_size += len(chunk) + file.write(chunk) + file.flush() + if total_size: + sys.stdout.write(f"\r{100 * downloaded_size / total_size:.1f}%") + sys.stdout.flush() + if total_size: + sys.stdout.write("\n") + + # Unzip the compressed file just downloaded + LOGGER.info("Decompressing the CORE evaluation bundle.") + name, suffix = os.path.splitext(file_name) + + if file_name.endswith("tar.gz"): + with tarfile.open(file_name, "r:gz") as tar: + tar.extractall(data_path) + os.remove(file_name) + elif suffix == ".zip": + LOGGER.info("Extracting %s to %s.", file_name, data_path) + with zipfile.ZipFile(file_name, "r") as zip_ref: + zip_ref.extractall(data_path) + os.remove(file_name) + elif suffix == ".gz": + with gzip.open(file_name, "rb") as zipped_file: + with open(name, "wb") as unzipped_file: + unzipped_file.write(zipped_file.read()) + os.remove(file_name) + else: + LOGGER.warning("Unknown compressed file type for %s.", file_name) + + sentinel.touch() + LOGGER.info("CORE evaluation bundle downloaded and extracted successfully.") + + +def _resolve_bundle_paths(bundle_dir: str | Path | None) -> tuple[Path, Path, Path]: + """Resolve the configuration, metadata, and dataset paths for CORE evaluation.""" + ensure_nanochat_importable() + + def _get_default_base_path() -> Path: + """Get the default base path, trying nanochat first, then Plato's data directory.""" + try: + from nanochat.common import get_base_dir # pylint: disable=import-error + + return Path(get_base_dir()) + except (ImportError, OSError, PermissionError): + plato_data_path = Config().params.get("data_path", "./runtime/data") + path = Path(plato_data_path) / "nanochat" + LOGGER.info("Using Plato data directory for CORE bundle: %s", path) + return path + + # Determine base path + if bundle_dir is not None: + try: + base_path = Path(bundle_dir).expanduser().resolve() + LOGGER.info("Using bundle_dir from config: %s", base_path) + except (OSError, PermissionError, ValueError) as exc: + LOGGER.warning( + "Cannot use bundle_dir '%s': %s. Using default location.", + bundle_dir, + exc, + ) + base_path = _get_default_base_path() + else: + base_path = _get_default_base_path() + + # Ensure base path exists + base_path.mkdir(parents=True, exist_ok=True) + + eval_bundle_dir = base_path / "eval_bundle" + config_path = eval_bundle_dir / "core.yaml" + data_dir = eval_bundle_dir / "eval_data" + metadata_path = eval_bundle_dir / "eval_meta_data.csv" + + # Check if evaluation bundle exists, download if missing + if not config_path.exists() or not data_dir.exists() or not metadata_path.exists(): + LOGGER.info( + "CORE evaluation bundle not found at %s. Downloading automatically...", + base_path, + ) + _download_eval_bundle(EVAL_BUNDLE_URL, str(base_path)) + + # Verify download succeeded + if not config_path.exists(): + raise FileNotFoundError( + f"CORE evaluation config not found at {config_path}. " + "Ensure the Nanochat eval bundle is downloaded." + ) + if not data_dir.exists(): + raise FileNotFoundError( + f"CORE evaluation data directory not found at {data_dir}. " + "Ensure the Nanochat eval bundle is downloaded." + ) + if not metadata_path.exists(): + raise FileNotFoundError( + f"CORE evaluation metadata CSV not found at {metadata_path}." + ) + + return config_path, data_dir, metadata_path + + +def _load_core_tasks(config_path: Path) -> list[dict[str, Any]]: + """Load task definitions from the CORE YAML config.""" + with config_path.open("r", encoding="utf-8") as handle: + config = yaml.safe_load(handle) + tasks = config.get("icl_tasks", []) + if not isinstance(tasks, list) or not tasks: + raise ValueError( + f"No CORE tasks defined in {config_path}. Inspect the eval bundle." + ) + return tasks + + +def _load_metadata(metadata_path: Path) -> dict[str, float]: + """Load random baseline metadata for centering accuracy.""" + baseline_map: dict[str, float] = {} + with metadata_path.open("r", encoding="utf-8") as handle: + reader = csv.DictReader(handle) + for row in reader: + label = row.get("Eval Task") + baseline = row.get("Random baseline") + if label is None or baseline is None: + continue + try: + baseline_map[label] = float(baseline) + except ValueError: + LOGGER.debug("Skipping malformed baseline row: %s", row) + if not baseline_map: + raise ValueError( + f"Random baselines missing in {metadata_path}. Required for CORE metric." + ) + return baseline_map + + +def _load_task_data(data_dir: Path, dataset_uri: str) -> list[dict[str, Any]]: + """Load task dataset rows from newline-delimited JSON.""" + path = data_dir / dataset_uri + if not path.exists(): + raise FileNotFoundError( + f"CORE dataset shard '{dataset_uri}' missing under {data_dir}." + ) + with path.open("r", encoding="utf-8") as handle: + return [json.loads(line.strip()) for line in handle if line.strip()] + + +def _resolve_tokenizer(model) -> Any: + """Obtain a tokenizer compatible with Nanochat core evaluation.""" + tokenizer = getattr(model, "nanochat_tokenizer", None) + if tokenizer is not None: + return tokenizer + + ensure_nanochat_importable() + from nanochat.tokenizer import get_tokenizer # pylint: disable=import-error + + return get_tokenizer() + + +def run_core_evaluation( + model: torch.nn.Module, + *, + tokenizer: Any | None = None, + bundle_dir: str | Path | None = None, + max_per_task: int = -1, + device: torch.device | str | None = None, +) -> dict[str, Any]: + """ + Execute the CORE benchmark for the provided model. + + Args: + model: Nanochat-style autoregressive model. + tokenizer: Optional tokenizer; falls back to nanochat.tokenizer.get_tokenizer(). + bundle_dir: Optional base directory containing `eval_bundle/`. + max_per_task: Optional cap on examples per task for quicker smoke tests (-1 = all). + device: Device to run evaluation on. Defaults to the model's current device. + + Returns: + Dictionary with `results`, `centered_results`, and `core_metric`. + """ + ensure_nanochat_importable() + from nanochat.core_eval import evaluate_task # pylint: disable=import-error + + config_path, data_dir, metadata_path = _resolve_bundle_paths(bundle_dir) + tasks = _load_core_tasks(config_path) + baselines = _load_metadata(metadata_path) + + eval_tokenizer = tokenizer or _resolve_tokenizer(model) + if eval_tokenizer is None: + raise RuntimeError( + "Nanochat CORE evaluation requires a tokenizer. " + "Either attach `model.nanochat_tokenizer` or provide one explicitly." + ) + + if device is None: + try: + first_param = next(model.parameters()) + device = first_param.device + except StopIteration: + device = torch.device("cpu") + if isinstance(device, str): + device = torch.device(device) + + model_device = device + model_was_training = model.training + model = model.to(model_device) + model.eval() + + results: dict[str, float] = {} + centered_results: dict[str, float] = {} + + for task in tasks: + label = task.get("label") + if not label: + LOGGER.debug("Skipping unnamed CORE task entry: %s", task) + continue + + dataset_uri = task.get("dataset_uri") + if not isinstance(dataset_uri, str): + LOGGER.debug( + "Skipping CORE task %s due to missing dataset_uri metadata.", task + ) + continue + + task_meta = { + "task_type": task.get("icl_task_type"), + "dataset_uri": dataset_uri, + "num_fewshot": task.get("num_fewshot", [0])[0], + "continuation_delimiter": task.get("continuation_delimiter", " "), + } + start_time = time.perf_counter() + + data = _load_task_data(data_dir, dataset_uri) + shuffle_rng = random.Random(1337) + shuffle_rng.shuffle(data) + if max_per_task > 0: + data = data[:max_per_task] + + accuracy = evaluate_task(model, eval_tokenizer, data, model_device, task_meta) + baseline = baselines.get(label, 0.0) + centered = (accuracy - 0.01 * baseline) / (1.0 - 0.01 * baseline) + + results[label] = accuracy + centered_results[label] = centered + elapsed = time.perf_counter() - start_time + LOGGER.info( + "CORE task %s | accuracy %.4f | centered %.4f | %.2fs", + label, + accuracy, + centered, + elapsed, + ) + + if model_was_training: + model.train() + + if not centered_results: + raise RuntimeError("No CORE tasks were evaluated; check the eval bundle.") + + core_metric = sum(centered_results.values()) / len(centered_results) + return { + "results": results, + "centered_results": centered_results, + "core_metric": core_metric, + } diff --git a/plato/models/nanochat.py b/plato/models/nanochat.py new file mode 100644 index 000000000..114a891f9 --- /dev/null +++ b/plato/models/nanochat.py @@ -0,0 +1,173 @@ +""" +Factory for Nanochat GPT models integrated with Plato's registry. +""" + +from __future__ import annotations + +import logging +from dataclasses import fields +from typing import Any + +try: + import torch +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat model integration requires PyTorch. " + "Install torch via the project's optional dependencies." + ) from exc + +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + +DEFAULT_MODEL_CONFIG: dict[str, int] = { + "sequence_len": 2048, + "vocab_size": 50304, + "n_layer": 12, + "n_head": 6, + "n_kv_head": 6, + "n_embd": 768, +} + + +def _import_nanochat_modules(): + ensure_nanochat_importable() + from nanochat.checkpoint_manager import ( # type: ignore[attr-defined] + load_model_from_dir, + ) + from nanochat.gpt import GPT, GPTConfig # type: ignore[attr-defined] + + return GPT, GPTConfig, load_model_from_dir + + +def _sanitize_kwargs(kwargs: dict[str, Any], valid_fields: set[str]) -> dict[str, Any]: + """Filter kwargs to those accepted by GPTConfig.""" + config_kwargs = DEFAULT_MODEL_CONFIG.copy() + for key, value in kwargs.items(): + if key in valid_fields: + config_kwargs[key] = value + return config_kwargs + + +def _load_from_checkpoint( + load_dir: str, + *, + device: str | torch.device = "cpu", + phase: str = "train", + model_tag: str | None = None, + step: int | None = None, +): + """Load a Nanochat checkpoint via checkpoint_manager.""" + GPT, _, load_model_from_dir = _import_nanochat_modules() + torch_device = torch.device(device) + model, tokenizer, metadata = load_model_from_dir( + load_dir, + device=torch_device, + phase=phase, + model_tag=model_tag, + step=step, + ) + # Attach helpful metadata to the model for downstream use. + setattr(model, "nanochat_tokenizer", tokenizer) + setattr(model, "nanochat_metadata", metadata) + if not isinstance(model, GPT): + raise TypeError( + "Checkpoint loader returned an unexpected model type. " + "Ensure the checkpoint directory points to Nanochat artifacts." + ) + return model + + +class Model: + """Nanochat GPT factory compatible with Plato's model registry.""" + + @staticmethod + def get(model_name: str | None = None, **kwargs: Any): + """ + Instantiate a Nanochat GPT model. + + Keyword Args: + sequence_len: Context length (tokens). + vocab_size: Token vocabulary size. + n_layer: Number of transformer blocks. + n_head: Attention heads for queries. + n_kv_head: Attention heads for keys/values (MQA/GQA). + n_embd: Hidden dimension width. + init_weights: Whether to run Nanochat's weight initialisation (default True). + load_checkpoint_dir: Optional checkpoint directory produced by Nanochat. + load_checkpoint_tag: Optional subdirectory/model tag within checkpoint dir. + load_checkpoint_step: Optional numeric step to load (defaults to latest). + device: Torch device string for checkpoint loading. + phase: "train" or "eval" when loading checkpoints. + """ + try: + GPT, GPTConfig, _ = _import_nanochat_modules() + except ThirdPartyImportError as exc: # pragma: no cover - defensive branch + raise ImportError( + "Nanochat submodule not found. " + "Run `git submodule update --init --recursive` to populate external/nanochat." + ) from exc + + init_weights = kwargs.pop("init_weights", True) + load_dir = kwargs.pop("load_checkpoint_dir", None) + checkpoint_tag = kwargs.pop("load_checkpoint_tag", None) + checkpoint_step = kwargs.pop("load_checkpoint_step", None) + checkpoint_phase = kwargs.pop("phase", "train") + checkpoint_device = kwargs.pop("device", "cpu") + + # GPTConfig only accepts specific fields; filter unknown kwargs. + config_fields = {field.name for field in fields(GPTConfig)} + config_kwargs = _sanitize_kwargs(kwargs, config_fields) + + if load_dir: + model = _load_from_checkpoint( + load_dir, + device=checkpoint_device, + phase=checkpoint_phase, + model_tag=checkpoint_tag, + step=checkpoint_step, + ) + return model + + # Model vocab_size MUST match tokenizer vocab_size to avoid IndexError + try: + from nanochat.tokenizer import get_tokenizer + + tokenizer = get_tokenizer() + actual_vocab_size = tokenizer.get_vocab_size() + + # Override vocab_size with tokenizer's actual vocab_size + if "vocab_size" in config_kwargs: + configured_vocab = config_kwargs["vocab_size"] + if configured_vocab != actual_vocab_size: + logging.warning( + f"[Nanochat Model] Config specifies vocab_size={configured_vocab}, " + f"but tokenizer has vocab_size={actual_vocab_size}. " + f"Using tokenizer's vocab_size={actual_vocab_size} to match tokenizer." + ) + config_kwargs["vocab_size"] = actual_vocab_size + + logging.info( + f"[Nanochat Model] Using vocab_size={actual_vocab_size} from tokenizer" + ) + except Exception as e: + logging.warning( + f"[Nanochat Model] Could not auto-detect vocab_size from tokenizer: {e}. " + f"Using configured or default vocab_size." + ) + + config = GPTConfig(**config_kwargs) + model = GPT(config) + if init_weights: + model.init_weights() + + # This allows CORE evaluation and other components to access the tokenizer + try: + from nanochat.tokenizer import get_tokenizer + + tokenizer = get_tokenizer() + setattr(model, "nanochat_tokenizer", tokenizer) + except Exception: + pass + # Set max_seq_len for CORE evaluation truncation + setattr(model, "max_seq_len", config.sequence_len) + setattr(model, "nanochat_config", config_kwargs) + return model diff --git a/plato/models/registry.py b/plato/models/registry.py index e6691e3a0..382a2d064 100644 --- a/plato/models/registry.py +++ b/plato/models/registry.py @@ -15,6 +15,7 @@ huggingface, lenet5, multilayer, + nanochat, resnet, torch_hub, vgg, @@ -40,6 +41,7 @@ "torch_hub": torch_hub.Model, "huggingface": huggingface.Model, "vit": vit.Model, + "nanochat": nanochat.Model, } registered_mlx_models = {} diff --git a/plato/processors/nanochat_tokenizer.py b/plato/processors/nanochat_tokenizer.py new file mode 100644 index 000000000..7460ae910 --- /dev/null +++ b/plato/processors/nanochat_tokenizer.py @@ -0,0 +1,135 @@ +""" +Prototype tokenizer processor for Nanochat, wrapping the rustbpe+tiktoken stack. + +This module exercises the build tooling for the Rust extension while providing +an adapter that conforms to Plato's processor interface. +""" + +from __future__ import annotations + +import pickle +from collections.abc import Iterable, Sequence +from pathlib import Path +from typing import Any + +from plato.processors.base import Processor + +SPECIAL_TOKENS = [ + "<|bos|>", + "<|user_start|>", + "<|user_end|>", + "<|assistant_start|>", + "<|assistant_end|>", + "<|python_start|>", + "<|python_end|>", + "<|output_start|>", + "<|output_end|>", +] + +DEFAULT_PATTERN = ( + r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]" + r"|\s+(?!\S)|\s+" +) + + +class NanochatTokenizerProcessor(Processor): + """Prototype tokenizer that can either load a saved encoding or train via rustbpe.""" + + def __init__( + self, + tokenizer_path: str | Path | None = None, + train_corpus: Iterable[str] | None = None, + vocab_size: int = 32000, + pattern: str = DEFAULT_PATTERN, + bos_token: str = "<|bos|>", + prepend_bos: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.tokenizer_path = Path(tokenizer_path) if tokenizer_path else None + self.train_corpus = train_corpus + self.vocab_size = max(vocab_size, 256 + len(SPECIAL_TOKENS)) + self.pattern = pattern + self.bos_token = bos_token + self.prepend_bos = prepend_bos + + self._encoding = self._load_encoding() + self._special_tokens = self._infer_special_tokens() + self._bos_token_id = self._special_tokens.get(self.bos_token) + + def _load_encoding(self): + if self.tokenizer_path: + return self._load_from_pickle(self.tokenizer_path) + if self.train_corpus is not None: + return self._train_from_corpus(self.train_corpus) + raise ValueError("Either tokenizer_path or train_corpus must be provided.") + + def _load_from_pickle(self, path: Path): + with path.open("rb") as handle: + encoding = pickle.load(handle) + return encoding + + def _train_from_corpus(self, corpus: Iterable[str]): + try: + import rustbpe # type: ignore[import-not-found] + except ImportError as exc: # pragma: no cover - guarded import + raise RuntimeError( + "rustbpe extension is required to train a Nanochat tokenizer." + ) from exc + + try: + import tiktoken # type: ignore[import-not-found] + except ImportError as exc: # pragma: no cover - guarded import + raise RuntimeError( + "tiktoken is required to construct Nanochat tokenizer encodings." + ) from exc + + tokenizer = rustbpe.Tokenizer() + tokenizer.train_from_iterator( + iter(corpus), + self.vocab_size - len(SPECIAL_TOKENS), + pattern=self.pattern, + ) + + mergeable_ranks = { + bytes(piece): rank for piece, rank in tokenizer.get_mergeable_ranks() + } + tokens_offset = len(mergeable_ranks) + special_tokens = { + token: tokens_offset + index for index, token in enumerate(SPECIAL_TOKENS) + } + + return tiktoken.Encoding( + name="nanochat-rustbpe", + pat_str=tokenizer.get_pattern(), + mergeable_ranks=mergeable_ranks, + special_tokens=special_tokens, + ) + + def _infer_special_tokens(self): + try: + encode_single_token = self._encoding.encode_single_token + special_token_set = getattr(self._encoding, "special_tokens_set", set()) + except AttributeError as exc: + raise RuntimeError( + "tiktoken encoding missing expected interfaces." + ) from exc + + mapping = {} + for token in SPECIAL_TOKENS: + if token in special_token_set: + mapping[token] = encode_single_token(token) + return mapping + + def _encode_one(self, text: str): + ids = list(self._encoding.encode_ordinary(text)) + if self.prepend_bos and self._bos_token_id is not None: + ids.insert(0, self._bos_token_id) + return ids + + def process(self, data: Any): + if isinstance(data, str): + return self._encode_one(data) + if isinstance(data, Sequence): + return [self._encode_one(item) for item in data] + raise TypeError(f"Unsupported payload type: {type(data)!r}") diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index ee636cb49..36fee944b 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -252,7 +252,24 @@ async def _process_reports(self): trainer = self.require_trainer() self.accuracy = trainer.test(self.testset, self.testset_sampler) - if hasattr(Config().trainer, "target_perplexity"): + # Extract CORE evaluation results if available (Nanochat CORE evaluation) + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + self._core_metric = core_results.get("core_metric", self.accuracy) + + # If CORE benchmark was run via a Nanochat testing strategy, report the specialized CORE metric instead of the generic 'Global model accuracy' label. + core_metric = getattr(self, "_core_metric", None) + + if core_metric is not None: + logging.info( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + elif hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" @@ -273,9 +290,10 @@ def clients_processed(self) -> None: def get_logged_items(self) -> dict: """Get items to be logged by the LogProgressCallback class in a .csv file.""" - return { + logged = { "round": self.current_round, "accuracy": self.accuracy, + "core_metric": getattr(self, "_core_metric", None), "accuracy_std": self.accuracy_std, "elapsed_time": self.wall_time - self.initial_wall_time, "processing_time": max( @@ -291,6 +309,30 @@ def get_logged_items(self) -> dict: "comm_overhead": self.comm_overhead, } + # Add train_loss if available from client reports + if self.updates and hasattr(self.updates[0].report, "train_loss"): + # Compute weighted average of train_loss across clients + total_samples = sum( + update.report.num_samples + for update in self.updates + if update.report.train_loss is not None + ) + if total_samples > 0: + weighted_loss = sum( + update.report.train_loss * update.report.num_samples + for update in self.updates + if update.report.train_loss is not None + ) + logged["train_loss"] = weighted_loss / total_samples + else: + logged["train_loss"] = None + + # Add core_metric if Nanochat CORE evaluation was performed + if hasattr(self, "_core_metric"): + logged["core_metric"] = self._core_metric + + return logged + @staticmethod def get_accuracy_mean_std(updates): """Compute the accuracy mean and standard deviation across clients.""" diff --git a/plato/servers/fedavg_cs.py b/plato/servers/fedavg_cs.py index fceff363a..eaba3caf6 100644 --- a/plato/servers/fedavg_cs.py +++ b/plato/servers/fedavg_cs.py @@ -253,7 +253,22 @@ async def _process_reports(self): logging.info("[%s] Started model testing.", self) self.accuracy = trainer.test(self.testset, self.testset_sampler) - if hasattr(Config().trainer, "target_perplexity"): + # Extract CORE evaluation results if available (Nanochat CORE evaluation) + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + self._core_metric = core_results.get("core_metric", None) + + core_metric = getattr(self, "_core_metric", None) + if core_metric is not None: + logging.info( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + elif hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" @@ -274,7 +289,22 @@ async def _process_reports(self): logging.info("[%s] Started model testing.", self) self.accuracy = trainer.test(self.testset, self.testset_sampler) - if hasattr(Config().trainer, "target_perplexity"): + # Extract CORE evaluation results if available (Nanochat CORE evaluation) + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + self._core_metric = core_results.get("core_metric", None) + + core_metric = getattr(self, "_core_metric", None) + if core_metric is not None: + logging.info( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + elif hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" diff --git a/plato/servers/split_learning.py b/plato/servers/split_learning.py index 83cf2804a..49fdc2d4c 100644 --- a/plato/servers/split_learning.py +++ b/plato/servers/split_learning.py @@ -99,11 +99,30 @@ async def aggregate_weights(self, updates, baseline_weights, weights_received): self.test_accuracy = trainer.test(self.testset, self.testset_sampler) - logging.warning( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + core_metric = core_results.get("core_metric", None) + if core_metric is not None: + logging.warning( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + else: + logging.warning( + fonts.colourize( + f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" + ) + ) + else: + logging.warning( + fonts.colourize( + f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" + ) ) - ) self.phase = "prompt" # Change client in next round self.next_client = True diff --git a/plato/trainers/nanochat.py b/plato/trainers/nanochat.py new file mode 100644 index 000000000..307ccf911 --- /dev/null +++ b/plato/trainers/nanochat.py @@ -0,0 +1,296 @@ +""" +Trainer wiring for Nanochat models within the composable trainer framework. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Iterable, Iterator, List, Sequence + +try: + import torch +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat trainer requires PyTorch. " + "Install torch via the project's optional dependencies." + ) from exc + +from plato.config import Config +from plato.datasources.nanochat import NanochatStreamingDataset +from plato.evaluators.nanochat_core import run_core_evaluation +from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies.base import ( + DataLoaderStrategy, + OptimizerStrategy, + TestingStrategy, + TrainingContext, + TrainingStepStrategy, +) +from plato.trainers.strategies.data_loader import DefaultDataLoaderStrategy +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + + +def _first_element_collate(batch: Sequence[Any]) -> Any: + """Return the first (and only) element from a DataLoader batch.""" + if not batch: + raise ValueError("Received empty batch from Nanochat dataset.") + return batch[0] + + +class NanochatDataLoaderStrategy(DataLoaderStrategy): + """Use identity collation for pre-batched Nanochat streaming datasets.""" + + def __init__(self): + self._fallback = DefaultDataLoaderStrategy() + + def create_train_loader( + self, trainset, sampler, batch_size: int, context: TrainingContext + ) -> torch.utils.data.DataLoader | Iterator: + if isinstance(trainset, NanochatStreamingDataset): + return torch.utils.data.DataLoader( + trainset, + batch_size=1, + shuffle=False, + sampler=None, + num_workers=0, + collate_fn=_first_element_collate, + ) + return self._fallback.create_train_loader( + trainset, sampler, batch_size, context + ) + + +class NanochatTrainingStepStrategy(TrainingStepStrategy): + """Call Nanochat's integrated loss computation during training.""" + + def __init__(self, loss_reduction: str = "mean"): + self.loss_reduction = loss_reduction + + def training_step( + self, + model: torch.nn.Module, + optimizer, + examples: torch.Tensor, + labels: torch.Tensor, + loss_criterion, + context: TrainingContext, + ) -> torch.Tensor: + optimizer.zero_grad() + if labels is not None: + loss = model(examples, targets=labels, loss_reduction=self.loss_reduction) + else: + outputs = model(examples) + loss = loss_criterion(outputs, labels) + + if not isinstance(loss, torch.Tensor): + raise TypeError( + "Nanochat model forward pass must return a torch.Tensor loss." + ) + + loss.backward() + optimizer.step() + context.state["optimizer_step_completed"] = True + return loss.detach() + + +@dataclass +class _OptimizerBundle: + """Bundle multiple optimizers under a single interface.""" + + optimizers: List[torch.optim.Optimizer] + + def zero_grad(self) -> None: + for optimizer in self.optimizers: + optimizer.zero_grad() + + def step(self) -> None: + for optimizer in self.optimizers: + optimizer.step() + + def state_dict(self) -> dict[str, Any]: + return { + "optimizers": [optimizer.state_dict() for optimizer in self.optimizers], + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + for optimizer, payload in zip( + self.optimizers, + state_dict.get("optimizers", []), + strict=False, # type: ignore[arg-type] + ): + optimizer.load_state_dict(payload) + + @property + def param_groups(self) -> list[dict[str, Any]]: + groups: list[dict[str, Any]] = [] + for optimizer in self.optimizers: + groups.extend(getattr(optimizer, "param_groups", [])) + return groups + + def params_state_update(self) -> None: + for optimizer in self.optimizers: + hook = getattr(optimizer, "params_state_update", None) + if callable(hook): + hook() + + +class NanochatOptimizerStrategy(OptimizerStrategy): + """Adapter around nanochat.gpt.GPT.setup_optimizers.""" + + def __init__( + self, + *, + unembedding_lr: float = 0.004, + embedding_lr: float = 0.2, + matrix_lr: float = 0.02, + weight_decay: float = 0.0, + ): + self.unembedding_lr = unembedding_lr + self.embedding_lr = embedding_lr + self.matrix_lr = matrix_lr + self.weight_decay = weight_decay + + def create_optimizer( + self, model: torch.nn.Module, context: TrainingContext + ) -> _OptimizerBundle: + if not isinstance(model, torch.nn.Module): + raise TypeError("Nanochat optimizer strategy requires a torch.nn.Module.") + + setup_fn = getattr(model, "setup_optimizers", None) + if not callable(setup_fn): + raise AttributeError( + "Nanochat model is expected to expose setup_optimizers()." + ) + + optimizers = setup_fn( + unembedding_lr=self.unembedding_lr, + embedding_lr=self.embedding_lr, + matrix_lr=self.matrix_lr, + weight_decay=self.weight_decay, + ) + if not isinstance(optimizers, Iterable): + raise TypeError("setup_optimizers() must return an iterable of optimizers.") + + optimizer_list: list[torch.optim.Optimizer] = list(optimizers) + if not optimizer_list: + raise ValueError("setup_optimizers() returned an empty optimizer list.") + return _OptimizerBundle(optimizer_list) + + +class NanochatTestingStrategy(TestingStrategy): + """Compute average token loss over the validation iterator.""" + + def __init__(self, reduction: str = "sum"): + self.reduction = reduction + + def test_model( + self, + model: torch.nn.Module, + config: dict[str, Any], + testset, + sampler, + context: TrainingContext, + ) -> float: + if not isinstance(testset, NanochatStreamingDataset): + raise TypeError( + "NanochatTestingStrategy expects a NanochatStreamingDataset instance." + ) + + model.eval() + total_loss = 0.0 + total_tokens = 0 + + with torch.no_grad(): + for inputs, targets in testset: + inputs = inputs.to(context.device) + targets = targets.to(context.device) + loss = model(inputs, targets=targets, loss_reduction=self.reduction) + total_loss += float(loss.item()) + total_tokens += targets.numel() + + model.train() + if total_tokens == 0: + return float("nan") + return total_loss / total_tokens + + +class NanochatCoreTestingStrategy(TestingStrategy): + """Evaluate the CORE benchmark and return the aggregate metric.""" + + def __init__(self, bundle_dir: str | None = None, max_per_task: int = -1): + self.bundle_dir = bundle_dir + self.max_per_task = max_per_task + + def test_model( + self, + model: torch.nn.Module, + config: dict[str, Any], + testset, + sampler, + context: TrainingContext, + ) -> float: + device = context.device or next(model.parameters()).device + tokenizer = getattr(model, "nanochat_tokenizer", None) + results = run_core_evaluation( + model, + tokenizer=tokenizer, + bundle_dir=self.bundle_dir, + max_per_task=self.max_per_task, + device=device, + ) + context.state["nanochat_core_results"] = results + return float(results["core_metric"]) + + +class Trainer(ComposableTrainer): + """Composable trainer specialised for Nanochat workloads.""" + + def __init__( + self, + model=None, + callbacks=None, + *, + optimizer_params: dict[str, Any] | None = None, + loss_reduction: str = "mean", + ): + try: + ensure_nanochat_importable() + except ThirdPartyImportError as exc: # pragma: no cover - defensive branch + raise ImportError( + "Nanochat trainer requires the external/nanochat submodule. " + "Run `git submodule update --init --recursive`." + ) from exc + + optimizer_strategy = NanochatOptimizerStrategy( + **(optimizer_params or {}), + ) + training_step_strategy = NanochatTrainingStepStrategy( + loss_reduction=loss_reduction + ) + data_loader_strategy = NanochatDataLoaderStrategy() + + evaluation_cfg = getattr(Config(), "evaluation", None) + evaluation_type = ( + getattr(evaluation_cfg, "type", "").lower() if evaluation_cfg else "" + ) + if evaluation_type == "nanochat_core": + max_per_task = getattr(evaluation_cfg, "max_per_task", -1) + max_per_task_value = -1 if max_per_task is None else int(max_per_task) + testing_strategy = NanochatCoreTestingStrategy( + bundle_dir=getattr(evaluation_cfg, "bundle_dir", None), + max_per_task=max_per_task_value, + ) + else: + testing_strategy = NanochatTestingStrategy() + + super().__init__( + model=model, + callbacks=callbacks, + loss_strategy=None, + optimizer_strategy=optimizer_strategy, + training_step_strategy=training_step_strategy, + lr_scheduler_strategy=None, + model_update_strategy=None, + data_loader_strategy=data_loader_strategy, + testing_strategy=testing_strategy, + ) diff --git a/plato/trainers/registry.py b/plato/trainers/registry.py index ff0db0cb9..ef47a287e 100644 --- a/plato/trainers/registry.py +++ b/plato/trainers/registry.py @@ -12,6 +12,9 @@ gan, split_learning, ) +from plato.trainers import ( + nanochat as nanochat_trainer, +) registered_trainers = { "composable": composable.ComposableTrainer, @@ -19,6 +22,7 @@ "timm_basic": basic.TrainerWithTimmScheduler, "gan": gan.Trainer, "split_learning": split_learning.Trainer, + "nanochat": nanochat_trainer.Trainer, } diff --git a/plato/utils/third_party.py b/plato/utils/third_party.py new file mode 100644 index 000000000..890764e13 --- /dev/null +++ b/plato/utils/third_party.py @@ -0,0 +1,42 @@ +""" +Helpers for accessing vendored third-party projects. +""" + +from __future__ import annotations + +import sys +from functools import lru_cache +from pathlib import Path + + +class ThirdPartyImportError(ImportError): + """Raised when a vendored third-party project is unavailable.""" + + +@lru_cache(maxsize=None) +def _nanochat_root() -> Path: + """Return the root directory of the Nanochat submodule.""" + repo_root = Path(__file__).resolve().parents[2] + nanochat_root = repo_root / "external" / "nanochat" + if not nanochat_root.exists(): + raise ThirdPartyImportError( + "Nanochat submodule missing. Run `git submodule update --init --recursive`." + ) + return nanochat_root + + +def ensure_nanochat_importable() -> Path: + """ + Ensure the vendored Nanochat package is importable. + + Returns: + Path to the Nanochat project root. + + Raises: + ThirdPartyImportError: If the vendored Nanochat tree is missing. + """ + nanochat_root = _nanochat_root() + path_str = str(nanochat_root) + if path_str not in sys.path: + sys.path.insert(0, path_str) + return nanochat_root diff --git a/pyproject.toml b/pyproject.toml index 099d06475..228545350 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,18 @@ dp = [ mlx = [ "mlx", ] +nanochat = [ + "datasets", + "numpy", + "psutil", + "regex", + "tiktoken", + "tokenizers", + "torch", + "wandb", + "jinja2", + "PyYAML", +] [project.urls] Homepage = "https://github.com/TL-System/plato" @@ -89,11 +101,13 @@ members = [ "examples/detector", "examples/gradient_leakage_attacks", "tools", + "examples/nanochat", ] [dependency-groups] dev = [ "pytest", + "ruff", "ty", ] diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index c22b0875c..2b1bb9b75 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -157,3 +157,54 @@ def test_config_base_path_used_without_cli_override(tmp_path: Path, monkeypatch) if hasattr(Config, "args"): delattr(Config, "args") Config._cli_overrides = {} + + +def test_config_loads_evaluation_section(tmp_path: Path, monkeypatch): + """Test that [evaluation] configuration is properly loaded.""" + config_base = tmp_path / "runtime" + config_path = tmp_path / "config.toml" + + config_data = { + "clients": {"type": "simple", "total_clients": 2, "per_round": 1}, + "server": {"address": "127.0.0.1", "port": 8000}, + "data": {"datasource": "MNIST"}, + "trainer": {"type": "basic", "rounds": 1, "epochs": 1, "batch_size": 10}, + "algorithm": {"type": "fedavg"}, + "evaluation": { + "type": "nanochat_core", + "max_per_task": 128, + "bundle_dir": "/custom/path/to/nanochat", + }, + } + + toml_writer.dump(config_data, config_path) + + monkeypatch.delenv("config_file", raising=False) + monkeypatch.setattr( + sys, + "argv", + [ + sys.argv[0], + "--config", + str(config_path), + "--base", + str(config_base), + ], + ) + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} + + config = Config() + + assert hasattr(config, "evaluation") + assert config.evaluation.type == "nanochat_core" + assert config.evaluation.max_per_task == 128 + assert config.evaluation.bundle_dir == "/custom/path/to/nanochat" + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} diff --git a/tests/test_nanochat_integration.py b/tests/test_nanochat_integration.py new file mode 100644 index 000000000..6a9c174e8 --- /dev/null +++ b/tests/test_nanochat_integration.py @@ -0,0 +1,162 @@ +"""Nanochat integration smoke checks (optional, require nanochat extras).""" + +from __future__ import annotations + +import importlib.util +import os + +import pytest + +from plato.config import Config, ConfigNode +from plato.utils.third_party import ensure_nanochat_importable + +pytestmark = pytest.mark.integration + +_RUSTBPE_AVAILABLE = importlib.util.find_spec("rustbpe") is not None + + +@pytest.mark.skipif( + not _RUSTBPE_AVAILABLE, + reason="Nanochat tokenizer tests require rustbpe extension (install nanochat extras).", +) +def test_nanochat_tokenizer_processor_round_trip(tmp_path): + """Train a tiny tokenizer via rustbpe and encode a sample string.""" + pytest.importorskip( + "tiktoken", reason="Nanochat tokenizer tests require tiktoken (nanochat extra)." + ) + from plato.processors.nanochat_tokenizer import NanochatTokenizerProcessor + + corpus = ["hello nanochat", "minimal tokenizer exercise"] + processor = NanochatTokenizerProcessor( + train_corpus=corpus, + vocab_size=300, + prepend_bos=False, + ) + encoded = processor.process("hello nanochat") + assert isinstance(encoded, list) + assert len(encoded) > 0 + + +def test_nanochat_trainer_smoke(temp_config, tmp_path): + """Run one training step with synthetic Nanochat data on CPU.""" + _ = pytest.importorskip( + "torch", reason="Nanochat trainer smoke requires torch (nanochat extra)." + ) + ensure_nanochat_importable() + _ = pytest.importorskip( + "nanochat", + reason="Nanochat trainer smoke requires the nanochat package (nanochat extra).", + ) + from plato.datasources.nanochat import DataSource as NanochatDataSource + from plato.models.nanochat import Model as NanochatModel + from plato.trainers.nanochat import Trainer as NanochatTrainer + + cfg = Config() + cfg.trainer.type = "nanochat" + cfg.trainer.model_name = "nanochat_smoke" + cfg.trainer.batch_size = 2 + cfg.trainer.rounds = 1 + cfg.trainer.epochs = 1 + + cfg.parameters.model = { + "sequence_len": 16, + "vocab_size": 512, + "n_layer": 1, + "n_head": 2, + "n_kv_head": 2, + "n_embd": 128, + } + + datasource = NanochatDataSource( + batch_size=cfg.trainer.batch_size, + sequence_length=cfg.parameters.model["sequence_len"], + mode="synthetic", + max_train_batches=2, + max_val_batches=1, + device="cpu", + vocab_size=cfg.parameters.model["vocab_size"], + synthetic_seed=123, + ) + + model = NanochatModel.get( + sequence_len=cfg.parameters.model["sequence_len"], + vocab_size=cfg.parameters.model["vocab_size"], + n_layer=cfg.parameters.model["n_layer"], + n_head=cfg.parameters.model["n_head"], + n_kv_head=cfg.parameters.model["n_kv_head"], + n_embd=cfg.parameters.model["n_embd"], + init_weights=True, + ) + + trainer = NanochatTrainer(model=model) + trainset = datasource.get_train_set() + elapsed = trainer.train(trainset, sampler=None) + + assert isinstance(elapsed, float) + assert elapsed >= 0.0 + + model_dir = Config().params["model_path"] + checkpoint_name = f"{cfg.trainer.model_name}_{trainer.client_id}_{Config().params['run_id']}.safetensors" + assert os.path.exists(os.path.join(model_dir, checkpoint_name)) + + +def test_nanochat_trainer_selects_core_eval_strategy(temp_config, monkeypatch): + """Ensure evaluation config triggers the CORE testing strategy.""" + _ = pytest.importorskip( + "torch", reason="Nanochat trainer requires torch (nanochat extra)." + ) + ensure_nanochat_importable() + _ = pytest.importorskip( + "nanochat", + reason="Nanochat trainer requires the nanochat package (nanochat extra).", + ) + from plato.models.nanochat import Model as NanochatModel + from plato.trainers.nanochat import ( + NanochatCoreTestingStrategy, + ) + from plato.trainers.nanochat import ( + Trainer as NanochatTrainer, + ) + + monkeypatch.setattr( + "plato.evaluators.nanochat_core.run_core_evaluation", + lambda *args, **kwargs: { + "results": {}, + "centered_results": {}, + "core_metric": 0.0, + }, + ) + + cfg = Config() + cfg.trainer.type = "nanochat" + cfg.trainer.model_name = "nanochat_core" + cfg.trainer.batch_size = 1 + cfg.trainer.rounds = 1 + cfg.trainer.epochs = 1 + cfg.parameters.model = { + "sequence_len": 16, + "vocab_size": 512, + "n_layer": 1, + "n_head": 2, + "n_kv_head": 2, + "n_embd": 128, + } + cfg.evaluation = ConfigNode.from_object( + { + "type": "nanochat_core", + "max_per_task": 1, + } + ) + + model = NanochatModel.get( + sequence_len=cfg.parameters.model["sequence_len"], + vocab_size=cfg.parameters.model["vocab_size"], + n_layer=cfg.parameters.model["n_layer"], + n_head=cfg.parameters.model["n_head"], + n_kv_head=cfg.parameters.model["n_kv_head"], + n_embd=cfg.parameters.model["n_embd"], + init_weights=True, + ) + + trainer = NanochatTrainer(model=model) + assert isinstance(trainer.testing_strategy, NanochatCoreTestingStrategy)