From 3355698fb0065472444ecdbfcce5a7d22515cf1f Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Tue, 28 Oct 2025 08:51:05 -0400 Subject: [PATCH 01/18] Initial support for the Nanochat model and its evaluation benchmark (core_eval.py). --- configs/Nanochat/synthetic_micro.toml | 49 +++++ docs/nanochat_integration_checklist.md | 55 +++++ docs/third_party.md | 19 ++ examples/nanochat/README.md | 49 +++++ examples/nanochat/pyproject.toml | 7 + plato/datasources/nanochat.py | 262 +++++++++++++++++++++++ plato/datasources/registry.py | 2 + plato/evaluators/__init__.py | 3 + plato/evaluators/nanochat_core.py | 217 +++++++++++++++++++ plato/models/nanochat.py | 135 ++++++++++++ plato/models/registry.py | 2 + plato/processors/nanochat_tokenizer.py | 133 ++++++++++++ plato/trainers/nanochat.py | 285 +++++++++++++++++++++++++ plato/trainers/registry.py | 2 + plato/utils/third_party.py | 42 ++++ pyproject.toml | 13 ++ tests/test_nanochat_integration.py | 152 +++++++++++++ 17 files changed, 1427 insertions(+) create mode 100644 configs/Nanochat/synthetic_micro.toml create mode 100644 docs/nanochat_integration_checklist.md create mode 100644 docs/third_party.md create mode 100644 examples/nanochat/README.md create mode 100644 examples/nanochat/pyproject.toml create mode 100644 plato/datasources/nanochat.py create mode 100644 plato/evaluators/__init__.py create mode 100644 plato/evaluators/nanochat_core.py create mode 100644 plato/models/nanochat.py create mode 100644 plato/processors/nanochat_tokenizer.py create mode 100644 plato/trainers/nanochat.py create mode 100644 plato/utils/third_party.py create mode 100644 tests/test_nanochat_integration.py diff --git a/configs/Nanochat/synthetic_micro.toml b/configs/Nanochat/synthetic_micro.toml new file mode 100644 index 000000000..b22941f53 --- /dev/null +++ b/configs/Nanochat/synthetic_micro.toml @@ -0,0 +1,49 @@ +[clients] + +type = "simple" +total_clients = 1 +per_round = 1 +do_test = false + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/nanochat/synthetic" +model_path = "models/nanochat/synthetic" + +[data] +datasource = "Nanochat" +sampler = "iid" +partition_size = 1 +random_seed = 1 +mode = "synthetic" +max_train_batches = 4 +max_val_batches = 1 +tokenizer_threads = 2 +tokenizer_batch_size = 64 +device = "cpu" +vocab_size = 512 +synthetic_seed = 123 + +[trainer] +type = "nanochat" +rounds = 1 +epochs = 1 +batch_size = 2 +model_name = "nanochat" +optimizer = "nanochat" + +[algorithm] +type = "fedavg" + +[parameters.model] +sequence_len = 128 +vocab_size = 512 +n_layer = 2 +n_head = 4 +n_kv_head = 4 +n_embd = 256 + +[results] +types = "round, elapsed_time, train_loss" diff --git a/docs/nanochat_integration_checklist.md b/docs/nanochat_integration_checklist.md new file mode 100644 index 000000000..5fa411526 --- /dev/null +++ b/docs/nanochat_integration_checklist.md @@ -0,0 +1,55 @@ +# Nanochat Integration Checklist + +This checklist coordinates the work to incorporate the Nanochat stack into Plato. Owners are placeholder roles until specific engineers are assigned. + +## Third-Party Snapshot +- **Owner:** Infrastructure +- **Deliverables:** Vendor `runtime/third_party/nanochat` snapshot with commit hash in docs; add provenance blurb to `docs/third_party.md`. +- **Dependencies:** None. + +## Model Registry +- **Owner:** Modeling +- **Deliverables:** Implement `plato/models/nanochat.py` mirroring Nanochat GPT config, register entry in `plato/models/registry.py`, supply weight-loading utilities. +- **Status:** In progress – factory module and registry wiring landed. +- **Dependencies:** Third-party snapshot. + +## Tokenizer & Processor +- **Owner:** Infrastructure +- **Deliverables:** Package Rust BPE via Maturin optional extra; wrap as `plato/processors/nanochat_tokenizer.py` with lazy import and fallbacks; document build steps in README. +- **Status:** Prototype processor and optional dependency group landed; CI build integration remains TODO. +- **Dependencies:** Third-party snapshot, build tooling prototype. + +## Datasource +- **Owner:** Data +- **Deliverables:** Create `plato/datasources/nanochat.py` handling dataset acquisition and sharding; register in datasource registry; store license metadata. +- **Status:** In progress – streaming dataset with synthetic fallback available. +- **Dependencies:** Tokenizer availability. + +## Trainer & Algorithm +- **Owner:** Training +- **Deliverables:** Port Nanochat engine into `plato/trainers/nanochat.py`; add algorithm glue if federated coordination diverges; ensure checkpoint compatibility. +- **Status:** In progress – composable trainer wrapper with Nanochat-specific optimiser/loader strategies in place. +- **Dependencies:** Model registry entry, datasource. + +## Evaluation Strategy +- **Owner:** Evaluation +- **Deliverables:** Translate `nanochat/core_eval.py` into reusable evaluator hooked into Plato testing strategy; add pytest coverage with synthetic data. +- **Status:** CORE evaluation adapter hooked into trainer testing strategy; follow-up coverage to use real eval bundles outstanding. +- **Dependencies:** Model, tokenizer. + +## Configuration & Examples +- **Owner:** Product +- **Deliverables:** Author `configs/Nanochat/*.toml` scenarios and `examples/nanochat/` workspace; include reference scripts and documentation. +- **Status:** Synthetic micro config and workspace README published; larger-scale scenarios pending. +- **Dependencies:** Model, datasource, trainer. + +## Documentation & Release +- **Owner:** Docs +- **Deliverables:** Publish `docs/models/nanochat.md`, extend root README tables, add integration notes and changelog entry; outline hardware requirements. +- **Dependencies:** All prior tracks. + +## Validation +- **Owner:** QA +- **Deliverables:** Expand CI to compile tokenizer, run smoke train/eval, and enforce import order checks; record expected metrics in evaluation baselines. +- **Status:** Initial pytest smoke checks for tokenizer/trainer added; CI enablement still pending. +- **Dependencies:** Evaluation strategy, trainer. diff --git a/docs/third_party.md b/docs/third_party.md new file mode 100644 index 000000000..304d62b4e --- /dev/null +++ b/docs/third_party.md @@ -0,0 +1,19 @@ +# Third-Party Assets + +This page records external projects that are vendored into the Plato repository to support specific integrations. Please update the relevant entry whenever the upstream source, commit hash, or licensing information changes. + +## Nanochat +- **Upstream:** [karpathy/nanochat](https://github.com/karpathy/nanochat) +- **Vendored location:** `runtime/third_party/nanochat` +- **Snapshot commit:** `c75fe54aa7c1fa881701c246f9427bcbe4eee5a4` (captured 2025-03-04) +- **License:** MIT (included in `runtime/third_party/nanochat/LICENSE`) + +### Updating the Snapshot +1. `cd runtime/third_party/nanochat` +2. `git fetch origin && git checkout ` +3. Review upstream changes and confirm compatibility with Plato. +4. Record the new commit hash and date in this document, and call out notable changes in the integration checklist before landing. + +### Notes +- The Rust tokenizer (`rustbpe`) builds via `maturin`. Ensure `uv run --with ./runtime/third_party/nanochat maturin develop --release` succeeds before pushing updates. +- Keep the vendored tree free of local modifications unless backporting fixes; prefer upstream contributions when feasible. diff --git a/examples/nanochat/README.md b/examples/nanochat/README.md new file mode 100644 index 000000000..5b772c855 --- /dev/null +++ b/examples/nanochat/README.md @@ -0,0 +1,49 @@ +# Nanochat Integration Workspace + +This workspace hosts Nanochat-focused experiments within Plato. + +## Quick Start + +1. Install dependencies (including the vendored tokenizer build requirements): + + ```bash + uv sync --extra nanochat + uv run --with ./runtime/third_party/nanochat maturin develop --release + ``` + +2. Run the synthetic smoke configuration: + + ```bash + uv run --extra nanochat python plato.py --config configs/Nanochat/synthetic_micro.toml + ``` + + This launches a single-client training round using the Nanochat trainer, synthetic + token streams, and a downsized GPT configuration for CPU debugging. + +## CORE Evaluation + +The Nanochat trainer can invoke the upstream CORE benchmark by adding the section +below to your TOML configuration: + +```toml +[evaluation] +type = "nanochat_core" +max_per_task = 128 # optional; limits evaluation samples per task +# bundle_dir = "/custom/path/to/nanochat" # defaults to ~/.cache/nanochat +``` + +Make sure the official evaluation bundle has been downloaded so the following files +exist (the default location is `~/.cache/nanochat/eval_bundle`): + +- `core.yaml` +- `eval_data/*.jsonl` +- `eval_meta_data.csv` + +The provided `configs/Nanochat/synthetic_micro.toml` can be extended with the +`[evaluation]` block once those assets are present. + +## Roadmap + +- Integrate real Nanochat tokenized datasets and publish download helpers. +- Add baseline evaluation scripts leveraging `nanochat/core_eval.py`. +- Capture reproducible metrics and hardware notes for larger-scale runs. diff --git a/examples/nanochat/pyproject.toml b/examples/nanochat/pyproject.toml new file mode 100644 index 000000000..c0e19f09c --- /dev/null +++ b/examples/nanochat/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "plato-nanochat-examples" +version = "0.1.0" +description = "Nanochat integration examples for Plato." +readme = "README.md" +requires-python = ">=3.10" +dependencies = [] diff --git a/plato/datasources/nanochat.py b/plato/datasources/nanochat.py new file mode 100644 index 000000000..28c805943 --- /dev/null +++ b/plato/datasources/nanochat.py @@ -0,0 +1,262 @@ +""" +Streaming datasource backed by the vendored Nanochat project. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable + +try: + import torch + from torch.utils.data import IterableDataset +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat datasource requires PyTorch. " + "Install torch via the project's optional dependencies." + ) from exc + +from plato.config import Config +from plato.datasources.base import DataSource as BaseDataSource +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + +DEFAULT_VOCAB_SIZE = 50304 +DEFAULT_SEQUENCE_LENGTH = 2048 + + +def _resolve_base_dir(base_dir: str | Path | None) -> Path | None: + if base_dir is None: + return None + return Path(base_dir).expanduser().resolve() + + +def _parquet_available(base_dir: Path | None) -> bool: + try: + ensure_nanochat_importable() + from nanochat.dataset import list_parquet_files, DATA_DIR # type: ignore[attr-defined] + except (ThirdPartyImportError, ImportError): # pragma: no cover - defensive + return False + + if base_dir is not None: + candidate_dir = base_dir / "base_data" + if not candidate_dir.exists(): + return False + parquet_dir = candidate_dir + else: + parquet_dir = Path(DATA_DIR) + + try: + return len(list_parquet_files(str(parquet_dir))) > 0 + except FileNotFoundError: + return False + + +@dataclass +class _SyntheticState: + generator: torch.Generator + + +class NanochatStreamingDataset(IterableDataset): + """Iterable dataset yielding (inputs, targets) token tensors.""" + + def __init__( + self, + *, + split: str, + batch_size: int, + sequence_length: int, + mode: str, + base_dir: Path | None, + max_batches: int | None, + tokenizer_threads: int, + tokenizer_batch_size: int, + device: str, + vocab_size: int, + synthetic_seed: int, + ): + super().__init__() + if split not in {"train", "val"}: + raise ValueError("split must be 'train' or 'val'.") + + if mode not in {"auto", "parquet", "synthetic"}: + raise ValueError("mode must be 'auto', 'parquet', or 'synthetic'.") + + self.split = split + self.batch_size = batch_size + self.sequence_length = sequence_length + self.base_dir = base_dir + self.max_batches = max_batches + self.tokenizer_threads = tokenizer_threads + self.tokenizer_batch_size = tokenizer_batch_size + self.device = device + self.vocab_size = vocab_size + self.synthetic_seed = synthetic_seed + + resolved_mode = mode + if resolved_mode == "auto": + resolved_mode = "parquet" if _parquet_available(base_dir) else "synthetic" + self.mode = resolved_mode + self._synthetic_state: _SyntheticState | None = None + + # Configure Nanochat's base directory if provided. + if self.base_dir is not None: + os.environ.setdefault("NANOCHAT_BASE_DIR", str(self.base_dir)) + + def _synthetic_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: + if self._synthetic_state is None: + generator = torch.Generator() + generator.manual_seed(self.synthetic_seed) + self._synthetic_state = _SyntheticState(generator=generator) + + generator = self._synthetic_state.generator + while True: + tokens = torch.randint( + low=0, + high=self.vocab_size, + size=(self.batch_size, self.sequence_length + 1), + dtype=torch.long, + generator=generator, + ) + inputs = tokens[:, :-1].to(dtype=torch.long) + targets = tokens[:, 1:].to(dtype=torch.long) + yield inputs, targets + + def _parquet_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: + ensure_nanochat_importable() + from nanochat.dataloader import ( # type: ignore[attr-defined] + tokenizing_distributed_data_loader, + ) + + loader = tokenizing_distributed_data_loader( + self.batch_size, + self.sequence_length, + split=self.split, + tokenizer_threads=self.tokenizer_threads, + tokenizer_batch_size=self.tokenizer_batch_size, + device=self.device, + ) + for inputs, targets in loader: + yield inputs.to(dtype=torch.long), targets.to(dtype=torch.long) + + def __iter__(self): + iterable = ( + self._parquet_iterable() + if self.mode == "parquet" + else self._synthetic_iterable() + ) + + for batch_index, batch in enumerate(iterable): + if self.max_batches is not None and batch_index >= self.max_batches: + break + yield batch + + def __len__(self) -> int: + if self.max_batches is None: + raise TypeError("Streaming dataset does not have a finite length.") + return self.max_batches + + +class DataSource(BaseDataSource): # type: ignore[misc] + """Plato datasource exposing Nanochat token streams.""" + + def __init__( + self, + *, + batch_size: int | None = None, + sequence_length: int | None = None, + mode: str = "auto", + base_dir: str | Path | None = None, + max_train_batches: int | None = 64, + max_val_batches: int | None = 8, + tokenizer_threads: int = 4, + tokenizer_batch_size: int = 128, + device: str = "cpu", + vocab_size: int = DEFAULT_VOCAB_SIZE, + synthetic_seed: int = 42, + ): + super().__init__() + + cfg_data = getattr(Config(), "data", None) + if cfg_data is not None: + mode = getattr(cfg_data, "mode", mode) + base_dir = getattr(cfg_data, "base_dir", base_dir) + max_train_batches = getattr(cfg_data, "max_train_batches", max_train_batches) + max_val_batches = getattr(cfg_data, "max_val_batches", max_val_batches) + tokenizer_threads = getattr(cfg_data, "tokenizer_threads", tokenizer_threads) + tokenizer_batch_size = getattr( + cfg_data, "tokenizer_batch_size", tokenizer_batch_size + ) + device = getattr(cfg_data, "device", device) + vocab_size = getattr(cfg_data, "vocab_size", vocab_size) + synthetic_seed = getattr(cfg_data, "synthetic_seed", synthetic_seed) + + config = getattr(Config(), "parameters", None) + model_conf = getattr(config, "model", None) + default_seq_len = DEFAULT_SEQUENCE_LENGTH + if model_conf is not None and hasattr(model_conf, "_asdict"): + seq_len_candidate = model_conf._asdict().get("sequence_len") + if isinstance(seq_len_candidate, int) and seq_len_candidate > 0: + default_seq_len = seq_len_candidate + + resolved_sequence_len = sequence_length or default_seq_len + resolved_batch_size = batch_size or getattr( + getattr(Config(), "trainer", None), "batch_size", 1 + ) + + resolved_base_dir = _resolve_base_dir(base_dir) + dataset_mode = mode + if dataset_mode == "auto": + dataset_mode = "parquet" if _parquet_available(resolved_base_dir) else "synthetic" + + self.trainset = NanochatStreamingDataset( + split="train", + batch_size=resolved_batch_size, + sequence_length=resolved_sequence_len, + mode=dataset_mode, + base_dir=resolved_base_dir, + max_batches=max_train_batches, + tokenizer_threads=tokenizer_threads, + tokenizer_batch_size=tokenizer_batch_size, + device=device, + vocab_size=vocab_size, + synthetic_seed=synthetic_seed, + ) + self.testset = NanochatStreamingDataset( + split="val", + batch_size=resolved_batch_size, + sequence_length=resolved_sequence_len, + mode=dataset_mode, + base_dir=resolved_base_dir, + max_batches=max_val_batches, + tokenizer_threads=tokenizer_threads, + tokenizer_batch_size=tokenizer_batch_size, + device=device, + vocab_size=vocab_size, + synthetic_seed=synthetic_seed + 1, + ) + self.sequence_length = resolved_sequence_len + self.batch_size = resolved_batch_size + self.mode = dataset_mode + + @staticmethod + def input_shape(): + """Return the default input shape (sequence length).""" + return (DEFAULT_SEQUENCE_LENGTH,) + + def num_train_examples(self) -> int: + dataset = self.trainset + if dataset.max_batches is None: + raise RuntimeError( + "Nanochat datasource streams infinity; configure max_train_batches to report size." + ) + return dataset.max_batches * dataset.batch_size + + def num_test_examples(self) -> int: + dataset = self.testset + if dataset.max_batches is None: + raise RuntimeError( + "Nanochat datasource streams infinity; configure max_val_batches to report size." + ) + return dataset.max_batches * dataset.batch_size diff --git a/plato/datasources/registry.py b/plato/datasources/registry.py index e2f101de3..1334c027c 100644 --- a/plato/datasources/registry.py +++ b/plato/datasources/registry.py @@ -11,6 +11,7 @@ feature, femnist, huggingface, + nanochat, lora, purchase, texas, @@ -27,6 +28,7 @@ "Texas": texas, "TinyImageNet": tiny_imagenet, "Feature": feature, + "Nanochat": nanochat, } registered_partitioned_datasources = {"FEMNIST": femnist} diff --git a/plato/evaluators/__init__.py b/plato/evaluators/__init__.py new file mode 100644 index 000000000..62d84e994 --- /dev/null +++ b/plato/evaluators/__init__.py @@ -0,0 +1,3 @@ +"""Evaluation helpers for Plato integrations.""" + +# Intentionally empty for now; modules register themselves when imported. diff --git a/plato/evaluators/nanochat_core.py b/plato/evaluators/nanochat_core.py new file mode 100644 index 000000000..74a7c1447 --- /dev/null +++ b/plato/evaluators/nanochat_core.py @@ -0,0 +1,217 @@ +""" +Adapter utilities to run Nanochat's CORE evaluation benchmark within Plato. +""" + +from __future__ import annotations + +import csv +import json +import logging +import random +import time +from pathlib import Path +from typing import Any + +try: + import torch +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat CORE evaluation requires PyTorch. " + "Install the `nanochat` extra (includes torch)." + ) from exc + +import yaml + +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + +LOGGER = logging.getLogger(__name__) + + +def _resolve_bundle_paths(bundle_dir: str | Path | None) -> tuple[Path, Path, Path]: + """Resolve the configuration, metadata, and dataset paths for CORE evaluation.""" + ensure_nanochat_importable() + from nanochat.common import get_base_dir # pylint: disable=import-error + + if bundle_dir is None: + base_path = Path(get_base_dir()) + else: + base_path = Path(bundle_dir).expanduser().resolve() + + eval_bundle_dir = base_path / "eval_bundle" + config_path = eval_bundle_dir / "core.yaml" + data_dir = eval_bundle_dir / "eval_data" + metadata_path = eval_bundle_dir / "eval_meta_data.csv" + + if not config_path.exists(): + raise FileNotFoundError( + f"CORE evaluation config not found at {config_path}. " + "Ensure the Nanochat eval bundle is downloaded." + ) + if not data_dir.exists(): + raise FileNotFoundError( + f"CORE evaluation data directory not found at {data_dir}. " + "Ensure the Nanochat eval bundle is downloaded." + ) + if not metadata_path.exists(): + raise FileNotFoundError( + f"CORE evaluation metadata CSV not found at {metadata_path}." + ) + + return config_path, data_dir, metadata_path + + +def _load_core_tasks(config_path: Path) -> list[dict[str, Any]]: + """Load task definitions from the CORE YAML config.""" + with config_path.open("r", encoding="utf-8") as handle: + config = yaml.safe_load(handle) + tasks = config.get("icl_tasks", []) + if not isinstance(tasks, list) or not tasks: + raise ValueError( + f"No CORE tasks defined in {config_path}. Inspect the eval bundle." + ) + return tasks + + +def _load_metadata(metadata_path: Path) -> dict[str, float]: + """Load random baseline metadata for centering accuracy.""" + baseline_map: dict[str, float] = {} + with metadata_path.open("r", encoding="utf-8") as handle: + reader = csv.DictReader(handle) + for row in reader: + label = row.get("Eval Task") + baseline = row.get("Random baseline") + if label is None or baseline is None: + continue + try: + baseline_map[label] = float(baseline) + except ValueError: + LOGGER.debug("Skipping malformed baseline row: %s", row) + if not baseline_map: + raise ValueError( + f"Random baselines missing in {metadata_path}. Required for CORE metric." + ) + return baseline_map + + +def _load_task_data(data_dir: Path, dataset_uri: str) -> list[dict[str, Any]]: + """Load task dataset rows from newline-delimited JSON.""" + path = data_dir / dataset_uri + if not path.exists(): + raise FileNotFoundError( + f"CORE dataset shard '{dataset_uri}' missing under {data_dir}." + ) + with path.open("r", encoding="utf-8") as handle: + return [json.loads(line.strip()) for line in handle if line.strip()] + + +def _resolve_tokenizer(model) -> Any: + """Obtain a tokenizer compatible with Nanochat core evaluation.""" + tokenizer = getattr(model, "nanochat_tokenizer", None) + if tokenizer is not None: + return tokenizer + + ensure_nanochat_importable() + from nanochat.tokenizer import get_tokenizer # pylint: disable=import-error + + return get_tokenizer() + + +def run_core_evaluation( + model: torch.nn.Module, + *, + tokenizer: Any | None = None, + bundle_dir: str | Path | None = None, + max_per_task: int = -1, + device: torch.device | str | None = None, +) -> dict[str, Any]: + """ + Execute the CORE benchmark for the provided model. + + Args: + model: Nanochat-style autoregressive model. + tokenizer: Optional tokenizer; falls back to nanochat.tokenizer.get_tokenizer(). + bundle_dir: Optional base directory containing `eval_bundle/`. + max_per_task: Optional cap on examples per task for quicker smoke tests (-1 = all). + device: Device to run evaluation on. Defaults to the model's current device. + + Returns: + Dictionary with `results`, `centered_results`, and `core_metric`. + """ + ensure_nanochat_importable() + from nanochat.core_eval import evaluate_task # pylint: disable=import-error + + config_path, data_dir, metadata_path = _resolve_bundle_paths(bundle_dir) + tasks = _load_core_tasks(config_path) + baselines = _load_metadata(metadata_path) + + eval_tokenizer = tokenizer or _resolve_tokenizer(model) + if eval_tokenizer is None: + raise RuntimeError( + "Nanochat CORE evaluation requires a tokenizer. " + "Either attach `model.nanochat_tokenizer` or provide one explicitly." + ) + + if device is None: + try: + first_param = next(model.parameters()) + device = first_param.device + except StopIteration: + device = torch.device("cpu") + if isinstance(device, str): + device = torch.device(device) + + model_device = device + model_was_training = model.training + model = model.to(model_device) + model.eval() + + results: dict[str, float] = {} + centered_results: dict[str, float] = {} + + for task in tasks: + label = task.get("label") + if not label: + LOGGER.debug("Skipping unnamed CORE task entry: %s", task) + continue + + task_meta = { + "task_type": task.get("icl_task_type"), + "dataset_uri": task.get("dataset_uri"), + "num_fewshot": task.get("num_fewshot", [0])[0], + "continuation_delimiter": task.get("continuation_delimiter", " "), + } + start_time = time.perf_counter() + + data = _load_task_data(data_dir, task_meta["dataset_uri"]) + shuffle_rng = random.Random(1337) + shuffle_rng.shuffle(data) + if max_per_task > 0: + data = data[:max_per_task] + + accuracy = evaluate_task(model, eval_tokenizer, data, model_device, task_meta) + baseline = baselines.get(label, 0.0) + centered = (accuracy - 0.01 * baseline) / (1.0 - 0.01 * baseline) + + results[label] = accuracy + centered_results[label] = centered + elapsed = time.perf_counter() - start_time + LOGGER.info( + "CORE task %s | accuracy %.4f | centered %.4f | %.2fs", + label, + accuracy, + centered, + elapsed, + ) + + if model_was_training: + model.train() + + if not centered_results: + raise RuntimeError("No CORE tasks were evaluated; check the eval bundle.") + + core_metric = sum(centered_results.values()) / len(centered_results) + return { + "results": results, + "centered_results": centered_results, + "core_metric": core_metric, + } diff --git a/plato/models/nanochat.py b/plato/models/nanochat.py new file mode 100644 index 000000000..d2a4b5dc2 --- /dev/null +++ b/plato/models/nanochat.py @@ -0,0 +1,135 @@ +""" +Factory for Nanochat GPT models integrated with Plato's registry. +""" + +from __future__ import annotations + +from dataclasses import fields +from typing import Any + +try: + import torch +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat model integration requires PyTorch. " + "Install torch via the project's optional dependencies." + ) from exc + +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + + +DEFAULT_MODEL_CONFIG: dict[str, int] = { + "sequence_len": 2048, + "vocab_size": 50304, + "n_layer": 12, + "n_head": 6, + "n_kv_head": 6, + "n_embd": 768, +} + + +def _import_nanochat_modules(): + ensure_nanochat_importable() + from nanochat.gpt import GPT, GPTConfig # type: ignore[attr-defined] + from nanochat.checkpoint_manager import ( # type: ignore[attr-defined] + load_model_from_dir, + ) + + return GPT, GPTConfig, load_model_from_dir + + +def _sanitize_kwargs(kwargs: dict[str, Any], valid_fields: set[str]) -> dict[str, Any]: + """Filter kwargs to those accepted by GPTConfig.""" + config_kwargs = DEFAULT_MODEL_CONFIG.copy() + for key, value in kwargs.items(): + if key in valid_fields: + config_kwargs[key] = value + return config_kwargs + + +def _load_from_checkpoint( + load_dir: str, + *, + device: str | torch.device = "cpu", + phase: str = "train", + model_tag: str | None = None, + step: int | None = None, +): + """Load a Nanochat checkpoint via checkpoint_manager.""" + GPT, _, load_model_from_dir = _import_nanochat_modules() + torch_device = torch.device(device) + model, tokenizer, metadata = load_model_from_dir( + load_dir, + device=torch_device, + phase=phase, + model_tag=model_tag, + step=step, + ) + # Attach helpful metadata to the model for downstream use. + setattr(model, "nanochat_tokenizer", tokenizer) + setattr(model, "nanochat_metadata", metadata) + if not isinstance(model, GPT): + raise TypeError( + "Checkpoint loader returned an unexpected model type. " + "Ensure the checkpoint directory points to Nanochat artifacts." + ) + return model + + +class Model: + """Nanochat GPT factory compatible with Plato's model registry.""" + + @staticmethod + def get(model_name: str | None = None, **kwargs: Any): + """ + Instantiate a Nanochat GPT model. + + Keyword Args: + sequence_len: Context length (tokens). + vocab_size: Token vocabulary size. + n_layer: Number of transformer blocks. + n_head: Attention heads for queries. + n_kv_head: Attention heads for keys/values (MQA/GQA). + n_embd: Hidden dimension width. + init_weights: Whether to run Nanochat's weight initialisation (default True). + load_checkpoint_dir: Optional checkpoint directory produced by Nanochat. + load_checkpoint_tag: Optional subdirectory/model tag within checkpoint dir. + load_checkpoint_step: Optional numeric step to load (defaults to latest). + device: Torch device string for checkpoint loading. + phase: "train" or "eval" when loading checkpoints. + """ + try: + GPT, GPTConfig, _ = _import_nanochat_modules() + except ThirdPartyImportError as exc: # pragma: no cover - defensive branch + raise ImportError( + "Nanochat vendor tree not found. " + "Ensure runtime/third_party/nanochat is available." + ) from exc + + init_weights = kwargs.pop("init_weights", True) + load_dir = kwargs.pop("load_checkpoint_dir", None) + checkpoint_tag = kwargs.pop("load_checkpoint_tag", None) + checkpoint_step = kwargs.pop("load_checkpoint_step", None) + checkpoint_phase = kwargs.pop("phase", "train") + checkpoint_device = kwargs.pop("device", "cpu") + + # GPTConfig only accepts specific fields; filter unknown kwargs. + config_fields = {field.name for field in fields(GPTConfig)} + config_kwargs = _sanitize_kwargs(kwargs, config_fields) + + if load_dir: + model = _load_from_checkpoint( + load_dir, + device=checkpoint_device, + phase=checkpoint_phase, + model_tag=checkpoint_tag, + step=checkpoint_step, + ) + return model + + config = GPTConfig(**config_kwargs) + model = GPT(config) + if init_weights: + model.init_weights() + setattr(model, "nanochat_config", config_kwargs) + return model diff --git a/plato/models/registry.py b/plato/models/registry.py index e6691e3a0..4ce573e94 100644 --- a/plato/models/registry.py +++ b/plato/models/registry.py @@ -14,6 +14,7 @@ general_multilayer, huggingface, lenet5, + nanochat, multilayer, resnet, torch_hub, @@ -40,6 +41,7 @@ "torch_hub": torch_hub.Model, "huggingface": huggingface.Model, "vit": vit.Model, + "nanochat": nanochat.Model, } registered_mlx_models = {} diff --git a/plato/processors/nanochat_tokenizer.py b/plato/processors/nanochat_tokenizer.py new file mode 100644 index 000000000..e67565f5f --- /dev/null +++ b/plato/processors/nanochat_tokenizer.py @@ -0,0 +1,133 @@ +""" +Prototype tokenizer processor for Nanochat, wrapping the rustbpe+tiktoken stack. + +This module exercises the build tooling for the Rust extension while providing +an adapter that conforms to Plato's processor interface. +""" + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +import pickle +from pathlib import Path +from typing import Any + +from plato.processors.base import Processor + +SPECIAL_TOKENS = [ + "<|bos|>", + "<|user_start|>", + "<|user_end|>", + "<|assistant_start|>", + "<|assistant_end|>", + "<|python_start|>", + "<|python_end|>", + "<|output_start|>", + "<|output_end|>", +] + +DEFAULT_PATTERN = ( + r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]" + r"|\s+(?!\S)|\s+" +) + + +class NanochatTokenizerProcessor(Processor): + """Prototype tokenizer that can either load a saved encoding or train via rustbpe.""" + + def __init__( + self, + tokenizer_path: str | Path | None = None, + train_corpus: Iterable[str] | None = None, + vocab_size: int = 32000, + pattern: str = DEFAULT_PATTERN, + bos_token: str = "<|bos|>", + prepend_bos: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.tokenizer_path = Path(tokenizer_path) if tokenizer_path else None + self.train_corpus = train_corpus + self.vocab_size = max(vocab_size, 256 + len(SPECIAL_TOKENS)) + self.pattern = pattern + self.bos_token = bos_token + self.prepend_bos = prepend_bos + + self._encoding = self._load_encoding() + self._special_tokens = self._infer_special_tokens() + self._bos_token_id = self._special_tokens.get(self.bos_token) + + def _load_encoding(self): + if self.tokenizer_path: + return self._load_from_pickle(self.tokenizer_path) + if self.train_corpus is not None: + return self._train_from_corpus(self.train_corpus) + raise ValueError("Either tokenizer_path or train_corpus must be provided.") + + def _load_from_pickle(self, path: Path): + with path.open("rb") as handle: + encoding = pickle.load(handle) + return encoding + + def _train_from_corpus(self, corpus: Iterable[str]): + try: + import rustbpe # type: ignore[import-not-found] + except ImportError as exc: # pragma: no cover - guarded import + raise RuntimeError( + "rustbpe extension is required to train a Nanochat tokenizer." + ) from exc + + try: + import tiktoken # type: ignore[import-not-found] + except ImportError as exc: # pragma: no cover - guarded import + raise RuntimeError( + "tiktoken is required to construct Nanochat tokenizer encodings." + ) from exc + + tokenizer = rustbpe.Tokenizer() + tokenizer.train_from_iterator( + iter(corpus), + self.vocab_size - len(SPECIAL_TOKENS), + pattern=self.pattern, + ) + + mergeable_ranks = { + bytes(piece): rank for piece, rank in tokenizer.get_mergeable_ranks() + } + tokens_offset = len(mergeable_ranks) + special_tokens = { + token: tokens_offset + index for index, token in enumerate(SPECIAL_TOKENS) + } + + return tiktoken.Encoding( + name="nanochat-rustbpe", + pat_str=tokenizer.get_pattern(), + mergeable_ranks=mergeable_ranks, + special_tokens=special_tokens, + ) + + def _infer_special_tokens(self): + try: + encode_single_token = self._encoding.encode_single_token + special_token_set = getattr(self._encoding, "special_tokens_set", set()) + except AttributeError as exc: + raise RuntimeError("tiktoken encoding missing expected interfaces.") from exc + + mapping = {} + for token in SPECIAL_TOKENS: + if token in special_token_set: + mapping[token] = encode_single_token(token) + return mapping + + def _encode_one(self, text: str): + ids = list(self._encoding.encode_ordinary(text)) + if self.prepend_bos and self._bos_token_id is not None: + ids.insert(0, self._bos_token_id) + return ids + + def process(self, data: Any): + if isinstance(data, str): + return self._encode_one(data) + if isinstance(data, Sequence): + return [self._encode_one(item) for item in data] + raise TypeError(f"Unsupported payload type: {type(data)!r}") diff --git a/plato/trainers/nanochat.py b/plato/trainers/nanochat.py new file mode 100644 index 000000000..74555ed22 --- /dev/null +++ b/plato/trainers/nanochat.py @@ -0,0 +1,285 @@ +""" +Trainer wiring for Nanochat models within the composable trainer framework. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Iterable, Iterator, List, Sequence + +try: + import torch +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat trainer requires PyTorch. " + "Install torch via the project's optional dependencies." + ) from exc + +from plato.config import Config +from plato.datasources.nanochat import NanochatStreamingDataset +from plato.evaluators.nanochat_core import run_core_evaluation +from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies.base import ( + DataLoaderStrategy, + OptimizerStrategy, + TestingStrategy, + TrainingContext, + TrainingStepStrategy, +) +from plato.trainers.strategies.data_loader import DefaultDataLoaderStrategy +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + + +def _first_element_collate(batch: Sequence[Any]) -> Any: + """Return the first (and only) element from a DataLoader batch.""" + if not batch: + raise ValueError("Received empty batch from Nanochat dataset.") + return batch[0] + + +class NanochatDataLoaderStrategy(DataLoaderStrategy): + """Use identity collation for pre-batched Nanochat streaming datasets.""" + + def __init__(self): + self._fallback = DefaultDataLoaderStrategy() + + def create_train_loader( + self, trainset, sampler, batch_size: int, context: TrainingContext + ) -> torch.utils.data.DataLoader | Iterator: + if isinstance(trainset, NanochatStreamingDataset): + return torch.utils.data.DataLoader( + trainset, + batch_size=1, + shuffle=False, + sampler=None, + num_workers=0, + collate_fn=_first_element_collate, + ) + return self._fallback.create_train_loader(trainset, sampler, batch_size, context) + + +class NanochatTrainingStepStrategy(TrainingStepStrategy): + """Call Nanochat's integrated loss computation during training.""" + + def __init__(self, loss_reduction: str = "mean"): + self.loss_reduction = loss_reduction + + def training_step( + self, + model: torch.nn.Module, + optimizer, + examples: torch.Tensor, + labels: torch.Tensor, + loss_criterion, + context: TrainingContext, + ) -> torch.Tensor: + optimizer.zero_grad() + if labels is not None: + loss = model(examples, targets=labels, loss_reduction=self.loss_reduction) + else: + outputs = model(examples) + loss = loss_criterion(outputs, labels) + + if not isinstance(loss, torch.Tensor): + raise TypeError( + "Nanochat model forward pass must return a torch.Tensor loss." + ) + + loss.backward() + optimizer.step() + context.state["optimizer_step_completed"] = True + return loss.detach() + + +@dataclass +class _OptimizerBundle: + """Bundle multiple optimizers under a single interface.""" + + optimizers: List[torch.optim.Optimizer] + + def zero_grad(self) -> None: + for optimizer in self.optimizers: + optimizer.zero_grad() + + def step(self) -> None: + for optimizer in self.optimizers: + optimizer.step() + + def state_dict(self) -> dict[str, Any]: + return { + "optimizers": [optimizer.state_dict() for optimizer in self.optimizers], + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + for optimizer, payload in zip( + self.optimizers, state_dict.get("optimizers", []), strict=False # type: ignore[arg-type] + ): + optimizer.load_state_dict(payload) + + @property + def param_groups(self) -> list[dict[str, Any]]: + groups: list[dict[str, Any]] = [] + for optimizer in self.optimizers: + groups.extend(getattr(optimizer, "param_groups", [])) + return groups + + def params_state_update(self) -> None: + for optimizer in self.optimizers: + hook = getattr(optimizer, "params_state_update", None) + if callable(hook): + hook() + + +class NanochatOptimizerStrategy(OptimizerStrategy): + """Adapter around nanochat.gpt.GPT.setup_optimizers.""" + + def __init__( + self, + *, + unembedding_lr: float = 0.004, + embedding_lr: float = 0.2, + matrix_lr: float = 0.02, + weight_decay: float = 0.0, + ): + self.unembedding_lr = unembedding_lr + self.embedding_lr = embedding_lr + self.matrix_lr = matrix_lr + self.weight_decay = weight_decay + + def create_optimizer( + self, model: torch.nn.Module, context: TrainingContext + ) -> _OptimizerBundle: + if not hasattr(model, "setup_optimizers"): + raise AttributeError( + "Nanochat model is expected to expose setup_optimizers()." + ) + + optimizers = model.setup_optimizers( + unembedding_lr=self.unembedding_lr, + embedding_lr=self.embedding_lr, + matrix_lr=self.matrix_lr, + weight_decay=self.weight_decay, + ) + if not isinstance(optimizers, Iterable): + raise TypeError("setup_optimizers() must return an iterable of optimizers.") + + optimizer_list: list[torch.optim.Optimizer] = list(optimizers) + if not optimizer_list: + raise ValueError("setup_optimizers() returned an empty optimizer list.") + return _OptimizerBundle(optimizer_list) + + +class NanochatTestingStrategy(TestingStrategy): + """Compute average token loss over the validation iterator.""" + + def __init__(self, reduction: str = "sum"): + self.reduction = reduction + + def test_model( + self, + model: torch.nn.Module, + config: dict[str, Any], + testset, + sampler, + context: TrainingContext, + ) -> float: + if not isinstance(testset, NanochatStreamingDataset): + raise TypeError( + "NanochatTestingStrategy expects a NanochatStreamingDataset instance." + ) + + model.eval() + total_loss = 0.0 + total_tokens = 0 + + with torch.no_grad(): + for inputs, targets in testset: + inputs = inputs.to(context.device) + targets = targets.to(context.device) + loss = model(inputs, targets=targets, loss_reduction=self.reduction) + total_loss += float(loss.item()) + total_tokens += targets.numel() + + model.train() + if total_tokens == 0: + return float("nan") + return total_loss / total_tokens + + +class NanochatCoreTestingStrategy(TestingStrategy): + """Evaluate the CORE benchmark and return the aggregate metric.""" + + def __init__(self, bundle_dir: str | None = None, max_per_task: int = -1): + self.bundle_dir = bundle_dir + self.max_per_task = max_per_task + + def test_model( + self, + model: torch.nn.Module, + config: dict[str, Any], + testset, + sampler, + context: TrainingContext, + ) -> float: + device = context.device or next(model.parameters()).device + tokenizer = getattr(model, "nanochat_tokenizer", None) + results = run_core_evaluation( + model, + tokenizer=tokenizer, + bundle_dir=self.bundle_dir, + max_per_task=self.max_per_task, + device=device, + ) + context.state["nanochat_core_results"] = results + return float(results["core_metric"]) + + +class Trainer(ComposableTrainer): + """Composable trainer specialised for Nanochat workloads.""" + + def __init__( + self, + model=None, + callbacks=None, + *, + optimizer_params: dict[str, Any] | None = None, + loss_reduction: str = "mean", + ): + try: + ensure_nanochat_importable() + except ThirdPartyImportError as exc: # pragma: no cover - defensive branch + raise ImportError( + "Nanochat trainer requires the vendored runtime/third_party/nanochat project." + ) from exc + + optimizer_strategy = NanochatOptimizerStrategy( + **(optimizer_params or {}), + ) + training_step_strategy = NanochatTrainingStepStrategy( + loss_reduction=loss_reduction + ) + data_loader_strategy = NanochatDataLoaderStrategy() + + evaluation_cfg = getattr(Config(), "evaluation", None) + evaluation_type = getattr(evaluation_cfg, "type", "").lower() if evaluation_cfg else "" + if evaluation_type == "nanochat_core": + max_per_task = getattr(evaluation_cfg, "max_per_task", -1) + max_per_task_value = -1 if max_per_task is None else int(max_per_task) + testing_strategy = NanochatCoreTestingStrategy( + bundle_dir=getattr(evaluation_cfg, "bundle_dir", None), + max_per_task=max_per_task_value, + ) + else: + testing_strategy = NanochatTestingStrategy() + + super().__init__( + model=model, + callbacks=callbacks, + loss_strategy=None, + optimizer_strategy=optimizer_strategy, + training_step_strategy=training_step_strategy, + lr_scheduler_strategy=None, + model_update_strategy=None, + data_loader_strategy=data_loader_strategy, + testing_strategy=testing_strategy, + ) diff --git a/plato/trainers/registry.py b/plato/trainers/registry.py index ff0db0cb9..977d04fd2 100644 --- a/plato/trainers/registry.py +++ b/plato/trainers/registry.py @@ -10,6 +10,7 @@ basic, composable, gan, + nanochat as nanochat_trainer, split_learning, ) @@ -19,6 +20,7 @@ "timm_basic": basic.TrainerWithTimmScheduler, "gan": gan.Trainer, "split_learning": split_learning.Trainer, + "nanochat": nanochat_trainer.Trainer, } diff --git a/plato/utils/third_party.py b/plato/utils/third_party.py new file mode 100644 index 000000000..9296f86fa --- /dev/null +++ b/plato/utils/third_party.py @@ -0,0 +1,42 @@ +""" +Helpers for accessing vendored third-party projects. +""" + +from __future__ import annotations + +import sys +from functools import lru_cache +from pathlib import Path + + +class ThirdPartyImportError(ImportError): + """Raised when a vendored third-party project is unavailable.""" + + +@lru_cache(maxsize=None) +def _nanochat_root() -> Path: + """Return the root directory of the vendored Nanochat project.""" + repo_root = Path(__file__).resolve().parents[2] + nanochat_root = repo_root / "runtime" / "third_party" / "nanochat" + if not nanochat_root.exists(): + raise ThirdPartyImportError( + "Nanochat is not vendored under runtime/third_party/nanochat." + ) + return nanochat_root + + +def ensure_nanochat_importable() -> Path: + """ + Ensure the vendored Nanochat package is importable. + + Returns: + Path to the Nanochat project root. + + Raises: + ThirdPartyImportError: If the vendored Nanochat tree is missing. + """ + nanochat_root = _nanochat_root() + path_str = str(nanochat_root) + if path_str not in sys.path: + sys.path.insert(0, path_str) + return nanochat_root diff --git a/pyproject.toml b/pyproject.toml index 099d06475..e12d65cbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,18 @@ dp = [ mlx = [ "mlx", ] +nanochat = [ + "datasets>=2.14.0", + "numpy==1.26.4", + "psutil>=5.9.0", + "regex>=2023.8.8", + "tiktoken>=0.5.0", + "tokenizers>=0.15.0", + "torch>=2.2.0", + "wandb>=0.16.0", + "jinja2>=3.0", + "PyYAML>=6.0", +] [project.urls] Homepage = "https://github.com/TL-System/plato" @@ -89,6 +101,7 @@ members = [ "examples/detector", "examples/gradient_leakage_attacks", "tools", + "examples/nanochat", ] [dependency-groups] diff --git a/tests/test_nanochat_integration.py b/tests/test_nanochat_integration.py new file mode 100644 index 000000000..4a60cae36 --- /dev/null +++ b/tests/test_nanochat_integration.py @@ -0,0 +1,152 @@ +"""Nanochat integration smoke checks (optional, require nanochat extras).""" + +from __future__ import annotations + +import importlib.util +import os + +import pytest + +from plato.config import Config, ConfigNode + + +pytestmark = pytest.mark.integration + +_RUSTBPE_AVAILABLE = importlib.util.find_spec("rustbpe") is not None + + +@pytest.mark.skipif( + not _RUSTBPE_AVAILABLE, + reason="Nanochat tokenizer tests require rustbpe extension (install nanochat extras).", +) +def test_nanochat_tokenizer_processor_round_trip(tmp_path): + """Train a tiny tokenizer via rustbpe and encode a sample string.""" + pytest.importorskip( + "tiktoken", reason="Nanochat tokenizer tests require tiktoken (nanochat extra)." + ) + from plato.processors.nanochat_tokenizer import NanochatTokenizerProcessor + + corpus = ["hello nanochat", "minimal tokenizer exercise"] + processor = NanochatTokenizerProcessor( + train_corpus=corpus, + vocab_size=300, + prepend_bos=False, + ) + encoded = processor.process("hello nanochat") + assert isinstance(encoded, list) + assert len(encoded) > 0 + + +def test_nanochat_trainer_smoke(temp_config, tmp_path): + """Run one training step with synthetic Nanochat data on CPU.""" + _ = pytest.importorskip( + "torch", reason="Nanochat trainer smoke requires torch (nanochat extra)." + ) + from plato.datasources.nanochat import DataSource as NanochatDataSource + from plato.models.nanochat import Model as NanochatModel + from plato.trainers.nanochat import Trainer as NanochatTrainer + + cfg = Config() + cfg.trainer.type = "nanochat" + cfg.trainer.model_name = "nanochat_smoke" + cfg.trainer.batch_size = 2 + cfg.trainer.rounds = 1 + cfg.trainer.epochs = 1 + + cfg.parameters.model = { + "sequence_len": 16, + "vocab_size": 512, + "n_layer": 1, + "n_head": 2, + "n_kv_head": 2, + "n_embd": 128, + } + + datasource = NanochatDataSource( + batch_size=cfg.trainer.batch_size, + sequence_length=cfg.parameters.model["sequence_len"], + mode="synthetic", + max_train_batches=2, + max_val_batches=1, + device="cpu", + vocab_size=cfg.parameters.model["vocab_size"], + synthetic_seed=123, + ) + + model = NanochatModel.get( + sequence_len=cfg.parameters.model["sequence_len"], + vocab_size=cfg.parameters.model["vocab_size"], + n_layer=cfg.parameters.model["n_layer"], + n_head=cfg.parameters.model["n_head"], + n_kv_head=cfg.parameters.model["n_kv_head"], + n_embd=cfg.parameters.model["n_embd"], + init_weights=True, + ) + + trainer = NanochatTrainer(model=model) + trainset = datasource.get_train_set() + elapsed = trainer.train(trainset, sampler=None) + + assert isinstance(elapsed, float) + assert elapsed >= 0.0 + + model_dir = Config().params["model_path"] + checkpoint_name = ( + f"{cfg.trainer.model_name}_{trainer.client_id}_{Config().params['run_id']}.safetensors" + ) + assert os.path.exists(os.path.join(model_dir, checkpoint_name)) + + +def test_nanochat_trainer_selects_core_eval_strategy(temp_config, monkeypatch): + """Ensure evaluation config triggers the CORE testing strategy.""" + _ = pytest.importorskip( + "torch", reason="Nanochat trainer requires torch (nanochat extra)." + ) + from plato.models.nanochat import Model as NanochatModel + from plato.trainers.nanochat import ( + NanochatCoreTestingStrategy, + Trainer as NanochatTrainer, + ) + + monkeypatch.setattr( + "plato.evaluators.nanochat_core.run_core_evaluation", + lambda *args, **kwargs: { + "results": {}, + "centered_results": {}, + "core_metric": 0.0, + }, + ) + + cfg = Config() + cfg.trainer.type = "nanochat" + cfg.trainer.model_name = "nanochat_core" + cfg.trainer.batch_size = 1 + cfg.trainer.rounds = 1 + cfg.trainer.epochs = 1 + cfg.parameters.model = { + "sequence_len": 16, + "vocab_size": 512, + "n_layer": 1, + "n_head": 2, + "n_kv_head": 2, + "n_embd": 128, + } + cfg.evaluation = ConfigNode.from_object( + { + "type": "nanochat_core", + "max_per_task": 1, + } + ) + + model = NanochatModel.get( + sequence_len=cfg.parameters.model["sequence_len"], + vocab_size=cfg.parameters.model["vocab_size"], + n_layer=cfg.parameters.model["n_layer"], + n_head=cfg.parameters.model["n_head"], + n_kv_head=cfg.parameters.model["n_kv_head"], + n_embd=cfg.parameters.model["n_embd"], + init_weights=True, + ) + + trainer = NanochatTrainer(model=model) + assert isinstance(trainer.testing_strategy, NanochatCoreTestingStrategy) From 822a1cb35c1c413e173ebf72b336b7aaa2088d70 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Tue, 28 Oct 2025 11:03:40 -0400 Subject: [PATCH 02/18] Added support for vendoring the external Nanochat repo as a git submodule under 'external/nanochat'. --- .gitmodules | 3 +++ docs/nanochat_integration_checklist.md | 8 ++++---- docs/third_party.md | 19 +++++++++---------- examples/nanochat/README.md | 2 +- external/nanochat | 1 + plato/models/nanochat.py | 4 ++-- plato/trainers/nanochat.py | 3 ++- plato/utils/third_party.py | 6 +++--- 8 files changed, 25 insertions(+), 21 deletions(-) create mode 160000 external/nanochat diff --git a/.gitmodules b/.gitmodules index 377fc17b7..1176795af 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "plato/models/t2tvit"] path = plato/models/t2tvit url = https://github.com/yitu-opensource/T2T-ViT +[submodule "external/nanochat"] + path = external/nanochat + url = https://github.com/karpathy/nanochat.git diff --git a/docs/nanochat_integration_checklist.md b/docs/nanochat_integration_checklist.md index 5fa411526..f0b7bf9b6 100644 --- a/docs/nanochat_integration_checklist.md +++ b/docs/nanochat_integration_checklist.md @@ -2,22 +2,22 @@ This checklist coordinates the work to incorporate the Nanochat stack into Plato. Owners are placeholder roles until specific engineers are assigned. -## Third-Party Snapshot +## Third-Party Submodule - **Owner:** Infrastructure -- **Deliverables:** Vendor `runtime/third_party/nanochat` snapshot with commit hash in docs; add provenance blurb to `docs/third_party.md`. +- **Deliverables:** Maintain the `external/nanochat` git submodule; document update procedure in `docs/third_party.md`. - **Dependencies:** None. ## Model Registry - **Owner:** Modeling - **Deliverables:** Implement `plato/models/nanochat.py` mirroring Nanochat GPT config, register entry in `plato/models/registry.py`, supply weight-loading utilities. - **Status:** In progress – factory module and registry wiring landed. -- **Dependencies:** Third-party snapshot. +- **Dependencies:** Third-party submodule. ## Tokenizer & Processor - **Owner:** Infrastructure - **Deliverables:** Package Rust BPE via Maturin optional extra; wrap as `plato/processors/nanochat_tokenizer.py` with lazy import and fallbacks; document build steps in README. - **Status:** Prototype processor and optional dependency group landed; CI build integration remains TODO. -- **Dependencies:** Third-party snapshot, build tooling prototype. +- **Dependencies:** Third-party submodule, build tooling prototype. ## Datasource - **Owner:** Data diff --git a/docs/third_party.md b/docs/third_party.md index 304d62b4e..d481b6829 100644 --- a/docs/third_party.md +++ b/docs/third_party.md @@ -4,16 +4,15 @@ This page records external projects that are vendored into the Plato repository ## Nanochat - **Upstream:** [karpathy/nanochat](https://github.com/karpathy/nanochat) -- **Vendored location:** `runtime/third_party/nanochat` -- **Snapshot commit:** `c75fe54aa7c1fa881701c246f9427bcbe4eee5a4` (captured 2025-03-04) -- **License:** MIT (included in `runtime/third_party/nanochat/LICENSE`) +- **Location:** `external/nanochat` (git submodule) +- **License:** MIT (included in `external/nanochat/LICENSE`) -### Updating the Snapshot -1. `cd runtime/third_party/nanochat` -2. `git fetch origin && git checkout ` -3. Review upstream changes and confirm compatibility with Plato. -4. Record the new commit hash and date in this document, and call out notable changes in the integration checklist before landing. +### Updating the Submodule +1. `git submodule update --remote external/nanochat` +2. Inspect upstream changes for compatibility with Plato. +3. Commit the submodule pointer update and note any required integration work in the checklist. ### Notes -- The Rust tokenizer (`rustbpe`) builds via `maturin`. Ensure `uv run --with ./runtime/third_party/nanochat maturin develop --release` succeeds before pushing updates. -- Keep the vendored tree free of local modifications unless backporting fixes; prefer upstream contributions when feasible. +- After cloning Plato, run `git submodule update --init --recursive` to populate all external dependencies. +- The Rust tokenizer (`rustbpe`) builds via `maturin`. Ensure `uv run --with ./external/nanochat maturin develop --release` succeeds before pushing updates. +- Avoid local modifications inside the submodule; contribute fixes upstream when possible. diff --git a/examples/nanochat/README.md b/examples/nanochat/README.md index 5b772c855..16ae42fc5 100644 --- a/examples/nanochat/README.md +++ b/examples/nanochat/README.md @@ -8,7 +8,7 @@ This workspace hosts Nanochat-focused experiments within Plato. ```bash uv sync --extra nanochat - uv run --with ./runtime/third_party/nanochat maturin develop --release + uv run --with ./external/nanochat maturin develop --release ``` 2. Run the synthetic smoke configuration: diff --git a/external/nanochat b/external/nanochat new file mode 160000 index 000000000..c75fe54aa --- /dev/null +++ b/external/nanochat @@ -0,0 +1 @@ +Subproject commit c75fe54aa7c1fa881701c246f9427bcbe4eee5a4 diff --git a/plato/models/nanochat.py b/plato/models/nanochat.py index d2a4b5dc2..290111353 100644 --- a/plato/models/nanochat.py +++ b/plato/models/nanochat.py @@ -102,8 +102,8 @@ def get(model_name: str | None = None, **kwargs: Any): GPT, GPTConfig, _ = _import_nanochat_modules() except ThirdPartyImportError as exc: # pragma: no cover - defensive branch raise ImportError( - "Nanochat vendor tree not found. " - "Ensure runtime/third_party/nanochat is available." + "Nanochat submodule not found. " + "Run `git submodule update --init --recursive` to populate external/nanochat." ) from exc init_weights = kwargs.pop("init_weights", True) diff --git a/plato/trainers/nanochat.py b/plato/trainers/nanochat.py index 74555ed22..5415cb6bb 100644 --- a/plato/trainers/nanochat.py +++ b/plato/trainers/nanochat.py @@ -249,7 +249,8 @@ def __init__( ensure_nanochat_importable() except ThirdPartyImportError as exc: # pragma: no cover - defensive branch raise ImportError( - "Nanochat trainer requires the vendored runtime/third_party/nanochat project." + "Nanochat trainer requires the external/nanochat submodule. " + "Run `git submodule update --init --recursive`." ) from exc optimizer_strategy = NanochatOptimizerStrategy( diff --git a/plato/utils/third_party.py b/plato/utils/third_party.py index 9296f86fa..890764e13 100644 --- a/plato/utils/third_party.py +++ b/plato/utils/third_party.py @@ -15,12 +15,12 @@ class ThirdPartyImportError(ImportError): @lru_cache(maxsize=None) def _nanochat_root() -> Path: - """Return the root directory of the vendored Nanochat project.""" + """Return the root directory of the Nanochat submodule.""" repo_root = Path(__file__).resolve().parents[2] - nanochat_root = repo_root / "runtime" / "third_party" / "nanochat" + nanochat_root = repo_root / "external" / "nanochat" if not nanochat_root.exists(): raise ThirdPartyImportError( - "Nanochat is not vendored under runtime/third_party/nanochat." + "Nanochat submodule missing. Run `git submodule update --init --recursive`." ) return nanochat_root From e039b2c73a9d4b76dfae404ef3dae765fd8b8cbf Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Tue, 28 Oct 2025 11:07:41 -0400 Subject: [PATCH 03/18] ruff check --fix & ruff format. --- .../fedunlearning/fedunlearning_server.py | 4 +--- plato/datasources/nanochat.py | 17 +++++++++++++---- plato/datasources/registry.py | 2 +- plato/models/nanochat.py | 3 +-- plato/models/registry.py | 2 +- plato/processors/nanochat_tokenizer.py | 6 ++++-- plato/trainers/nanochat.py | 12 +++++++++--- plato/trainers/registry.py | 4 +++- tests/test_nanochat_integration.py | 7 +++---- 9 files changed, 36 insertions(+), 21 deletions(-) diff --git a/examples/unlearning/fedunlearning/fedunlearning_server.py b/examples/unlearning/fedunlearning/fedunlearning_server.py index 8938b6549..6d6ade576 100644 --- a/examples/unlearning/fedunlearning/fedunlearning_server.py +++ b/examples/unlearning/fedunlearning/fedunlearning_server.py @@ -43,9 +43,7 @@ async def aggregate_deltas(self, updates, deltas_received, context): if not filtered_pairs: if self._fallback_to_original: - return await super().aggregate_deltas( - updates, deltas_received, context - ) + return await super().aggregate_deltas(updates, deltas_received, context) zero_delta = self._zero_delta( context, deltas_received[0] if deltas_received else None diff --git a/plato/datasources/nanochat.py b/plato/datasources/nanochat.py index 28c805943..3a0f9201a 100644 --- a/plato/datasources/nanochat.py +++ b/plato/datasources/nanochat.py @@ -35,7 +35,10 @@ def _resolve_base_dir(base_dir: str | Path | None) -> Path | None: def _parquet_available(base_dir: Path | None) -> bool: try: ensure_nanochat_importable() - from nanochat.dataset import list_parquet_files, DATA_DIR # type: ignore[attr-defined] + from nanochat.dataset import ( # type: ignore[attr-defined] + DATA_DIR, + list_parquet_files, + ) except (ThirdPartyImportError, ImportError): # pragma: no cover - defensive return False @@ -182,9 +185,13 @@ def __init__( if cfg_data is not None: mode = getattr(cfg_data, "mode", mode) base_dir = getattr(cfg_data, "base_dir", base_dir) - max_train_batches = getattr(cfg_data, "max_train_batches", max_train_batches) + max_train_batches = getattr( + cfg_data, "max_train_batches", max_train_batches + ) max_val_batches = getattr(cfg_data, "max_val_batches", max_val_batches) - tokenizer_threads = getattr(cfg_data, "tokenizer_threads", tokenizer_threads) + tokenizer_threads = getattr( + cfg_data, "tokenizer_threads", tokenizer_threads + ) tokenizer_batch_size = getattr( cfg_data, "tokenizer_batch_size", tokenizer_batch_size ) @@ -208,7 +215,9 @@ def __init__( resolved_base_dir = _resolve_base_dir(base_dir) dataset_mode = mode if dataset_mode == "auto": - dataset_mode = "parquet" if _parquet_available(resolved_base_dir) else "synthetic" + dataset_mode = ( + "parquet" if _parquet_available(resolved_base_dir) else "synthetic" + ) self.trainset = NanochatStreamingDataset( split="train", diff --git a/plato/datasources/registry.py b/plato/datasources/registry.py index 1334c027c..d4bda0c8d 100644 --- a/plato/datasources/registry.py +++ b/plato/datasources/registry.py @@ -11,8 +11,8 @@ feature, femnist, huggingface, - nanochat, lora, + nanochat, purchase, texas, tiny_imagenet, diff --git a/plato/models/nanochat.py b/plato/models/nanochat.py index 290111353..afd8d4a59 100644 --- a/plato/models/nanochat.py +++ b/plato/models/nanochat.py @@ -17,7 +17,6 @@ from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable - DEFAULT_MODEL_CONFIG: dict[str, int] = { "sequence_len": 2048, "vocab_size": 50304, @@ -30,10 +29,10 @@ def _import_nanochat_modules(): ensure_nanochat_importable() - from nanochat.gpt import GPT, GPTConfig # type: ignore[attr-defined] from nanochat.checkpoint_manager import ( # type: ignore[attr-defined] load_model_from_dir, ) + from nanochat.gpt import GPT, GPTConfig # type: ignore[attr-defined] return GPT, GPTConfig, load_model_from_dir diff --git a/plato/models/registry.py b/plato/models/registry.py index 4ce573e94..382a2d064 100644 --- a/plato/models/registry.py +++ b/plato/models/registry.py @@ -14,8 +14,8 @@ general_multilayer, huggingface, lenet5, - nanochat, multilayer, + nanochat, resnet, torch_hub, vgg, diff --git a/plato/processors/nanochat_tokenizer.py b/plato/processors/nanochat_tokenizer.py index e67565f5f..7460ae910 100644 --- a/plato/processors/nanochat_tokenizer.py +++ b/plato/processors/nanochat_tokenizer.py @@ -7,8 +7,8 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence import pickle +from collections.abc import Iterable, Sequence from pathlib import Path from typing import Any @@ -111,7 +111,9 @@ def _infer_special_tokens(self): encode_single_token = self._encoding.encode_single_token special_token_set = getattr(self._encoding, "special_tokens_set", set()) except AttributeError as exc: - raise RuntimeError("tiktoken encoding missing expected interfaces.") from exc + raise RuntimeError( + "tiktoken encoding missing expected interfaces." + ) from exc mapping = {} for token in SPECIAL_TOKENS: diff --git a/plato/trainers/nanochat.py b/plato/trainers/nanochat.py index 5415cb6bb..aa025f41a 100644 --- a/plato/trainers/nanochat.py +++ b/plato/trainers/nanochat.py @@ -55,7 +55,9 @@ def create_train_loader( num_workers=0, collate_fn=_first_element_collate, ) - return self._fallback.create_train_loader(trainset, sampler, batch_size, context) + return self._fallback.create_train_loader( + trainset, sampler, batch_size, context + ) class NanochatTrainingStepStrategy(TrainingStepStrategy): @@ -112,7 +114,9 @@ def state_dict(self) -> dict[str, Any]: def load_state_dict(self, state_dict: dict[str, Any]) -> None: for optimizer, payload in zip( - self.optimizers, state_dict.get("optimizers", []), strict=False # type: ignore[arg-type] + self.optimizers, + state_dict.get("optimizers", []), + strict=False, # type: ignore[arg-type] ): optimizer.load_state_dict(payload) @@ -262,7 +266,9 @@ def __init__( data_loader_strategy = NanochatDataLoaderStrategy() evaluation_cfg = getattr(Config(), "evaluation", None) - evaluation_type = getattr(evaluation_cfg, "type", "").lower() if evaluation_cfg else "" + evaluation_type = ( + getattr(evaluation_cfg, "type", "").lower() if evaluation_cfg else "" + ) if evaluation_type == "nanochat_core": max_per_task = getattr(evaluation_cfg, "max_per_task", -1) max_per_task_value = -1 if max_per_task is None else int(max_per_task) diff --git a/plato/trainers/registry.py b/plato/trainers/registry.py index 977d04fd2..ef47a287e 100644 --- a/plato/trainers/registry.py +++ b/plato/trainers/registry.py @@ -10,9 +10,11 @@ basic, composable, gan, - nanochat as nanochat_trainer, split_learning, ) +from plato.trainers import ( + nanochat as nanochat_trainer, +) registered_trainers = { "composable": composable.ComposableTrainer, diff --git a/tests/test_nanochat_integration.py b/tests/test_nanochat_integration.py index 4a60cae36..4af61fe8a 100644 --- a/tests/test_nanochat_integration.py +++ b/tests/test_nanochat_integration.py @@ -9,7 +9,6 @@ from plato.config import Config, ConfigNode - pytestmark = pytest.mark.integration _RUSTBPE_AVAILABLE = importlib.util.find_spec("rustbpe") is not None @@ -91,9 +90,7 @@ def test_nanochat_trainer_smoke(temp_config, tmp_path): assert elapsed >= 0.0 model_dir = Config().params["model_path"] - checkpoint_name = ( - f"{cfg.trainer.model_name}_{trainer.client_id}_{Config().params['run_id']}.safetensors" - ) + checkpoint_name = f"{cfg.trainer.model_name}_{trainer.client_id}_{Config().params['run_id']}.safetensors" assert os.path.exists(os.path.join(model_dir, checkpoint_name)) @@ -105,6 +102,8 @@ def test_nanochat_trainer_selects_core_eval_strategy(temp_config, monkeypatch): from plato.models.nanochat import Model as NanochatModel from plato.trainers.nanochat import ( NanochatCoreTestingStrategy, + ) + from plato.trainers.nanochat import ( Trainer as NanochatTrainer, ) From ac5bebaa932d29936367c0a8d9ea844663a6fa47 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Wed, 29 Oct 2025 16:54:51 -0400 Subject: [PATCH 04/18] Added benchmark configuration ([evaluation]) support in config.py. --- plato/config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/plato/config.py b/plato/config.py index ec7d840a0..56f35cfa8 100644 --- a/plato/config.py +++ b/plato/config.py @@ -153,6 +153,7 @@ class Config: clients: Any server: Any data: Any + evaluation: Any trainer: Any algorithm: Any results: Any @@ -342,6 +343,10 @@ def __new__(cls): Config.params["base_path"], "data" ) + # User-defined evaluation configuration + if hasattr(config, "evaluation"): + Config.evaluation = config.evaluation + # Pretrained models if hasattr(Config().server, "model_path"): Config.params["model_path"] = os.path.join( From 4501781bc99aadb038946218bbab4e91c2f1246d Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Wed, 29 Oct 2025 16:56:15 -0400 Subject: [PATCH 05/18] Added test to verify that [evaluation] configuration is properly loaded. --- tests/test_config_loader.py | 52 +++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index c22b0875c..01952518a 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -157,3 +157,55 @@ def test_config_base_path_used_without_cli_override(tmp_path: Path, monkeypatch) if hasattr(Config, "args"): delattr(Config, "args") Config._cli_overrides = {} + + +def test_config_loads_evaluation_section(tmp_path: Path, monkeypatch): + """Test that [evaluation] configuration is properly loaded.""" + config_base = tmp_path / "runtime" + config_path = tmp_path / "config.toml" + + config_data = { + "clients": {"type": "simple", "total_clients": 2, "per_round": 1}, + "server": {"address": "127.0.0.1", "port": 8000}, + "data": {"datasource": "MNIST"}, + "trainer": {"type": "basic", "rounds": 1, "epochs": 1, "batch_size": 10}, + "algorithm": {"type": "fedavg"}, + "evaluation": { + "type": "nanochat_core", + "max_per_task": 128, + "bundle_dir": "/custom/path/to/nanochat", + }, + } + + toml_writer.dump(config_data, config_path) + + monkeypatch.delenv("config_file", raising=False) + monkeypatch.setattr( + sys, + "argv", + [ + sys.argv[0], + "--config", + str(config_path), + "--base", + str(config_base), + ], + ) + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} + + config = Config() + + assert hasattr(config, "evaluation") + assert config.evaluation.type == "nanochat_core" + assert config.evaluation.max_per_task == 128 + assert config.evaluation.bundle_dir == "/custom/path/to/nanochat" + assert config_base.is_dir() + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} From a8efbeacf91ea7db0951ee227a71d2ec980441c0 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Wed, 29 Oct 2025 17:52:47 -0400 Subject: [PATCH 06/18] Fixed tensor contiguity issue in datasource. - Resolved a RuntimeError caused by non-contiguous tensors during view operations (in nanochat - gpt.py): "view size is not compatible with input tensor's size and stride...". Replaced .view() with .reshape() --- plato/datasources/nanochat.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/plato/datasources/nanochat.py b/plato/datasources/nanochat.py index 3a0f9201a..9718184cc 100644 --- a/plato/datasources/nanochat.py +++ b/plato/datasources/nanochat.py @@ -122,8 +122,8 @@ def _synthetic_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: dtype=torch.long, generator=generator, ) - inputs = tokens[:, :-1].to(dtype=torch.long) - targets = tokens[:, 1:].to(dtype=torch.long) + inputs = tokens[:, :-1].contiguous() + targets = tokens[:, 1:].contiguous() yield inputs, targets def _parquet_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: @@ -141,7 +141,10 @@ def _parquet_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: device=self.device, ) for inputs, targets in loader: - yield inputs.to(dtype=torch.long), targets.to(dtype=torch.long) + yield ( + inputs.to(dtype=torch.long).contiguous(), + targets.to(dtype=torch.long).contiguous(), + ) def __iter__(self): iterable = ( From f0bc22d2fe7313da0382e4ce4e42de77e21492d8 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Wed, 29 Oct 2025 17:54:12 -0400 Subject: [PATCH 07/18] Fixed KeyError: 'train_loss'. - Resolved an issue where the configuration requested 'train_loss' in the results, but the server's get_logged_items() did not include it. --- plato/clients/strategies/defaults.py | 13 +++++++++++++ plato/servers/fedavg.py | 18 +++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/plato/clients/strategies/defaults.py b/plato/clients/strategies/defaults.py index 7b104c163..a7a54c86d 100644 --- a/plato/clients/strategies/defaults.py +++ b/plato/clients/strategies/defaults.py @@ -383,6 +383,18 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: except TypeError: num_samples = 0 + # Extract train_loss from trainer's run_history if available + train_loss = None + if ( + context.trainer is not None + and hasattr(context.trainer, "run_history") + and context.trainer.run_history is not None + ): + try: + train_loss = context.trainer.run_history.get_latest_metric("train_loss") + except (AttributeError, KeyError, IndexError): + train_loss = None + report = SimpleNamespace( client_id=context.client_id, num_samples=num_samples, @@ -390,6 +402,7 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: training_time=training_time, comm_time=time.time(), update_response=False, + train_loss=train_loss, ) return report, weights diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index ee636cb49..1f5205f80 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -273,7 +273,7 @@ def clients_processed(self) -> None: def get_logged_items(self) -> dict: """Get items to be logged by the LogProgressCallback class in a .csv file.""" - return { + logged = { "round": self.current_round, "accuracy": self.accuracy, "accuracy_std": self.accuracy_std, @@ -291,6 +291,22 @@ def get_logged_items(self) -> dict: "comm_overhead": self.comm_overhead, } + # Add train_loss if available from client reports + if self.updates and hasattr(self.updates[0].report, "train_loss"): + # Compute weighted average of train_loss across clients + total_samples = sum(update.report.num_samples for update in self.updates) + if total_samples > 0: + weighted_loss = sum( + update.report.train_loss * update.report.num_samples + for update in self.updates + if update.report.train_loss is not None + ) + logged["train_loss"] = weighted_loss / total_samples + else: + logged["train_loss"] = 0.0 + + return logged + @staticmethod def get_accuracy_mean_std(updates): """Compute the accuracy mean and standard deviation across clients.""" From 48100806fd1c8804dcb5f3540ab57a7c04f8683c Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 30 Oct 2025 00:21:56 -0400 Subject: [PATCH 08/18] Fixed train_loss aggregation in FedAvg server to handle None values. --- plato/servers/fedavg.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index 1f5205f80..e9f9c0fde 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -294,7 +294,11 @@ def get_logged_items(self) -> dict: # Add train_loss if available from client reports if self.updates and hasattr(self.updates[0].report, "train_loss"): # Compute weighted average of train_loss across clients - total_samples = sum(update.report.num_samples for update in self.updates) + total_samples = sum( + update.report.num_samples + for update in self.updates + if update.report.train_loss is not None + ) if total_samples > 0: weighted_loss = sum( update.report.train_loss * update.report.num_samples @@ -303,7 +307,7 @@ def get_logged_items(self) -> dict: ) logged["train_loss"] = weighted_loss / total_samples else: - logged["train_loss"] = 0.0 + logged["train_loss"] = None return logged From eb736ebde6cb967a389fe5fc1f70304f2f6d6455 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 30 Oct 2025 00:23:48 -0400 Subject: [PATCH 09/18] Added evaluation configs for nanochat CORE metric. --- configs/Nanochat/synthetic_micro.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/configs/Nanochat/synthetic_micro.toml b/configs/Nanochat/synthetic_micro.toml index b22941f53..abc3a6452 100644 --- a/configs/Nanochat/synthetic_micro.toml +++ b/configs/Nanochat/synthetic_micro.toml @@ -26,6 +26,11 @@ device = "cpu" vocab_size = 512 synthetic_seed = 123 +[evaluation] +type = "nanochat_core" +# bundle_dir = "~/nanochat" # Optional, defaults to nanochat base dir or Plato's data directory +max_per_task = 16 # Optional, -1 means run all examples + [trainer] type = "nanochat" rounds = 1 @@ -46,4 +51,4 @@ n_kv_head = 4 n_embd = 256 [results] -types = "round, elapsed_time, train_loss" +types = "round, elapsed_time, core_metric, train_loss" From d9fe94a949cbc16f02725006fb8cca138182b013 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 30 Oct 2025 00:44:50 -0400 Subject: [PATCH 10/18] Added automatic download of nanochat CORE evaluation bundle. --- plato/evaluators/nanochat_core.py | 168 +++++++++++++++++++++++++++--- 1 file changed, 151 insertions(+), 17 deletions(-) diff --git a/plato/evaluators/nanochat_core.py b/plato/evaluators/nanochat_core.py index 74a7c1447..edac1a898 100644 --- a/plato/evaluators/nanochat_core.py +++ b/plato/evaluators/nanochat_core.py @@ -4,13 +4,20 @@ from __future__ import annotations +import contextlib import csv +import gzip import json import logging +import os import random +import sys +import tarfile import time +import zipfile from pathlib import Path from typing import Any +from urllib.parse import urlparse try: import torch @@ -20,42 +27,169 @@ "Install the `nanochat` extra (includes torch)." ) from exc +import requests import yaml +from plato.config import Config from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable LOGGER = logging.getLogger(__name__) +# URL for the CORE evaluation bundle +EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" + + +@contextlib.contextmanager +def _download_guard(data_path: str): + """Serialize dataset downloads to avoid concurrent corruption.""" + os.makedirs(data_path, exist_ok=True) + lock_file = os.path.join(data_path, ".download.lock") + lock_fd = None + waited = False + + try: + while True: + try: + lock_fd = os.open(lock_file, os.O_CREAT | os.O_EXCL | os.O_RDWR) + break + except FileExistsError: + if not waited: + LOGGER.info( + "Another process is preparing the dataset at %s. Waiting.", + data_path, + ) + waited = True + time.sleep(1) + yield + finally: + if lock_fd is not None: + os.close(lock_fd) + try: + os.remove(lock_file) + except FileNotFoundError: + pass + + +def _download_eval_bundle(url: str, data_path: str) -> None: + """Download the CORE evaluation bundle from a URL if not already available.""" + url_parse = urlparse(url) + file_name = os.path.join(data_path, url_parse.path.split("/")[-1]) + os.makedirs(data_path, exist_ok=True) + sentinel = Path(f"{file_name}.complete") + + if sentinel.exists(): + return + + with _download_guard(data_path): + if sentinel.exists(): + return + + LOGGER.info("Downloading CORE evaluation bundle from %s.", url) + + res = requests.get(url, stream=True, timeout=60) + total_size = int(res.headers.get("Content-Length", 0)) + downloaded_size = 0 + + with open(file_name, "wb+") as file: + for chunk in res.iter_content(chunk_size=1024): + if not chunk: + continue + downloaded_size += len(chunk) + file.write(chunk) + file.flush() + if total_size: + sys.stdout.write(f"\r{100 * downloaded_size / total_size:.1f}%") + sys.stdout.flush() + if total_size: + sys.stdout.write("\n") + + # Unzip the compressed file just downloaded + LOGGER.info("Decompressing the CORE evaluation bundle.") + name, suffix = os.path.splitext(file_name) + + if file_name.endswith("tar.gz"): + with tarfile.open(file_name, "r:gz") as tar: + tar.extractall(data_path) + os.remove(file_name) + elif suffix == ".zip": + LOGGER.info("Extracting %s to %s.", file_name, data_path) + with zipfile.ZipFile(file_name, "r") as zip_ref: + zip_ref.extractall(data_path) + os.remove(file_name) + elif suffix == ".gz": + with gzip.open(file_name, "rb") as zipped_file: + with open(name, "wb") as unzipped_file: + unzipped_file.write(zipped_file.read()) + os.remove(file_name) + else: + LOGGER.warning("Unknown compressed file type for %s.", file_name) + + sentinel.touch() + LOGGER.info("CORE evaluation bundle downloaded and extracted successfully.") + def _resolve_bundle_paths(bundle_dir: str | Path | None) -> tuple[Path, Path, Path]: """Resolve the configuration, metadata, and dataset paths for CORE evaluation.""" ensure_nanochat_importable() - from nanochat.common import get_base_dir # pylint: disable=import-error - if bundle_dir is None: - base_path = Path(get_base_dir()) + def _get_default_base_path() -> Path: + """Get the default base path, trying nanochat first, then Plato's data directory.""" + try: + from nanochat.common import get_base_dir # pylint: disable=import-error + + return Path(get_base_dir()) + except (ImportError, OSError, PermissionError): + plato_data_path = Config().params.get("data_path", "./runtime/data") + path = Path(plato_data_path) / "nanochat" + LOGGER.info("Using Plato data directory for CORE bundle: %s", path) + return path + + # Determine base path + if bundle_dir is not None: + try: + base_path = Path(bundle_dir).expanduser().resolve() + LOGGER.info("Using bundle_dir from config: %s", base_path) + except (OSError, PermissionError, ValueError) as exc: + LOGGER.warning( + "Cannot use bundle_dir '%s': %s. Using default location.", + bundle_dir, + exc, + ) + base_path = _get_default_base_path() else: - base_path = Path(bundle_dir).expanduser().resolve() + base_path = _get_default_base_path() + + # Ensure base path exists + base_path.mkdir(parents=True, exist_ok=True) eval_bundle_dir = base_path / "eval_bundle" config_path = eval_bundle_dir / "core.yaml" data_dir = eval_bundle_dir / "eval_data" metadata_path = eval_bundle_dir / "eval_meta_data.csv" - if not config_path.exists(): - raise FileNotFoundError( - f"CORE evaluation config not found at {config_path}. " - "Ensure the Nanochat eval bundle is downloaded." - ) - if not data_dir.exists(): - raise FileNotFoundError( - f"CORE evaluation data directory not found at {data_dir}. " - "Ensure the Nanochat eval bundle is downloaded." - ) - if not metadata_path.exists(): - raise FileNotFoundError( - f"CORE evaluation metadata CSV not found at {metadata_path}." + # Check if evaluation bundle exists, download if missing + if not config_path.exists() or not data_dir.exists() or not metadata_path.exists(): + LOGGER.info( + "CORE evaluation bundle not found at %s. Downloading automatically...", + base_path, ) + _download_eval_bundle(EVAL_BUNDLE_URL, str(base_path)) + + # Verify download succeeded + if not config_path.exists(): + raise FileNotFoundError( + f"CORE evaluation config not found at {config_path}. " + "Ensure the Nanochat eval bundle is downloaded." + ) + if not data_dir.exists(): + raise FileNotFoundError( + f"CORE evaluation data directory not found at {data_dir}. " + "Ensure the Nanochat eval bundle is downloaded." + ) + if not metadata_path.exists(): + raise FileNotFoundError( + f"CORE evaluation metadata CSV not found at {metadata_path}." + ) return config_path, data_dir, metadata_path From 6f349500331933beeb5866407af8e1ee3f9f347f Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 30 Oct 2025 00:50:01 -0400 Subject: [PATCH 11/18] Using tokenizer's vocab_size to match between model and tokenizer. - To avoid vocabulary size mismatch between model and tokenizer during CORE evaluation. --- plato/models/nanochat.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/plato/models/nanochat.py b/plato/models/nanochat.py index afd8d4a59..429d034b2 100644 --- a/plato/models/nanochat.py +++ b/plato/models/nanochat.py @@ -6,6 +6,7 @@ from dataclasses import fields from typing import Any +import logging try: import torch @@ -126,9 +127,47 @@ def get(model_name: str | None = None, **kwargs: Any): ) return model + # Model vocab_size MUST match tokenizer vocab_size to avoid IndexError + try: + from nanochat.tokenizer import get_tokenizer + + tokenizer = get_tokenizer() + actual_vocab_size = tokenizer.get_vocab_size() + + # Override vocab_size with tokenizer's actual vocab_size + if "vocab_size" in config_kwargs: + configured_vocab = config_kwargs["vocab_size"] + if configured_vocab != actual_vocab_size: + logging.warning( + f"[Nanochat Model] Config specifies vocab_size={configured_vocab}, " + f"but tokenizer has vocab_size={actual_vocab_size}. " + f"Using tokenizer's vocab_size={actual_vocab_size} to match tokenizer." + ) + config_kwargs["vocab_size"] = actual_vocab_size + + logging.info( + f"[Nanochat Model] Using vocab_size={actual_vocab_size} from tokenizer" + ) + except Exception as e: + logging.warning( + f"[Nanochat Model] Could not auto-detect vocab_size from tokenizer: {e}. " + f"Using configured or default vocab_size." + ) + config = GPTConfig(**config_kwargs) model = GPT(config) if init_weights: model.init_weights() + + # This allows CORE evaluation and other components to access the tokenizer + try: + from nanochat.tokenizer import get_tokenizer + + tokenizer = get_tokenizer() + setattr(model, "nanochat_tokenizer", tokenizer) + except Exception: + pass + # Set max_seq_len for CORE evaluation truncation + setattr(model, "max_seq_len", config.sequence_len) setattr(model, "nanochat_config", config_kwargs) return model From 35f25ebcc7b6f063d329b608a8df29679e2f0008 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 30 Oct 2025 00:55:24 -0400 Subject: [PATCH 12/18] Added outputs for Nanochat CORE evaluation in FedAvg server. --- plato/servers/fedavg.py | 12 ++++++++++++ tests/test_config_loader.py | 1 - 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index e9f9c0fde..fe001cb34 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -252,6 +252,14 @@ async def _process_reports(self): trainer = self.require_trainer() self.accuracy = trainer.test(self.testset, self.testset_sampler) + # Extract CORE evaluation results if available (Nanochat CORE evaluation) + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + self._core_metric = core_results.get("core_metric", self.accuracy) + if hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( @@ -309,6 +317,10 @@ def get_logged_items(self) -> dict: else: logged["train_loss"] = None + # Add core_metric if Nanochat CORE evaluation was performed + if hasattr(self, "_core_metric"): + logged["core_metric"] = self._core_metric + return logged @staticmethod diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 01952518a..2b1bb9b75 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -203,7 +203,6 @@ def test_config_loads_evaluation_section(tmp_path: Path, monkeypatch): assert config.evaluation.type == "nanochat_core" assert config.evaluation.max_per_task == 128 assert config.evaluation.bundle_dir == "/custom/path/to/nanochat" - assert config_base.is_dir() Config._instance = None if hasattr(Config, "args"): From e4ae761265364126ff6c0f67c8359a8e702de685 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Fri, 31 Oct 2025 11:08:39 -0400 Subject: [PATCH 13/18] Added specific logging output for CORE benchmark metrics. - Updated log message from "global accuracy" to "Average Centered CORE benchmark metric" - Used ruff to format code --- cleanup.py | 8 ++------ plato/servers/fedavg.py | 12 +++++++++++- plato/servers/fedavg_cs.py | 34 +++++++++++++++++++++++++++++++-- plato/servers/split_learning.py | 27 ++++++++++++++++++++++---- 4 files changed, 68 insertions(+), 13 deletions(-) diff --git a/cleanup.py b/cleanup.py index 68c658019..f03f5f3cf 100644 --- a/cleanup.py +++ b/cleanup.py @@ -146,18 +146,14 @@ def main() -> None: continue cleared = clean_directory(runtime_dir) - print( - f"Failed to delete {runtime_dir}; cleared {cleared} items instead." - ) + print(f"Failed to delete {runtime_dir}; cleared {cleared} items instead.") fallback_dirs += 1 fallback_items += cleared if runtime_total == 0: print("No runtime directories found.") else: - print( - f"Removed {runtime_removed} of {runtime_total} runtime directories." - ) + print(f"Removed {runtime_removed} of {runtime_total} runtime directories.") if fallback_dirs: print( f"Cleared {fallback_items} items in " diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index fe001cb34..36fee944b 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -260,7 +260,16 @@ async def _process_reports(self): core_results = trainer.context.state["nanochat_core_results"] self._core_metric = core_results.get("core_metric", self.accuracy) - if hasattr(Config().trainer, "target_perplexity"): + # If CORE benchmark was run via a Nanochat testing strategy, report the specialized CORE metric instead of the generic 'Global model accuracy' label. + core_metric = getattr(self, "_core_metric", None) + + if core_metric is not None: + logging.info( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + elif hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" @@ -284,6 +293,7 @@ def get_logged_items(self) -> dict: logged = { "round": self.current_round, "accuracy": self.accuracy, + "core_metric": getattr(self, "_core_metric", None), "accuracy_std": self.accuracy_std, "elapsed_time": self.wall_time - self.initial_wall_time, "processing_time": max( diff --git a/plato/servers/fedavg_cs.py b/plato/servers/fedavg_cs.py index fceff363a..eaba3caf6 100644 --- a/plato/servers/fedavg_cs.py +++ b/plato/servers/fedavg_cs.py @@ -253,7 +253,22 @@ async def _process_reports(self): logging.info("[%s] Started model testing.", self) self.accuracy = trainer.test(self.testset, self.testset_sampler) - if hasattr(Config().trainer, "target_perplexity"): + # Extract CORE evaluation results if available (Nanochat CORE evaluation) + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + self._core_metric = core_results.get("core_metric", None) + + core_metric = getattr(self, "_core_metric", None) + if core_metric is not None: + logging.info( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + elif hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" @@ -274,7 +289,22 @@ async def _process_reports(self): logging.info("[%s] Started model testing.", self) self.accuracy = trainer.test(self.testset, self.testset_sampler) - if hasattr(Config().trainer, "target_perplexity"): + # Extract CORE evaluation results if available (Nanochat CORE evaluation) + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + self._core_metric = core_results.get("core_metric", None) + + core_metric = getattr(self, "_core_metric", None) + if core_metric is not None: + logging.info( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + elif hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" diff --git a/plato/servers/split_learning.py b/plato/servers/split_learning.py index 83cf2804a..49fdc2d4c 100644 --- a/plato/servers/split_learning.py +++ b/plato/servers/split_learning.py @@ -99,11 +99,30 @@ async def aggregate_weights(self, updates, baseline_weights, weights_received): self.test_accuracy = trainer.test(self.testset, self.testset_sampler) - logging.warning( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + core_metric = core_results.get("core_metric", None) + if core_metric is not None: + logging.warning( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + else: + logging.warning( + fonts.colourize( + f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" + ) + ) + else: + logging.warning( + fonts.colourize( + f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" + ) ) - ) self.phase = "prompt" # Change client in next round self.next_client = True From 432fe50d3ffea35a7ee653b603439230c592fb73 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Fri, 7 Nov 2025 23:41:41 +0000 Subject: [PATCH 14/18] Typed the Nanochat datasource/optimizer plumbing and enforced valid CORE metadata so ty check is clean again. --- plato/datasources/nanochat.py | 4 ++-- plato/evaluators/nanochat_core.py | 11 +++++++++-- plato/trainers/nanochat.py | 8 ++++++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/plato/datasources/nanochat.py b/plato/datasources/nanochat.py index 9718184cc..7e53c71c6 100644 --- a/plato/datasources/nanochat.py +++ b/plato/datasources/nanochat.py @@ -222,7 +222,7 @@ def __init__( "parquet" if _parquet_available(resolved_base_dir) else "synthetic" ) - self.trainset = NanochatStreamingDataset( + self.trainset: NanochatStreamingDataset = NanochatStreamingDataset( split="train", batch_size=resolved_batch_size, sequence_length=resolved_sequence_len, @@ -235,7 +235,7 @@ def __init__( vocab_size=vocab_size, synthetic_seed=synthetic_seed, ) - self.testset = NanochatStreamingDataset( + self.testset: NanochatStreamingDataset = NanochatStreamingDataset( split="val", batch_size=resolved_batch_size, sequence_length=resolved_sequence_len, diff --git a/plato/evaluators/nanochat_core.py b/plato/evaluators/nanochat_core.py index edac1a898..97a49e9f2 100644 --- a/plato/evaluators/nanochat_core.py +++ b/plato/evaluators/nanochat_core.py @@ -308,15 +308,22 @@ def run_core_evaluation( LOGGER.debug("Skipping unnamed CORE task entry: %s", task) continue + dataset_uri = task.get("dataset_uri") + if not isinstance(dataset_uri, str): + LOGGER.debug( + "Skipping CORE task %s due to missing dataset_uri metadata.", task + ) + continue + task_meta = { "task_type": task.get("icl_task_type"), - "dataset_uri": task.get("dataset_uri"), + "dataset_uri": dataset_uri, "num_fewshot": task.get("num_fewshot", [0])[0], "continuation_delimiter": task.get("continuation_delimiter", " "), } start_time = time.perf_counter() - data = _load_task_data(data_dir, task_meta["dataset_uri"]) + data = _load_task_data(data_dir, dataset_uri) shuffle_rng = random.Random(1337) shuffle_rng.shuffle(data) if max_per_task > 0: diff --git a/plato/trainers/nanochat.py b/plato/trainers/nanochat.py index aa025f41a..307ccf911 100644 --- a/plato/trainers/nanochat.py +++ b/plato/trainers/nanochat.py @@ -153,12 +153,16 @@ def __init__( def create_optimizer( self, model: torch.nn.Module, context: TrainingContext ) -> _OptimizerBundle: - if not hasattr(model, "setup_optimizers"): + if not isinstance(model, torch.nn.Module): + raise TypeError("Nanochat optimizer strategy requires a torch.nn.Module.") + + setup_fn = getattr(model, "setup_optimizers", None) + if not callable(setup_fn): raise AttributeError( "Nanochat model is expected to expose setup_optimizers()." ) - optimizers = model.setup_optimizers( + optimizers = setup_fn( unembedding_lr=self.unembedding_lr, embedding_lr=self.embedding_lr, matrix_lr=self.matrix_lr, From da04815748868e07d3d917dace28c2cd7c444331 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Fri, 7 Nov 2025 23:53:13 +0000 Subject: [PATCH 15/18] All nanochat tests now pass. --- pyproject.toml | 20 ++++++++++---------- tests/test_nanochat_integration.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e12d65cbc..e1e172a6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,16 +48,16 @@ mlx = [ "mlx", ] nanochat = [ - "datasets>=2.14.0", - "numpy==1.26.4", - "psutil>=5.9.0", - "regex>=2023.8.8", - "tiktoken>=0.5.0", - "tokenizers>=0.15.0", - "torch>=2.2.0", - "wandb>=0.16.0", - "jinja2>=3.0", - "PyYAML>=6.0", + "datasets", + "numpy", + "psutil", + "regex", + "tiktoken", + "tokenizers", + "torch", + "wandb", + "jinja2", + "PyYAML", ] [project.urls] diff --git a/tests/test_nanochat_integration.py b/tests/test_nanochat_integration.py index 4af61fe8a..6a9c174e8 100644 --- a/tests/test_nanochat_integration.py +++ b/tests/test_nanochat_integration.py @@ -8,6 +8,7 @@ import pytest from plato.config import Config, ConfigNode +from plato.utils.third_party import ensure_nanochat_importable pytestmark = pytest.mark.integration @@ -41,6 +42,11 @@ def test_nanochat_trainer_smoke(temp_config, tmp_path): _ = pytest.importorskip( "torch", reason="Nanochat trainer smoke requires torch (nanochat extra)." ) + ensure_nanochat_importable() + _ = pytest.importorskip( + "nanochat", + reason="Nanochat trainer smoke requires the nanochat package (nanochat extra).", + ) from plato.datasources.nanochat import DataSource as NanochatDataSource from plato.models.nanochat import Model as NanochatModel from plato.trainers.nanochat import Trainer as NanochatTrainer @@ -99,6 +105,11 @@ def test_nanochat_trainer_selects_core_eval_strategy(temp_config, monkeypatch): _ = pytest.importorskip( "torch", reason="Nanochat trainer requires torch (nanochat extra)." ) + ensure_nanochat_importable() + _ = pytest.importorskip( + "nanochat", + reason="Nanochat trainer requires the nanochat package (nanochat extra).", + ) from plato.models.nanochat import Model as NanochatModel from plato.trainers.nanochat import ( NanochatCoreTestingStrategy, From 279d05ecb6eeed0d79e475b2c69cbb21c020c52c Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Sat, 8 Nov 2025 01:38:01 -0500 Subject: [PATCH 16/18] Updated nanochat README with setup and troubleshooting notes. - Added instructions for initializing submodules and resolving maturin build failure. --- examples/nanochat/README.md | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/nanochat/README.md b/examples/nanochat/README.md index 16ae42fc5..1d061345b 100644 --- a/examples/nanochat/README.md +++ b/examples/nanochat/README.md @@ -4,14 +4,26 @@ This workspace hosts Nanochat-focused experiments within Plato. ## Quick Start -1. Install dependencies (including the vendored tokenizer build requirements): +1. Initialize the nanochat submodule (required for the nanochat integration): + + ```bash + git submodule update --init --recursive + ``` + +2. Install dependencies (including the vendored tokenizer build requirements): ```bash uv sync --extra nanochat uv run --with ./external/nanochat maturin develop --release ``` + **Troubleshooting:** If you encounter a `maturin failed` error with "Can't find Cargo.toml", run the maturin command from within the nanochat directory: + + ```bash + uv sync --extra nanochat + cd external/nanochat && uv run maturin develop --release && cd ../.. + ``` -2. Run the synthetic smoke configuration: +3. Run the synthetic smoke configuration: ```bash uv run --extra nanochat python plato.py --config configs/Nanochat/synthetic_micro.toml From af0bafab579b7dff754c154dad523dd55e34e8b7 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 13 Nov 2025 18:37:09 +0000 Subject: [PATCH 17/18] Added configuration file for NanoChat Parquet mode. --- configs/Nanochat/parquet_micro.toml | 53 +++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 configs/Nanochat/parquet_micro.toml diff --git a/configs/Nanochat/parquet_micro.toml b/configs/Nanochat/parquet_micro.toml new file mode 100644 index 000000000..cfe09a921 --- /dev/null +++ b/configs/Nanochat/parquet_micro.toml @@ -0,0 +1,53 @@ +[clients] +type = "simple" +total_clients = 10 +per_round = 3 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/nanochat/parquet" +model_path = "models/nanochat/parquet" + +[data] +datasource = "Nanochat" +sampler = "iid" +partition_size = 1 +random_seed = 1 +mode = "parquet" +max_train_batches = 16 +max_val_batches = 1 +tokenizer_threads = 2 +tokenizer_batch_size = 32 +device = "cuda" +vocab_size = 512 +synthetic_seed = 123 + +[evaluation] +type = "nanochat_core" +# bundle_dir = "~/nanochat" +max_per_task = 16 + +[trainer] +type = "nanochat" +rounds = 10000 +epochs = 5 +batch_size = 1 +model_name = "nanochat" +optimizer = "nanochat" + +[algorithm] +type = "fedavg" + +[parameters.model] +sequence_len = 256 +vocab_size = 50304 +n_layer = 4 +n_head = 4 +n_kv_head = 4 +n_embd = 256 + +[results] +types = "round, elapsed_time, core_metric, train_loss" From 2b7cf3d6832dc85c967835ec4738f4049a0a7661 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 13 Nov 2025 20:37:13 +0000 Subject: [PATCH 18/18] Formatted code with Ruff and applied autofixes. --- .../model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py | 3 ++- plato/models/nanochat.py | 2 +- pyproject.toml | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py index 487547209..e8d700c97 100644 --- a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py +++ b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py @@ -22,7 +22,6 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F -from data import build_loader from misc.config import get_config from misc.loss_ops import AdaptiveLossSoft from misc.lr_scheduler import build_scheduler @@ -37,6 +36,8 @@ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.utils import AverageMeter, ModelEma, accuracy +from data import build_loader + try: from apex import amp except ImportError: diff --git a/plato/models/nanochat.py b/plato/models/nanochat.py index 429d034b2..114a891f9 100644 --- a/plato/models/nanochat.py +++ b/plato/models/nanochat.py @@ -4,9 +4,9 @@ from __future__ import annotations +import logging from dataclasses import fields from typing import Any -import logging try: import torch diff --git a/pyproject.toml b/pyproject.toml index e1e172a6c..228545350 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ members = [ [dependency-groups] dev = [ "pytest", + "ruff", "ty", ]