From 3355698fb0065472444ecdbfcce5a7d22515cf1f Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Tue, 28 Oct 2025 08:51:05 -0400 Subject: [PATCH 01/33] Initial support for the Nanochat model and its evaluation benchmark (core_eval.py). --- configs/Nanochat/synthetic_micro.toml | 49 +++++ docs/nanochat_integration_checklist.md | 55 +++++ docs/third_party.md | 19 ++ examples/nanochat/README.md | 49 +++++ examples/nanochat/pyproject.toml | 7 + plato/datasources/nanochat.py | 262 +++++++++++++++++++++++ plato/datasources/registry.py | 2 + plato/evaluators/__init__.py | 3 + plato/evaluators/nanochat_core.py | 217 +++++++++++++++++++ plato/models/nanochat.py | 135 ++++++++++++ plato/models/registry.py | 2 + plato/processors/nanochat_tokenizer.py | 133 ++++++++++++ plato/trainers/nanochat.py | 285 +++++++++++++++++++++++++ plato/trainers/registry.py | 2 + plato/utils/third_party.py | 42 ++++ pyproject.toml | 13 ++ tests/test_nanochat_integration.py | 152 +++++++++++++ 17 files changed, 1427 insertions(+) create mode 100644 configs/Nanochat/synthetic_micro.toml create mode 100644 docs/nanochat_integration_checklist.md create mode 100644 docs/third_party.md create mode 100644 examples/nanochat/README.md create mode 100644 examples/nanochat/pyproject.toml create mode 100644 plato/datasources/nanochat.py create mode 100644 plato/evaluators/__init__.py create mode 100644 plato/evaluators/nanochat_core.py create mode 100644 plato/models/nanochat.py create mode 100644 plato/processors/nanochat_tokenizer.py create mode 100644 plato/trainers/nanochat.py create mode 100644 plato/utils/third_party.py create mode 100644 tests/test_nanochat_integration.py diff --git a/configs/Nanochat/synthetic_micro.toml b/configs/Nanochat/synthetic_micro.toml new file mode 100644 index 000000000..b22941f53 --- /dev/null +++ b/configs/Nanochat/synthetic_micro.toml @@ -0,0 +1,49 @@ +[clients] + +type = "simple" +total_clients = 1 +per_round = 1 +do_test = false + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/nanochat/synthetic" +model_path = "models/nanochat/synthetic" + +[data] +datasource = "Nanochat" +sampler = "iid" +partition_size = 1 +random_seed = 1 +mode = "synthetic" +max_train_batches = 4 +max_val_batches = 1 +tokenizer_threads = 2 +tokenizer_batch_size = 64 +device = "cpu" +vocab_size = 512 +synthetic_seed = 123 + +[trainer] +type = "nanochat" +rounds = 1 +epochs = 1 +batch_size = 2 +model_name = "nanochat" +optimizer = "nanochat" + +[algorithm] +type = "fedavg" + +[parameters.model] +sequence_len = 128 +vocab_size = 512 +n_layer = 2 +n_head = 4 +n_kv_head = 4 +n_embd = 256 + +[results] +types = "round, elapsed_time, train_loss" diff --git a/docs/nanochat_integration_checklist.md b/docs/nanochat_integration_checklist.md new file mode 100644 index 000000000..5fa411526 --- /dev/null +++ b/docs/nanochat_integration_checklist.md @@ -0,0 +1,55 @@ +# Nanochat Integration Checklist + +This checklist coordinates the work to incorporate the Nanochat stack into Plato. Owners are placeholder roles until specific engineers are assigned. + +## Third-Party Snapshot +- **Owner:** Infrastructure +- **Deliverables:** Vendor `runtime/third_party/nanochat` snapshot with commit hash in docs; add provenance blurb to `docs/third_party.md`. +- **Dependencies:** None. + +## Model Registry +- **Owner:** Modeling +- **Deliverables:** Implement `plato/models/nanochat.py` mirroring Nanochat GPT config, register entry in `plato/models/registry.py`, supply weight-loading utilities. +- **Status:** In progress – factory module and registry wiring landed. +- **Dependencies:** Third-party snapshot. + +## Tokenizer & Processor +- **Owner:** Infrastructure +- **Deliverables:** Package Rust BPE via Maturin optional extra; wrap as `plato/processors/nanochat_tokenizer.py` with lazy import and fallbacks; document build steps in README. +- **Status:** Prototype processor and optional dependency group landed; CI build integration remains TODO. +- **Dependencies:** Third-party snapshot, build tooling prototype. + +## Datasource +- **Owner:** Data +- **Deliverables:** Create `plato/datasources/nanochat.py` handling dataset acquisition and sharding; register in datasource registry; store license metadata. +- **Status:** In progress – streaming dataset with synthetic fallback available. +- **Dependencies:** Tokenizer availability. + +## Trainer & Algorithm +- **Owner:** Training +- **Deliverables:** Port Nanochat engine into `plato/trainers/nanochat.py`; add algorithm glue if federated coordination diverges; ensure checkpoint compatibility. +- **Status:** In progress – composable trainer wrapper with Nanochat-specific optimiser/loader strategies in place. +- **Dependencies:** Model registry entry, datasource. + +## Evaluation Strategy +- **Owner:** Evaluation +- **Deliverables:** Translate `nanochat/core_eval.py` into reusable evaluator hooked into Plato testing strategy; add pytest coverage with synthetic data. +- **Status:** CORE evaluation adapter hooked into trainer testing strategy; follow-up coverage to use real eval bundles outstanding. +- **Dependencies:** Model, tokenizer. + +## Configuration & Examples +- **Owner:** Product +- **Deliverables:** Author `configs/Nanochat/*.toml` scenarios and `examples/nanochat/` workspace; include reference scripts and documentation. +- **Status:** Synthetic micro config and workspace README published; larger-scale scenarios pending. +- **Dependencies:** Model, datasource, trainer. + +## Documentation & Release +- **Owner:** Docs +- **Deliverables:** Publish `docs/models/nanochat.md`, extend root README tables, add integration notes and changelog entry; outline hardware requirements. +- **Dependencies:** All prior tracks. + +## Validation +- **Owner:** QA +- **Deliverables:** Expand CI to compile tokenizer, run smoke train/eval, and enforce import order checks; record expected metrics in evaluation baselines. +- **Status:** Initial pytest smoke checks for tokenizer/trainer added; CI enablement still pending. +- **Dependencies:** Evaluation strategy, trainer. diff --git a/docs/third_party.md b/docs/third_party.md new file mode 100644 index 000000000..304d62b4e --- /dev/null +++ b/docs/third_party.md @@ -0,0 +1,19 @@ +# Third-Party Assets + +This page records external projects that are vendored into the Plato repository to support specific integrations. Please update the relevant entry whenever the upstream source, commit hash, or licensing information changes. + +## Nanochat +- **Upstream:** [karpathy/nanochat](https://github.com/karpathy/nanochat) +- **Vendored location:** `runtime/third_party/nanochat` +- **Snapshot commit:** `c75fe54aa7c1fa881701c246f9427bcbe4eee5a4` (captured 2025-03-04) +- **License:** MIT (included in `runtime/third_party/nanochat/LICENSE`) + +### Updating the Snapshot +1. `cd runtime/third_party/nanochat` +2. `git fetch origin && git checkout ` +3. Review upstream changes and confirm compatibility with Plato. +4. Record the new commit hash and date in this document, and call out notable changes in the integration checklist before landing. + +### Notes +- The Rust tokenizer (`rustbpe`) builds via `maturin`. Ensure `uv run --with ./runtime/third_party/nanochat maturin develop --release` succeeds before pushing updates. +- Keep the vendored tree free of local modifications unless backporting fixes; prefer upstream contributions when feasible. diff --git a/examples/nanochat/README.md b/examples/nanochat/README.md new file mode 100644 index 000000000..5b772c855 --- /dev/null +++ b/examples/nanochat/README.md @@ -0,0 +1,49 @@ +# Nanochat Integration Workspace + +This workspace hosts Nanochat-focused experiments within Plato. + +## Quick Start + +1. Install dependencies (including the vendored tokenizer build requirements): + + ```bash + uv sync --extra nanochat + uv run --with ./runtime/third_party/nanochat maturin develop --release + ``` + +2. Run the synthetic smoke configuration: + + ```bash + uv run --extra nanochat python plato.py --config configs/Nanochat/synthetic_micro.toml + ``` + + This launches a single-client training round using the Nanochat trainer, synthetic + token streams, and a downsized GPT configuration for CPU debugging. + +## CORE Evaluation + +The Nanochat trainer can invoke the upstream CORE benchmark by adding the section +below to your TOML configuration: + +```toml +[evaluation] +type = "nanochat_core" +max_per_task = 128 # optional; limits evaluation samples per task +# bundle_dir = "/custom/path/to/nanochat" # defaults to ~/.cache/nanochat +``` + +Make sure the official evaluation bundle has been downloaded so the following files +exist (the default location is `~/.cache/nanochat/eval_bundle`): + +- `core.yaml` +- `eval_data/*.jsonl` +- `eval_meta_data.csv` + +The provided `configs/Nanochat/synthetic_micro.toml` can be extended with the +`[evaluation]` block once those assets are present. + +## Roadmap + +- Integrate real Nanochat tokenized datasets and publish download helpers. +- Add baseline evaluation scripts leveraging `nanochat/core_eval.py`. +- Capture reproducible metrics and hardware notes for larger-scale runs. diff --git a/examples/nanochat/pyproject.toml b/examples/nanochat/pyproject.toml new file mode 100644 index 000000000..c0e19f09c --- /dev/null +++ b/examples/nanochat/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "plato-nanochat-examples" +version = "0.1.0" +description = "Nanochat integration examples for Plato." +readme = "README.md" +requires-python = ">=3.10" +dependencies = [] diff --git a/plato/datasources/nanochat.py b/plato/datasources/nanochat.py new file mode 100644 index 000000000..28c805943 --- /dev/null +++ b/plato/datasources/nanochat.py @@ -0,0 +1,262 @@ +""" +Streaming datasource backed by the vendored Nanochat project. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable + +try: + import torch + from torch.utils.data import IterableDataset +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat datasource requires PyTorch. " + "Install torch via the project's optional dependencies." + ) from exc + +from plato.config import Config +from plato.datasources.base import DataSource as BaseDataSource +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + +DEFAULT_VOCAB_SIZE = 50304 +DEFAULT_SEQUENCE_LENGTH = 2048 + + +def _resolve_base_dir(base_dir: str | Path | None) -> Path | None: + if base_dir is None: + return None + return Path(base_dir).expanduser().resolve() + + +def _parquet_available(base_dir: Path | None) -> bool: + try: + ensure_nanochat_importable() + from nanochat.dataset import list_parquet_files, DATA_DIR # type: ignore[attr-defined] + except (ThirdPartyImportError, ImportError): # pragma: no cover - defensive + return False + + if base_dir is not None: + candidate_dir = base_dir / "base_data" + if not candidate_dir.exists(): + return False + parquet_dir = candidate_dir + else: + parquet_dir = Path(DATA_DIR) + + try: + return len(list_parquet_files(str(parquet_dir))) > 0 + except FileNotFoundError: + return False + + +@dataclass +class _SyntheticState: + generator: torch.Generator + + +class NanochatStreamingDataset(IterableDataset): + """Iterable dataset yielding (inputs, targets) token tensors.""" + + def __init__( + self, + *, + split: str, + batch_size: int, + sequence_length: int, + mode: str, + base_dir: Path | None, + max_batches: int | None, + tokenizer_threads: int, + tokenizer_batch_size: int, + device: str, + vocab_size: int, + synthetic_seed: int, + ): + super().__init__() + if split not in {"train", "val"}: + raise ValueError("split must be 'train' or 'val'.") + + if mode not in {"auto", "parquet", "synthetic"}: + raise ValueError("mode must be 'auto', 'parquet', or 'synthetic'.") + + self.split = split + self.batch_size = batch_size + self.sequence_length = sequence_length + self.base_dir = base_dir + self.max_batches = max_batches + self.tokenizer_threads = tokenizer_threads + self.tokenizer_batch_size = tokenizer_batch_size + self.device = device + self.vocab_size = vocab_size + self.synthetic_seed = synthetic_seed + + resolved_mode = mode + if resolved_mode == "auto": + resolved_mode = "parquet" if _parquet_available(base_dir) else "synthetic" + self.mode = resolved_mode + self._synthetic_state: _SyntheticState | None = None + + # Configure Nanochat's base directory if provided. + if self.base_dir is not None: + os.environ.setdefault("NANOCHAT_BASE_DIR", str(self.base_dir)) + + def _synthetic_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: + if self._synthetic_state is None: + generator = torch.Generator() + generator.manual_seed(self.synthetic_seed) + self._synthetic_state = _SyntheticState(generator=generator) + + generator = self._synthetic_state.generator + while True: + tokens = torch.randint( + low=0, + high=self.vocab_size, + size=(self.batch_size, self.sequence_length + 1), + dtype=torch.long, + generator=generator, + ) + inputs = tokens[:, :-1].to(dtype=torch.long) + targets = tokens[:, 1:].to(dtype=torch.long) + yield inputs, targets + + def _parquet_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: + ensure_nanochat_importable() + from nanochat.dataloader import ( # type: ignore[attr-defined] + tokenizing_distributed_data_loader, + ) + + loader = tokenizing_distributed_data_loader( + self.batch_size, + self.sequence_length, + split=self.split, + tokenizer_threads=self.tokenizer_threads, + tokenizer_batch_size=self.tokenizer_batch_size, + device=self.device, + ) + for inputs, targets in loader: + yield inputs.to(dtype=torch.long), targets.to(dtype=torch.long) + + def __iter__(self): + iterable = ( + self._parquet_iterable() + if self.mode == "parquet" + else self._synthetic_iterable() + ) + + for batch_index, batch in enumerate(iterable): + if self.max_batches is not None and batch_index >= self.max_batches: + break + yield batch + + def __len__(self) -> int: + if self.max_batches is None: + raise TypeError("Streaming dataset does not have a finite length.") + return self.max_batches + + +class DataSource(BaseDataSource): # type: ignore[misc] + """Plato datasource exposing Nanochat token streams.""" + + def __init__( + self, + *, + batch_size: int | None = None, + sequence_length: int | None = None, + mode: str = "auto", + base_dir: str | Path | None = None, + max_train_batches: int | None = 64, + max_val_batches: int | None = 8, + tokenizer_threads: int = 4, + tokenizer_batch_size: int = 128, + device: str = "cpu", + vocab_size: int = DEFAULT_VOCAB_SIZE, + synthetic_seed: int = 42, + ): + super().__init__() + + cfg_data = getattr(Config(), "data", None) + if cfg_data is not None: + mode = getattr(cfg_data, "mode", mode) + base_dir = getattr(cfg_data, "base_dir", base_dir) + max_train_batches = getattr(cfg_data, "max_train_batches", max_train_batches) + max_val_batches = getattr(cfg_data, "max_val_batches", max_val_batches) + tokenizer_threads = getattr(cfg_data, "tokenizer_threads", tokenizer_threads) + tokenizer_batch_size = getattr( + cfg_data, "tokenizer_batch_size", tokenizer_batch_size + ) + device = getattr(cfg_data, "device", device) + vocab_size = getattr(cfg_data, "vocab_size", vocab_size) + synthetic_seed = getattr(cfg_data, "synthetic_seed", synthetic_seed) + + config = getattr(Config(), "parameters", None) + model_conf = getattr(config, "model", None) + default_seq_len = DEFAULT_SEQUENCE_LENGTH + if model_conf is not None and hasattr(model_conf, "_asdict"): + seq_len_candidate = model_conf._asdict().get("sequence_len") + if isinstance(seq_len_candidate, int) and seq_len_candidate > 0: + default_seq_len = seq_len_candidate + + resolved_sequence_len = sequence_length or default_seq_len + resolved_batch_size = batch_size or getattr( + getattr(Config(), "trainer", None), "batch_size", 1 + ) + + resolved_base_dir = _resolve_base_dir(base_dir) + dataset_mode = mode + if dataset_mode == "auto": + dataset_mode = "parquet" if _parquet_available(resolved_base_dir) else "synthetic" + + self.trainset = NanochatStreamingDataset( + split="train", + batch_size=resolved_batch_size, + sequence_length=resolved_sequence_len, + mode=dataset_mode, + base_dir=resolved_base_dir, + max_batches=max_train_batches, + tokenizer_threads=tokenizer_threads, + tokenizer_batch_size=tokenizer_batch_size, + device=device, + vocab_size=vocab_size, + synthetic_seed=synthetic_seed, + ) + self.testset = NanochatStreamingDataset( + split="val", + batch_size=resolved_batch_size, + sequence_length=resolved_sequence_len, + mode=dataset_mode, + base_dir=resolved_base_dir, + max_batches=max_val_batches, + tokenizer_threads=tokenizer_threads, + tokenizer_batch_size=tokenizer_batch_size, + device=device, + vocab_size=vocab_size, + synthetic_seed=synthetic_seed + 1, + ) + self.sequence_length = resolved_sequence_len + self.batch_size = resolved_batch_size + self.mode = dataset_mode + + @staticmethod + def input_shape(): + """Return the default input shape (sequence length).""" + return (DEFAULT_SEQUENCE_LENGTH,) + + def num_train_examples(self) -> int: + dataset = self.trainset + if dataset.max_batches is None: + raise RuntimeError( + "Nanochat datasource streams infinity; configure max_train_batches to report size." + ) + return dataset.max_batches * dataset.batch_size + + def num_test_examples(self) -> int: + dataset = self.testset + if dataset.max_batches is None: + raise RuntimeError( + "Nanochat datasource streams infinity; configure max_val_batches to report size." + ) + return dataset.max_batches * dataset.batch_size diff --git a/plato/datasources/registry.py b/plato/datasources/registry.py index e2f101de3..1334c027c 100644 --- a/plato/datasources/registry.py +++ b/plato/datasources/registry.py @@ -11,6 +11,7 @@ feature, femnist, huggingface, + nanochat, lora, purchase, texas, @@ -27,6 +28,7 @@ "Texas": texas, "TinyImageNet": tiny_imagenet, "Feature": feature, + "Nanochat": nanochat, } registered_partitioned_datasources = {"FEMNIST": femnist} diff --git a/plato/evaluators/__init__.py b/plato/evaluators/__init__.py new file mode 100644 index 000000000..62d84e994 --- /dev/null +++ b/plato/evaluators/__init__.py @@ -0,0 +1,3 @@ +"""Evaluation helpers for Plato integrations.""" + +# Intentionally empty for now; modules register themselves when imported. diff --git a/plato/evaluators/nanochat_core.py b/plato/evaluators/nanochat_core.py new file mode 100644 index 000000000..74a7c1447 --- /dev/null +++ b/plato/evaluators/nanochat_core.py @@ -0,0 +1,217 @@ +""" +Adapter utilities to run Nanochat's CORE evaluation benchmark within Plato. +""" + +from __future__ import annotations + +import csv +import json +import logging +import random +import time +from pathlib import Path +from typing import Any + +try: + import torch +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat CORE evaluation requires PyTorch. " + "Install the `nanochat` extra (includes torch)." + ) from exc + +import yaml + +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + +LOGGER = logging.getLogger(__name__) + + +def _resolve_bundle_paths(bundle_dir: str | Path | None) -> tuple[Path, Path, Path]: + """Resolve the configuration, metadata, and dataset paths for CORE evaluation.""" + ensure_nanochat_importable() + from nanochat.common import get_base_dir # pylint: disable=import-error + + if bundle_dir is None: + base_path = Path(get_base_dir()) + else: + base_path = Path(bundle_dir).expanduser().resolve() + + eval_bundle_dir = base_path / "eval_bundle" + config_path = eval_bundle_dir / "core.yaml" + data_dir = eval_bundle_dir / "eval_data" + metadata_path = eval_bundle_dir / "eval_meta_data.csv" + + if not config_path.exists(): + raise FileNotFoundError( + f"CORE evaluation config not found at {config_path}. " + "Ensure the Nanochat eval bundle is downloaded." + ) + if not data_dir.exists(): + raise FileNotFoundError( + f"CORE evaluation data directory not found at {data_dir}. " + "Ensure the Nanochat eval bundle is downloaded." + ) + if not metadata_path.exists(): + raise FileNotFoundError( + f"CORE evaluation metadata CSV not found at {metadata_path}." + ) + + return config_path, data_dir, metadata_path + + +def _load_core_tasks(config_path: Path) -> list[dict[str, Any]]: + """Load task definitions from the CORE YAML config.""" + with config_path.open("r", encoding="utf-8") as handle: + config = yaml.safe_load(handle) + tasks = config.get("icl_tasks", []) + if not isinstance(tasks, list) or not tasks: + raise ValueError( + f"No CORE tasks defined in {config_path}. Inspect the eval bundle." + ) + return tasks + + +def _load_metadata(metadata_path: Path) -> dict[str, float]: + """Load random baseline metadata for centering accuracy.""" + baseline_map: dict[str, float] = {} + with metadata_path.open("r", encoding="utf-8") as handle: + reader = csv.DictReader(handle) + for row in reader: + label = row.get("Eval Task") + baseline = row.get("Random baseline") + if label is None or baseline is None: + continue + try: + baseline_map[label] = float(baseline) + except ValueError: + LOGGER.debug("Skipping malformed baseline row: %s", row) + if not baseline_map: + raise ValueError( + f"Random baselines missing in {metadata_path}. Required for CORE metric." + ) + return baseline_map + + +def _load_task_data(data_dir: Path, dataset_uri: str) -> list[dict[str, Any]]: + """Load task dataset rows from newline-delimited JSON.""" + path = data_dir / dataset_uri + if not path.exists(): + raise FileNotFoundError( + f"CORE dataset shard '{dataset_uri}' missing under {data_dir}." + ) + with path.open("r", encoding="utf-8") as handle: + return [json.loads(line.strip()) for line in handle if line.strip()] + + +def _resolve_tokenizer(model) -> Any: + """Obtain a tokenizer compatible with Nanochat core evaluation.""" + tokenizer = getattr(model, "nanochat_tokenizer", None) + if tokenizer is not None: + return tokenizer + + ensure_nanochat_importable() + from nanochat.tokenizer import get_tokenizer # pylint: disable=import-error + + return get_tokenizer() + + +def run_core_evaluation( + model: torch.nn.Module, + *, + tokenizer: Any | None = None, + bundle_dir: str | Path | None = None, + max_per_task: int = -1, + device: torch.device | str | None = None, +) -> dict[str, Any]: + """ + Execute the CORE benchmark for the provided model. + + Args: + model: Nanochat-style autoregressive model. + tokenizer: Optional tokenizer; falls back to nanochat.tokenizer.get_tokenizer(). + bundle_dir: Optional base directory containing `eval_bundle/`. + max_per_task: Optional cap on examples per task for quicker smoke tests (-1 = all). + device: Device to run evaluation on. Defaults to the model's current device. + + Returns: + Dictionary with `results`, `centered_results`, and `core_metric`. + """ + ensure_nanochat_importable() + from nanochat.core_eval import evaluate_task # pylint: disable=import-error + + config_path, data_dir, metadata_path = _resolve_bundle_paths(bundle_dir) + tasks = _load_core_tasks(config_path) + baselines = _load_metadata(metadata_path) + + eval_tokenizer = tokenizer or _resolve_tokenizer(model) + if eval_tokenizer is None: + raise RuntimeError( + "Nanochat CORE evaluation requires a tokenizer. " + "Either attach `model.nanochat_tokenizer` or provide one explicitly." + ) + + if device is None: + try: + first_param = next(model.parameters()) + device = first_param.device + except StopIteration: + device = torch.device("cpu") + if isinstance(device, str): + device = torch.device(device) + + model_device = device + model_was_training = model.training + model = model.to(model_device) + model.eval() + + results: dict[str, float] = {} + centered_results: dict[str, float] = {} + + for task in tasks: + label = task.get("label") + if not label: + LOGGER.debug("Skipping unnamed CORE task entry: %s", task) + continue + + task_meta = { + "task_type": task.get("icl_task_type"), + "dataset_uri": task.get("dataset_uri"), + "num_fewshot": task.get("num_fewshot", [0])[0], + "continuation_delimiter": task.get("continuation_delimiter", " "), + } + start_time = time.perf_counter() + + data = _load_task_data(data_dir, task_meta["dataset_uri"]) + shuffle_rng = random.Random(1337) + shuffle_rng.shuffle(data) + if max_per_task > 0: + data = data[:max_per_task] + + accuracy = evaluate_task(model, eval_tokenizer, data, model_device, task_meta) + baseline = baselines.get(label, 0.0) + centered = (accuracy - 0.01 * baseline) / (1.0 - 0.01 * baseline) + + results[label] = accuracy + centered_results[label] = centered + elapsed = time.perf_counter() - start_time + LOGGER.info( + "CORE task %s | accuracy %.4f | centered %.4f | %.2fs", + label, + accuracy, + centered, + elapsed, + ) + + if model_was_training: + model.train() + + if not centered_results: + raise RuntimeError("No CORE tasks were evaluated; check the eval bundle.") + + core_metric = sum(centered_results.values()) / len(centered_results) + return { + "results": results, + "centered_results": centered_results, + "core_metric": core_metric, + } diff --git a/plato/models/nanochat.py b/plato/models/nanochat.py new file mode 100644 index 000000000..d2a4b5dc2 --- /dev/null +++ b/plato/models/nanochat.py @@ -0,0 +1,135 @@ +""" +Factory for Nanochat GPT models integrated with Plato's registry. +""" + +from __future__ import annotations + +from dataclasses import fields +from typing import Any + +try: + import torch +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat model integration requires PyTorch. " + "Install torch via the project's optional dependencies." + ) from exc + +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + + +DEFAULT_MODEL_CONFIG: dict[str, int] = { + "sequence_len": 2048, + "vocab_size": 50304, + "n_layer": 12, + "n_head": 6, + "n_kv_head": 6, + "n_embd": 768, +} + + +def _import_nanochat_modules(): + ensure_nanochat_importable() + from nanochat.gpt import GPT, GPTConfig # type: ignore[attr-defined] + from nanochat.checkpoint_manager import ( # type: ignore[attr-defined] + load_model_from_dir, + ) + + return GPT, GPTConfig, load_model_from_dir + + +def _sanitize_kwargs(kwargs: dict[str, Any], valid_fields: set[str]) -> dict[str, Any]: + """Filter kwargs to those accepted by GPTConfig.""" + config_kwargs = DEFAULT_MODEL_CONFIG.copy() + for key, value in kwargs.items(): + if key in valid_fields: + config_kwargs[key] = value + return config_kwargs + + +def _load_from_checkpoint( + load_dir: str, + *, + device: str | torch.device = "cpu", + phase: str = "train", + model_tag: str | None = None, + step: int | None = None, +): + """Load a Nanochat checkpoint via checkpoint_manager.""" + GPT, _, load_model_from_dir = _import_nanochat_modules() + torch_device = torch.device(device) + model, tokenizer, metadata = load_model_from_dir( + load_dir, + device=torch_device, + phase=phase, + model_tag=model_tag, + step=step, + ) + # Attach helpful metadata to the model for downstream use. + setattr(model, "nanochat_tokenizer", tokenizer) + setattr(model, "nanochat_metadata", metadata) + if not isinstance(model, GPT): + raise TypeError( + "Checkpoint loader returned an unexpected model type. " + "Ensure the checkpoint directory points to Nanochat artifacts." + ) + return model + + +class Model: + """Nanochat GPT factory compatible with Plato's model registry.""" + + @staticmethod + def get(model_name: str | None = None, **kwargs: Any): + """ + Instantiate a Nanochat GPT model. + + Keyword Args: + sequence_len: Context length (tokens). + vocab_size: Token vocabulary size. + n_layer: Number of transformer blocks. + n_head: Attention heads for queries. + n_kv_head: Attention heads for keys/values (MQA/GQA). + n_embd: Hidden dimension width. + init_weights: Whether to run Nanochat's weight initialisation (default True). + load_checkpoint_dir: Optional checkpoint directory produced by Nanochat. + load_checkpoint_tag: Optional subdirectory/model tag within checkpoint dir. + load_checkpoint_step: Optional numeric step to load (defaults to latest). + device: Torch device string for checkpoint loading. + phase: "train" or "eval" when loading checkpoints. + """ + try: + GPT, GPTConfig, _ = _import_nanochat_modules() + except ThirdPartyImportError as exc: # pragma: no cover - defensive branch + raise ImportError( + "Nanochat vendor tree not found. " + "Ensure runtime/third_party/nanochat is available." + ) from exc + + init_weights = kwargs.pop("init_weights", True) + load_dir = kwargs.pop("load_checkpoint_dir", None) + checkpoint_tag = kwargs.pop("load_checkpoint_tag", None) + checkpoint_step = kwargs.pop("load_checkpoint_step", None) + checkpoint_phase = kwargs.pop("phase", "train") + checkpoint_device = kwargs.pop("device", "cpu") + + # GPTConfig only accepts specific fields; filter unknown kwargs. + config_fields = {field.name for field in fields(GPTConfig)} + config_kwargs = _sanitize_kwargs(kwargs, config_fields) + + if load_dir: + model = _load_from_checkpoint( + load_dir, + device=checkpoint_device, + phase=checkpoint_phase, + model_tag=checkpoint_tag, + step=checkpoint_step, + ) + return model + + config = GPTConfig(**config_kwargs) + model = GPT(config) + if init_weights: + model.init_weights() + setattr(model, "nanochat_config", config_kwargs) + return model diff --git a/plato/models/registry.py b/plato/models/registry.py index e6691e3a0..4ce573e94 100644 --- a/plato/models/registry.py +++ b/plato/models/registry.py @@ -14,6 +14,7 @@ general_multilayer, huggingface, lenet5, + nanochat, multilayer, resnet, torch_hub, @@ -40,6 +41,7 @@ "torch_hub": torch_hub.Model, "huggingface": huggingface.Model, "vit": vit.Model, + "nanochat": nanochat.Model, } registered_mlx_models = {} diff --git a/plato/processors/nanochat_tokenizer.py b/plato/processors/nanochat_tokenizer.py new file mode 100644 index 000000000..e67565f5f --- /dev/null +++ b/plato/processors/nanochat_tokenizer.py @@ -0,0 +1,133 @@ +""" +Prototype tokenizer processor for Nanochat, wrapping the rustbpe+tiktoken stack. + +This module exercises the build tooling for the Rust extension while providing +an adapter that conforms to Plato's processor interface. +""" + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +import pickle +from pathlib import Path +from typing import Any + +from plato.processors.base import Processor + +SPECIAL_TOKENS = [ + "<|bos|>", + "<|user_start|>", + "<|user_end|>", + "<|assistant_start|>", + "<|assistant_end|>", + "<|python_start|>", + "<|python_end|>", + "<|output_start|>", + "<|output_end|>", +] + +DEFAULT_PATTERN = ( + r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]" + r"|\s+(?!\S)|\s+" +) + + +class NanochatTokenizerProcessor(Processor): + """Prototype tokenizer that can either load a saved encoding or train via rustbpe.""" + + def __init__( + self, + tokenizer_path: str | Path | None = None, + train_corpus: Iterable[str] | None = None, + vocab_size: int = 32000, + pattern: str = DEFAULT_PATTERN, + bos_token: str = "<|bos|>", + prepend_bos: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.tokenizer_path = Path(tokenizer_path) if tokenizer_path else None + self.train_corpus = train_corpus + self.vocab_size = max(vocab_size, 256 + len(SPECIAL_TOKENS)) + self.pattern = pattern + self.bos_token = bos_token + self.prepend_bos = prepend_bos + + self._encoding = self._load_encoding() + self._special_tokens = self._infer_special_tokens() + self._bos_token_id = self._special_tokens.get(self.bos_token) + + def _load_encoding(self): + if self.tokenizer_path: + return self._load_from_pickle(self.tokenizer_path) + if self.train_corpus is not None: + return self._train_from_corpus(self.train_corpus) + raise ValueError("Either tokenizer_path or train_corpus must be provided.") + + def _load_from_pickle(self, path: Path): + with path.open("rb") as handle: + encoding = pickle.load(handle) + return encoding + + def _train_from_corpus(self, corpus: Iterable[str]): + try: + import rustbpe # type: ignore[import-not-found] + except ImportError as exc: # pragma: no cover - guarded import + raise RuntimeError( + "rustbpe extension is required to train a Nanochat tokenizer." + ) from exc + + try: + import tiktoken # type: ignore[import-not-found] + except ImportError as exc: # pragma: no cover - guarded import + raise RuntimeError( + "tiktoken is required to construct Nanochat tokenizer encodings." + ) from exc + + tokenizer = rustbpe.Tokenizer() + tokenizer.train_from_iterator( + iter(corpus), + self.vocab_size - len(SPECIAL_TOKENS), + pattern=self.pattern, + ) + + mergeable_ranks = { + bytes(piece): rank for piece, rank in tokenizer.get_mergeable_ranks() + } + tokens_offset = len(mergeable_ranks) + special_tokens = { + token: tokens_offset + index for index, token in enumerate(SPECIAL_TOKENS) + } + + return tiktoken.Encoding( + name="nanochat-rustbpe", + pat_str=tokenizer.get_pattern(), + mergeable_ranks=mergeable_ranks, + special_tokens=special_tokens, + ) + + def _infer_special_tokens(self): + try: + encode_single_token = self._encoding.encode_single_token + special_token_set = getattr(self._encoding, "special_tokens_set", set()) + except AttributeError as exc: + raise RuntimeError("tiktoken encoding missing expected interfaces.") from exc + + mapping = {} + for token in SPECIAL_TOKENS: + if token in special_token_set: + mapping[token] = encode_single_token(token) + return mapping + + def _encode_one(self, text: str): + ids = list(self._encoding.encode_ordinary(text)) + if self.prepend_bos and self._bos_token_id is not None: + ids.insert(0, self._bos_token_id) + return ids + + def process(self, data: Any): + if isinstance(data, str): + return self._encode_one(data) + if isinstance(data, Sequence): + return [self._encode_one(item) for item in data] + raise TypeError(f"Unsupported payload type: {type(data)!r}") diff --git a/plato/trainers/nanochat.py b/plato/trainers/nanochat.py new file mode 100644 index 000000000..74555ed22 --- /dev/null +++ b/plato/trainers/nanochat.py @@ -0,0 +1,285 @@ +""" +Trainer wiring for Nanochat models within the composable trainer framework. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Iterable, Iterator, List, Sequence + +try: + import torch +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat trainer requires PyTorch. " + "Install torch via the project's optional dependencies." + ) from exc + +from plato.config import Config +from plato.datasources.nanochat import NanochatStreamingDataset +from plato.evaluators.nanochat_core import run_core_evaluation +from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies.base import ( + DataLoaderStrategy, + OptimizerStrategy, + TestingStrategy, + TrainingContext, + TrainingStepStrategy, +) +from plato.trainers.strategies.data_loader import DefaultDataLoaderStrategy +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + + +def _first_element_collate(batch: Sequence[Any]) -> Any: + """Return the first (and only) element from a DataLoader batch.""" + if not batch: + raise ValueError("Received empty batch from Nanochat dataset.") + return batch[0] + + +class NanochatDataLoaderStrategy(DataLoaderStrategy): + """Use identity collation for pre-batched Nanochat streaming datasets.""" + + def __init__(self): + self._fallback = DefaultDataLoaderStrategy() + + def create_train_loader( + self, trainset, sampler, batch_size: int, context: TrainingContext + ) -> torch.utils.data.DataLoader | Iterator: + if isinstance(trainset, NanochatStreamingDataset): + return torch.utils.data.DataLoader( + trainset, + batch_size=1, + shuffle=False, + sampler=None, + num_workers=0, + collate_fn=_first_element_collate, + ) + return self._fallback.create_train_loader(trainset, sampler, batch_size, context) + + +class NanochatTrainingStepStrategy(TrainingStepStrategy): + """Call Nanochat's integrated loss computation during training.""" + + def __init__(self, loss_reduction: str = "mean"): + self.loss_reduction = loss_reduction + + def training_step( + self, + model: torch.nn.Module, + optimizer, + examples: torch.Tensor, + labels: torch.Tensor, + loss_criterion, + context: TrainingContext, + ) -> torch.Tensor: + optimizer.zero_grad() + if labels is not None: + loss = model(examples, targets=labels, loss_reduction=self.loss_reduction) + else: + outputs = model(examples) + loss = loss_criterion(outputs, labels) + + if not isinstance(loss, torch.Tensor): + raise TypeError( + "Nanochat model forward pass must return a torch.Tensor loss." + ) + + loss.backward() + optimizer.step() + context.state["optimizer_step_completed"] = True + return loss.detach() + + +@dataclass +class _OptimizerBundle: + """Bundle multiple optimizers under a single interface.""" + + optimizers: List[torch.optim.Optimizer] + + def zero_grad(self) -> None: + for optimizer in self.optimizers: + optimizer.zero_grad() + + def step(self) -> None: + for optimizer in self.optimizers: + optimizer.step() + + def state_dict(self) -> dict[str, Any]: + return { + "optimizers": [optimizer.state_dict() for optimizer in self.optimizers], + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + for optimizer, payload in zip( + self.optimizers, state_dict.get("optimizers", []), strict=False # type: ignore[arg-type] + ): + optimizer.load_state_dict(payload) + + @property + def param_groups(self) -> list[dict[str, Any]]: + groups: list[dict[str, Any]] = [] + for optimizer in self.optimizers: + groups.extend(getattr(optimizer, "param_groups", [])) + return groups + + def params_state_update(self) -> None: + for optimizer in self.optimizers: + hook = getattr(optimizer, "params_state_update", None) + if callable(hook): + hook() + + +class NanochatOptimizerStrategy(OptimizerStrategy): + """Adapter around nanochat.gpt.GPT.setup_optimizers.""" + + def __init__( + self, + *, + unembedding_lr: float = 0.004, + embedding_lr: float = 0.2, + matrix_lr: float = 0.02, + weight_decay: float = 0.0, + ): + self.unembedding_lr = unembedding_lr + self.embedding_lr = embedding_lr + self.matrix_lr = matrix_lr + self.weight_decay = weight_decay + + def create_optimizer( + self, model: torch.nn.Module, context: TrainingContext + ) -> _OptimizerBundle: + if not hasattr(model, "setup_optimizers"): + raise AttributeError( + "Nanochat model is expected to expose setup_optimizers()." + ) + + optimizers = model.setup_optimizers( + unembedding_lr=self.unembedding_lr, + embedding_lr=self.embedding_lr, + matrix_lr=self.matrix_lr, + weight_decay=self.weight_decay, + ) + if not isinstance(optimizers, Iterable): + raise TypeError("setup_optimizers() must return an iterable of optimizers.") + + optimizer_list: list[torch.optim.Optimizer] = list(optimizers) + if not optimizer_list: + raise ValueError("setup_optimizers() returned an empty optimizer list.") + return _OptimizerBundle(optimizer_list) + + +class NanochatTestingStrategy(TestingStrategy): + """Compute average token loss over the validation iterator.""" + + def __init__(self, reduction: str = "sum"): + self.reduction = reduction + + def test_model( + self, + model: torch.nn.Module, + config: dict[str, Any], + testset, + sampler, + context: TrainingContext, + ) -> float: + if not isinstance(testset, NanochatStreamingDataset): + raise TypeError( + "NanochatTestingStrategy expects a NanochatStreamingDataset instance." + ) + + model.eval() + total_loss = 0.0 + total_tokens = 0 + + with torch.no_grad(): + for inputs, targets in testset: + inputs = inputs.to(context.device) + targets = targets.to(context.device) + loss = model(inputs, targets=targets, loss_reduction=self.reduction) + total_loss += float(loss.item()) + total_tokens += targets.numel() + + model.train() + if total_tokens == 0: + return float("nan") + return total_loss / total_tokens + + +class NanochatCoreTestingStrategy(TestingStrategy): + """Evaluate the CORE benchmark and return the aggregate metric.""" + + def __init__(self, bundle_dir: str | None = None, max_per_task: int = -1): + self.bundle_dir = bundle_dir + self.max_per_task = max_per_task + + def test_model( + self, + model: torch.nn.Module, + config: dict[str, Any], + testset, + sampler, + context: TrainingContext, + ) -> float: + device = context.device or next(model.parameters()).device + tokenizer = getattr(model, "nanochat_tokenizer", None) + results = run_core_evaluation( + model, + tokenizer=tokenizer, + bundle_dir=self.bundle_dir, + max_per_task=self.max_per_task, + device=device, + ) + context.state["nanochat_core_results"] = results + return float(results["core_metric"]) + + +class Trainer(ComposableTrainer): + """Composable trainer specialised for Nanochat workloads.""" + + def __init__( + self, + model=None, + callbacks=None, + *, + optimizer_params: dict[str, Any] | None = None, + loss_reduction: str = "mean", + ): + try: + ensure_nanochat_importable() + except ThirdPartyImportError as exc: # pragma: no cover - defensive branch + raise ImportError( + "Nanochat trainer requires the vendored runtime/third_party/nanochat project." + ) from exc + + optimizer_strategy = NanochatOptimizerStrategy( + **(optimizer_params or {}), + ) + training_step_strategy = NanochatTrainingStepStrategy( + loss_reduction=loss_reduction + ) + data_loader_strategy = NanochatDataLoaderStrategy() + + evaluation_cfg = getattr(Config(), "evaluation", None) + evaluation_type = getattr(evaluation_cfg, "type", "").lower() if evaluation_cfg else "" + if evaluation_type == "nanochat_core": + max_per_task = getattr(evaluation_cfg, "max_per_task", -1) + max_per_task_value = -1 if max_per_task is None else int(max_per_task) + testing_strategy = NanochatCoreTestingStrategy( + bundle_dir=getattr(evaluation_cfg, "bundle_dir", None), + max_per_task=max_per_task_value, + ) + else: + testing_strategy = NanochatTestingStrategy() + + super().__init__( + model=model, + callbacks=callbacks, + loss_strategy=None, + optimizer_strategy=optimizer_strategy, + training_step_strategy=training_step_strategy, + lr_scheduler_strategy=None, + model_update_strategy=None, + data_loader_strategy=data_loader_strategy, + testing_strategy=testing_strategy, + ) diff --git a/plato/trainers/registry.py b/plato/trainers/registry.py index ff0db0cb9..977d04fd2 100644 --- a/plato/trainers/registry.py +++ b/plato/trainers/registry.py @@ -10,6 +10,7 @@ basic, composable, gan, + nanochat as nanochat_trainer, split_learning, ) @@ -19,6 +20,7 @@ "timm_basic": basic.TrainerWithTimmScheduler, "gan": gan.Trainer, "split_learning": split_learning.Trainer, + "nanochat": nanochat_trainer.Trainer, } diff --git a/plato/utils/third_party.py b/plato/utils/third_party.py new file mode 100644 index 000000000..9296f86fa --- /dev/null +++ b/plato/utils/third_party.py @@ -0,0 +1,42 @@ +""" +Helpers for accessing vendored third-party projects. +""" + +from __future__ import annotations + +import sys +from functools import lru_cache +from pathlib import Path + + +class ThirdPartyImportError(ImportError): + """Raised when a vendored third-party project is unavailable.""" + + +@lru_cache(maxsize=None) +def _nanochat_root() -> Path: + """Return the root directory of the vendored Nanochat project.""" + repo_root = Path(__file__).resolve().parents[2] + nanochat_root = repo_root / "runtime" / "third_party" / "nanochat" + if not nanochat_root.exists(): + raise ThirdPartyImportError( + "Nanochat is not vendored under runtime/third_party/nanochat." + ) + return nanochat_root + + +def ensure_nanochat_importable() -> Path: + """ + Ensure the vendored Nanochat package is importable. + + Returns: + Path to the Nanochat project root. + + Raises: + ThirdPartyImportError: If the vendored Nanochat tree is missing. + """ + nanochat_root = _nanochat_root() + path_str = str(nanochat_root) + if path_str not in sys.path: + sys.path.insert(0, path_str) + return nanochat_root diff --git a/pyproject.toml b/pyproject.toml index 099d06475..e12d65cbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,18 @@ dp = [ mlx = [ "mlx", ] +nanochat = [ + "datasets>=2.14.0", + "numpy==1.26.4", + "psutil>=5.9.0", + "regex>=2023.8.8", + "tiktoken>=0.5.0", + "tokenizers>=0.15.0", + "torch>=2.2.0", + "wandb>=0.16.0", + "jinja2>=3.0", + "PyYAML>=6.0", +] [project.urls] Homepage = "https://github.com/TL-System/plato" @@ -89,6 +101,7 @@ members = [ "examples/detector", "examples/gradient_leakage_attacks", "tools", + "examples/nanochat", ] [dependency-groups] diff --git a/tests/test_nanochat_integration.py b/tests/test_nanochat_integration.py new file mode 100644 index 000000000..4a60cae36 --- /dev/null +++ b/tests/test_nanochat_integration.py @@ -0,0 +1,152 @@ +"""Nanochat integration smoke checks (optional, require nanochat extras).""" + +from __future__ import annotations + +import importlib.util +import os + +import pytest + +from plato.config import Config, ConfigNode + + +pytestmark = pytest.mark.integration + +_RUSTBPE_AVAILABLE = importlib.util.find_spec("rustbpe") is not None + + +@pytest.mark.skipif( + not _RUSTBPE_AVAILABLE, + reason="Nanochat tokenizer tests require rustbpe extension (install nanochat extras).", +) +def test_nanochat_tokenizer_processor_round_trip(tmp_path): + """Train a tiny tokenizer via rustbpe and encode a sample string.""" + pytest.importorskip( + "tiktoken", reason="Nanochat tokenizer tests require tiktoken (nanochat extra)." + ) + from plato.processors.nanochat_tokenizer import NanochatTokenizerProcessor + + corpus = ["hello nanochat", "minimal tokenizer exercise"] + processor = NanochatTokenizerProcessor( + train_corpus=corpus, + vocab_size=300, + prepend_bos=False, + ) + encoded = processor.process("hello nanochat") + assert isinstance(encoded, list) + assert len(encoded) > 0 + + +def test_nanochat_trainer_smoke(temp_config, tmp_path): + """Run one training step with synthetic Nanochat data on CPU.""" + _ = pytest.importorskip( + "torch", reason="Nanochat trainer smoke requires torch (nanochat extra)." + ) + from plato.datasources.nanochat import DataSource as NanochatDataSource + from plato.models.nanochat import Model as NanochatModel + from plato.trainers.nanochat import Trainer as NanochatTrainer + + cfg = Config() + cfg.trainer.type = "nanochat" + cfg.trainer.model_name = "nanochat_smoke" + cfg.trainer.batch_size = 2 + cfg.trainer.rounds = 1 + cfg.trainer.epochs = 1 + + cfg.parameters.model = { + "sequence_len": 16, + "vocab_size": 512, + "n_layer": 1, + "n_head": 2, + "n_kv_head": 2, + "n_embd": 128, + } + + datasource = NanochatDataSource( + batch_size=cfg.trainer.batch_size, + sequence_length=cfg.parameters.model["sequence_len"], + mode="synthetic", + max_train_batches=2, + max_val_batches=1, + device="cpu", + vocab_size=cfg.parameters.model["vocab_size"], + synthetic_seed=123, + ) + + model = NanochatModel.get( + sequence_len=cfg.parameters.model["sequence_len"], + vocab_size=cfg.parameters.model["vocab_size"], + n_layer=cfg.parameters.model["n_layer"], + n_head=cfg.parameters.model["n_head"], + n_kv_head=cfg.parameters.model["n_kv_head"], + n_embd=cfg.parameters.model["n_embd"], + init_weights=True, + ) + + trainer = NanochatTrainer(model=model) + trainset = datasource.get_train_set() + elapsed = trainer.train(trainset, sampler=None) + + assert isinstance(elapsed, float) + assert elapsed >= 0.0 + + model_dir = Config().params["model_path"] + checkpoint_name = ( + f"{cfg.trainer.model_name}_{trainer.client_id}_{Config().params['run_id']}.safetensors" + ) + assert os.path.exists(os.path.join(model_dir, checkpoint_name)) + + +def test_nanochat_trainer_selects_core_eval_strategy(temp_config, monkeypatch): + """Ensure evaluation config triggers the CORE testing strategy.""" + _ = pytest.importorskip( + "torch", reason="Nanochat trainer requires torch (nanochat extra)." + ) + from plato.models.nanochat import Model as NanochatModel + from plato.trainers.nanochat import ( + NanochatCoreTestingStrategy, + Trainer as NanochatTrainer, + ) + + monkeypatch.setattr( + "plato.evaluators.nanochat_core.run_core_evaluation", + lambda *args, **kwargs: { + "results": {}, + "centered_results": {}, + "core_metric": 0.0, + }, + ) + + cfg = Config() + cfg.trainer.type = "nanochat" + cfg.trainer.model_name = "nanochat_core" + cfg.trainer.batch_size = 1 + cfg.trainer.rounds = 1 + cfg.trainer.epochs = 1 + cfg.parameters.model = { + "sequence_len": 16, + "vocab_size": 512, + "n_layer": 1, + "n_head": 2, + "n_kv_head": 2, + "n_embd": 128, + } + cfg.evaluation = ConfigNode.from_object( + { + "type": "nanochat_core", + "max_per_task": 1, + } + ) + + model = NanochatModel.get( + sequence_len=cfg.parameters.model["sequence_len"], + vocab_size=cfg.parameters.model["vocab_size"], + n_layer=cfg.parameters.model["n_layer"], + n_head=cfg.parameters.model["n_head"], + n_kv_head=cfg.parameters.model["n_kv_head"], + n_embd=cfg.parameters.model["n_embd"], + init_weights=True, + ) + + trainer = NanochatTrainer(model=model) + assert isinstance(trainer.testing_strategy, NanochatCoreTestingStrategy) From 822a1cb35c1c413e173ebf72b336b7aaa2088d70 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Tue, 28 Oct 2025 11:03:40 -0400 Subject: [PATCH 02/33] Added support for vendoring the external Nanochat repo as a git submodule under 'external/nanochat'. --- .gitmodules | 3 +++ docs/nanochat_integration_checklist.md | 8 ++++---- docs/third_party.md | 19 +++++++++---------- examples/nanochat/README.md | 2 +- external/nanochat | 1 + plato/models/nanochat.py | 4 ++-- plato/trainers/nanochat.py | 3 ++- plato/utils/third_party.py | 6 +++--- 8 files changed, 25 insertions(+), 21 deletions(-) create mode 160000 external/nanochat diff --git a/.gitmodules b/.gitmodules index 377fc17b7..1176795af 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "plato/models/t2tvit"] path = plato/models/t2tvit url = https://github.com/yitu-opensource/T2T-ViT +[submodule "external/nanochat"] + path = external/nanochat + url = https://github.com/karpathy/nanochat.git diff --git a/docs/nanochat_integration_checklist.md b/docs/nanochat_integration_checklist.md index 5fa411526..f0b7bf9b6 100644 --- a/docs/nanochat_integration_checklist.md +++ b/docs/nanochat_integration_checklist.md @@ -2,22 +2,22 @@ This checklist coordinates the work to incorporate the Nanochat stack into Plato. Owners are placeholder roles until specific engineers are assigned. -## Third-Party Snapshot +## Third-Party Submodule - **Owner:** Infrastructure -- **Deliverables:** Vendor `runtime/third_party/nanochat` snapshot with commit hash in docs; add provenance blurb to `docs/third_party.md`. +- **Deliverables:** Maintain the `external/nanochat` git submodule; document update procedure in `docs/third_party.md`. - **Dependencies:** None. ## Model Registry - **Owner:** Modeling - **Deliverables:** Implement `plato/models/nanochat.py` mirroring Nanochat GPT config, register entry in `plato/models/registry.py`, supply weight-loading utilities. - **Status:** In progress – factory module and registry wiring landed. -- **Dependencies:** Third-party snapshot. +- **Dependencies:** Third-party submodule. ## Tokenizer & Processor - **Owner:** Infrastructure - **Deliverables:** Package Rust BPE via Maturin optional extra; wrap as `plato/processors/nanochat_tokenizer.py` with lazy import and fallbacks; document build steps in README. - **Status:** Prototype processor and optional dependency group landed; CI build integration remains TODO. -- **Dependencies:** Third-party snapshot, build tooling prototype. +- **Dependencies:** Third-party submodule, build tooling prototype. ## Datasource - **Owner:** Data diff --git a/docs/third_party.md b/docs/third_party.md index 304d62b4e..d481b6829 100644 --- a/docs/third_party.md +++ b/docs/third_party.md @@ -4,16 +4,15 @@ This page records external projects that are vendored into the Plato repository ## Nanochat - **Upstream:** [karpathy/nanochat](https://github.com/karpathy/nanochat) -- **Vendored location:** `runtime/third_party/nanochat` -- **Snapshot commit:** `c75fe54aa7c1fa881701c246f9427bcbe4eee5a4` (captured 2025-03-04) -- **License:** MIT (included in `runtime/third_party/nanochat/LICENSE`) +- **Location:** `external/nanochat` (git submodule) +- **License:** MIT (included in `external/nanochat/LICENSE`) -### Updating the Snapshot -1. `cd runtime/third_party/nanochat` -2. `git fetch origin && git checkout ` -3. Review upstream changes and confirm compatibility with Plato. -4. Record the new commit hash and date in this document, and call out notable changes in the integration checklist before landing. +### Updating the Submodule +1. `git submodule update --remote external/nanochat` +2. Inspect upstream changes for compatibility with Plato. +3. Commit the submodule pointer update and note any required integration work in the checklist. ### Notes -- The Rust tokenizer (`rustbpe`) builds via `maturin`. Ensure `uv run --with ./runtime/third_party/nanochat maturin develop --release` succeeds before pushing updates. -- Keep the vendored tree free of local modifications unless backporting fixes; prefer upstream contributions when feasible. +- After cloning Plato, run `git submodule update --init --recursive` to populate all external dependencies. +- The Rust tokenizer (`rustbpe`) builds via `maturin`. Ensure `uv run --with ./external/nanochat maturin develop --release` succeeds before pushing updates. +- Avoid local modifications inside the submodule; contribute fixes upstream when possible. diff --git a/examples/nanochat/README.md b/examples/nanochat/README.md index 5b772c855..16ae42fc5 100644 --- a/examples/nanochat/README.md +++ b/examples/nanochat/README.md @@ -8,7 +8,7 @@ This workspace hosts Nanochat-focused experiments within Plato. ```bash uv sync --extra nanochat - uv run --with ./runtime/third_party/nanochat maturin develop --release + uv run --with ./external/nanochat maturin develop --release ``` 2. Run the synthetic smoke configuration: diff --git a/external/nanochat b/external/nanochat new file mode 160000 index 000000000..c75fe54aa --- /dev/null +++ b/external/nanochat @@ -0,0 +1 @@ +Subproject commit c75fe54aa7c1fa881701c246f9427bcbe4eee5a4 diff --git a/plato/models/nanochat.py b/plato/models/nanochat.py index d2a4b5dc2..290111353 100644 --- a/plato/models/nanochat.py +++ b/plato/models/nanochat.py @@ -102,8 +102,8 @@ def get(model_name: str | None = None, **kwargs: Any): GPT, GPTConfig, _ = _import_nanochat_modules() except ThirdPartyImportError as exc: # pragma: no cover - defensive branch raise ImportError( - "Nanochat vendor tree not found. " - "Ensure runtime/third_party/nanochat is available." + "Nanochat submodule not found. " + "Run `git submodule update --init --recursive` to populate external/nanochat." ) from exc init_weights = kwargs.pop("init_weights", True) diff --git a/plato/trainers/nanochat.py b/plato/trainers/nanochat.py index 74555ed22..5415cb6bb 100644 --- a/plato/trainers/nanochat.py +++ b/plato/trainers/nanochat.py @@ -249,7 +249,8 @@ def __init__( ensure_nanochat_importable() except ThirdPartyImportError as exc: # pragma: no cover - defensive branch raise ImportError( - "Nanochat trainer requires the vendored runtime/third_party/nanochat project." + "Nanochat trainer requires the external/nanochat submodule. " + "Run `git submodule update --init --recursive`." ) from exc optimizer_strategy = NanochatOptimizerStrategy( diff --git a/plato/utils/third_party.py b/plato/utils/third_party.py index 9296f86fa..890764e13 100644 --- a/plato/utils/third_party.py +++ b/plato/utils/third_party.py @@ -15,12 +15,12 @@ class ThirdPartyImportError(ImportError): @lru_cache(maxsize=None) def _nanochat_root() -> Path: - """Return the root directory of the vendored Nanochat project.""" + """Return the root directory of the Nanochat submodule.""" repo_root = Path(__file__).resolve().parents[2] - nanochat_root = repo_root / "runtime" / "third_party" / "nanochat" + nanochat_root = repo_root / "external" / "nanochat" if not nanochat_root.exists(): raise ThirdPartyImportError( - "Nanochat is not vendored under runtime/third_party/nanochat." + "Nanochat submodule missing. Run `git submodule update --init --recursive`." ) return nanochat_root From e039b2c73a9d4b76dfae404ef3dae765fd8b8cbf Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Tue, 28 Oct 2025 11:07:41 -0400 Subject: [PATCH 03/33] ruff check --fix & ruff format. --- .../fedunlearning/fedunlearning_server.py | 4 +--- plato/datasources/nanochat.py | 17 +++++++++++++---- plato/datasources/registry.py | 2 +- plato/models/nanochat.py | 3 +-- plato/models/registry.py | 2 +- plato/processors/nanochat_tokenizer.py | 6 ++++-- plato/trainers/nanochat.py | 12 +++++++++--- plato/trainers/registry.py | 4 +++- tests/test_nanochat_integration.py | 7 +++---- 9 files changed, 36 insertions(+), 21 deletions(-) diff --git a/examples/unlearning/fedunlearning/fedunlearning_server.py b/examples/unlearning/fedunlearning/fedunlearning_server.py index 8938b6549..6d6ade576 100644 --- a/examples/unlearning/fedunlearning/fedunlearning_server.py +++ b/examples/unlearning/fedunlearning/fedunlearning_server.py @@ -43,9 +43,7 @@ async def aggregate_deltas(self, updates, deltas_received, context): if not filtered_pairs: if self._fallback_to_original: - return await super().aggregate_deltas( - updates, deltas_received, context - ) + return await super().aggregate_deltas(updates, deltas_received, context) zero_delta = self._zero_delta( context, deltas_received[0] if deltas_received else None diff --git a/plato/datasources/nanochat.py b/plato/datasources/nanochat.py index 28c805943..3a0f9201a 100644 --- a/plato/datasources/nanochat.py +++ b/plato/datasources/nanochat.py @@ -35,7 +35,10 @@ def _resolve_base_dir(base_dir: str | Path | None) -> Path | None: def _parquet_available(base_dir: Path | None) -> bool: try: ensure_nanochat_importable() - from nanochat.dataset import list_parquet_files, DATA_DIR # type: ignore[attr-defined] + from nanochat.dataset import ( # type: ignore[attr-defined] + DATA_DIR, + list_parquet_files, + ) except (ThirdPartyImportError, ImportError): # pragma: no cover - defensive return False @@ -182,9 +185,13 @@ def __init__( if cfg_data is not None: mode = getattr(cfg_data, "mode", mode) base_dir = getattr(cfg_data, "base_dir", base_dir) - max_train_batches = getattr(cfg_data, "max_train_batches", max_train_batches) + max_train_batches = getattr( + cfg_data, "max_train_batches", max_train_batches + ) max_val_batches = getattr(cfg_data, "max_val_batches", max_val_batches) - tokenizer_threads = getattr(cfg_data, "tokenizer_threads", tokenizer_threads) + tokenizer_threads = getattr( + cfg_data, "tokenizer_threads", tokenizer_threads + ) tokenizer_batch_size = getattr( cfg_data, "tokenizer_batch_size", tokenizer_batch_size ) @@ -208,7 +215,9 @@ def __init__( resolved_base_dir = _resolve_base_dir(base_dir) dataset_mode = mode if dataset_mode == "auto": - dataset_mode = "parquet" if _parquet_available(resolved_base_dir) else "synthetic" + dataset_mode = ( + "parquet" if _parquet_available(resolved_base_dir) else "synthetic" + ) self.trainset = NanochatStreamingDataset( split="train", diff --git a/plato/datasources/registry.py b/plato/datasources/registry.py index 1334c027c..d4bda0c8d 100644 --- a/plato/datasources/registry.py +++ b/plato/datasources/registry.py @@ -11,8 +11,8 @@ feature, femnist, huggingface, - nanochat, lora, + nanochat, purchase, texas, tiny_imagenet, diff --git a/plato/models/nanochat.py b/plato/models/nanochat.py index 290111353..afd8d4a59 100644 --- a/plato/models/nanochat.py +++ b/plato/models/nanochat.py @@ -17,7 +17,6 @@ from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable - DEFAULT_MODEL_CONFIG: dict[str, int] = { "sequence_len": 2048, "vocab_size": 50304, @@ -30,10 +29,10 @@ def _import_nanochat_modules(): ensure_nanochat_importable() - from nanochat.gpt import GPT, GPTConfig # type: ignore[attr-defined] from nanochat.checkpoint_manager import ( # type: ignore[attr-defined] load_model_from_dir, ) + from nanochat.gpt import GPT, GPTConfig # type: ignore[attr-defined] return GPT, GPTConfig, load_model_from_dir diff --git a/plato/models/registry.py b/plato/models/registry.py index 4ce573e94..382a2d064 100644 --- a/plato/models/registry.py +++ b/plato/models/registry.py @@ -14,8 +14,8 @@ general_multilayer, huggingface, lenet5, - nanochat, multilayer, + nanochat, resnet, torch_hub, vgg, diff --git a/plato/processors/nanochat_tokenizer.py b/plato/processors/nanochat_tokenizer.py index e67565f5f..7460ae910 100644 --- a/plato/processors/nanochat_tokenizer.py +++ b/plato/processors/nanochat_tokenizer.py @@ -7,8 +7,8 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence import pickle +from collections.abc import Iterable, Sequence from pathlib import Path from typing import Any @@ -111,7 +111,9 @@ def _infer_special_tokens(self): encode_single_token = self._encoding.encode_single_token special_token_set = getattr(self._encoding, "special_tokens_set", set()) except AttributeError as exc: - raise RuntimeError("tiktoken encoding missing expected interfaces.") from exc + raise RuntimeError( + "tiktoken encoding missing expected interfaces." + ) from exc mapping = {} for token in SPECIAL_TOKENS: diff --git a/plato/trainers/nanochat.py b/plato/trainers/nanochat.py index 5415cb6bb..aa025f41a 100644 --- a/plato/trainers/nanochat.py +++ b/plato/trainers/nanochat.py @@ -55,7 +55,9 @@ def create_train_loader( num_workers=0, collate_fn=_first_element_collate, ) - return self._fallback.create_train_loader(trainset, sampler, batch_size, context) + return self._fallback.create_train_loader( + trainset, sampler, batch_size, context + ) class NanochatTrainingStepStrategy(TrainingStepStrategy): @@ -112,7 +114,9 @@ def state_dict(self) -> dict[str, Any]: def load_state_dict(self, state_dict: dict[str, Any]) -> None: for optimizer, payload in zip( - self.optimizers, state_dict.get("optimizers", []), strict=False # type: ignore[arg-type] + self.optimizers, + state_dict.get("optimizers", []), + strict=False, # type: ignore[arg-type] ): optimizer.load_state_dict(payload) @@ -262,7 +266,9 @@ def __init__( data_loader_strategy = NanochatDataLoaderStrategy() evaluation_cfg = getattr(Config(), "evaluation", None) - evaluation_type = getattr(evaluation_cfg, "type", "").lower() if evaluation_cfg else "" + evaluation_type = ( + getattr(evaluation_cfg, "type", "").lower() if evaluation_cfg else "" + ) if evaluation_type == "nanochat_core": max_per_task = getattr(evaluation_cfg, "max_per_task", -1) max_per_task_value = -1 if max_per_task is None else int(max_per_task) diff --git a/plato/trainers/registry.py b/plato/trainers/registry.py index 977d04fd2..ef47a287e 100644 --- a/plato/trainers/registry.py +++ b/plato/trainers/registry.py @@ -10,9 +10,11 @@ basic, composable, gan, - nanochat as nanochat_trainer, split_learning, ) +from plato.trainers import ( + nanochat as nanochat_trainer, +) registered_trainers = { "composable": composable.ComposableTrainer, diff --git a/tests/test_nanochat_integration.py b/tests/test_nanochat_integration.py index 4a60cae36..4af61fe8a 100644 --- a/tests/test_nanochat_integration.py +++ b/tests/test_nanochat_integration.py @@ -9,7 +9,6 @@ from plato.config import Config, ConfigNode - pytestmark = pytest.mark.integration _RUSTBPE_AVAILABLE = importlib.util.find_spec("rustbpe") is not None @@ -91,9 +90,7 @@ def test_nanochat_trainer_smoke(temp_config, tmp_path): assert elapsed >= 0.0 model_dir = Config().params["model_path"] - checkpoint_name = ( - f"{cfg.trainer.model_name}_{trainer.client_id}_{Config().params['run_id']}.safetensors" - ) + checkpoint_name = f"{cfg.trainer.model_name}_{trainer.client_id}_{Config().params['run_id']}.safetensors" assert os.path.exists(os.path.join(model_dir, checkpoint_name)) @@ -105,6 +102,8 @@ def test_nanochat_trainer_selects_core_eval_strategy(temp_config, monkeypatch): from plato.models.nanochat import Model as NanochatModel from plato.trainers.nanochat import ( NanochatCoreTestingStrategy, + ) + from plato.trainers.nanochat import ( Trainer as NanochatTrainer, ) From ac5bebaa932d29936367c0a8d9ea844663a6fa47 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Wed, 29 Oct 2025 16:54:51 -0400 Subject: [PATCH 04/33] Added benchmark configuration ([evaluation]) support in config.py. --- plato/config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/plato/config.py b/plato/config.py index ec7d840a0..56f35cfa8 100644 --- a/plato/config.py +++ b/plato/config.py @@ -153,6 +153,7 @@ class Config: clients: Any server: Any data: Any + evaluation: Any trainer: Any algorithm: Any results: Any @@ -342,6 +343,10 @@ def __new__(cls): Config.params["base_path"], "data" ) + # User-defined evaluation configuration + if hasattr(config, "evaluation"): + Config.evaluation = config.evaluation + # Pretrained models if hasattr(Config().server, "model_path"): Config.params["model_path"] = os.path.join( From 4501781bc99aadb038946218bbab4e91c2f1246d Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Wed, 29 Oct 2025 16:56:15 -0400 Subject: [PATCH 05/33] Added test to verify that [evaluation] configuration is properly loaded. --- tests/test_config_loader.py | 52 +++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index c22b0875c..01952518a 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -157,3 +157,55 @@ def test_config_base_path_used_without_cli_override(tmp_path: Path, monkeypatch) if hasattr(Config, "args"): delattr(Config, "args") Config._cli_overrides = {} + + +def test_config_loads_evaluation_section(tmp_path: Path, monkeypatch): + """Test that [evaluation] configuration is properly loaded.""" + config_base = tmp_path / "runtime" + config_path = tmp_path / "config.toml" + + config_data = { + "clients": {"type": "simple", "total_clients": 2, "per_round": 1}, + "server": {"address": "127.0.0.1", "port": 8000}, + "data": {"datasource": "MNIST"}, + "trainer": {"type": "basic", "rounds": 1, "epochs": 1, "batch_size": 10}, + "algorithm": {"type": "fedavg"}, + "evaluation": { + "type": "nanochat_core", + "max_per_task": 128, + "bundle_dir": "/custom/path/to/nanochat", + }, + } + + toml_writer.dump(config_data, config_path) + + monkeypatch.delenv("config_file", raising=False) + monkeypatch.setattr( + sys, + "argv", + [ + sys.argv[0], + "--config", + str(config_path), + "--base", + str(config_base), + ], + ) + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} + + config = Config() + + assert hasattr(config, "evaluation") + assert config.evaluation.type == "nanochat_core" + assert config.evaluation.max_per_task == 128 + assert config.evaluation.bundle_dir == "/custom/path/to/nanochat" + assert config_base.is_dir() + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} From a8efbeacf91ea7db0951ee227a71d2ec980441c0 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Wed, 29 Oct 2025 17:52:47 -0400 Subject: [PATCH 06/33] Fixed tensor contiguity issue in datasource. - Resolved a RuntimeError caused by non-contiguous tensors during view operations (in nanochat - gpt.py): "view size is not compatible with input tensor's size and stride...". Replaced .view() with .reshape() --- plato/datasources/nanochat.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/plato/datasources/nanochat.py b/plato/datasources/nanochat.py index 3a0f9201a..9718184cc 100644 --- a/plato/datasources/nanochat.py +++ b/plato/datasources/nanochat.py @@ -122,8 +122,8 @@ def _synthetic_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: dtype=torch.long, generator=generator, ) - inputs = tokens[:, :-1].to(dtype=torch.long) - targets = tokens[:, 1:].to(dtype=torch.long) + inputs = tokens[:, :-1].contiguous() + targets = tokens[:, 1:].contiguous() yield inputs, targets def _parquet_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: @@ -141,7 +141,10 @@ def _parquet_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: device=self.device, ) for inputs, targets in loader: - yield inputs.to(dtype=torch.long), targets.to(dtype=torch.long) + yield ( + inputs.to(dtype=torch.long).contiguous(), + targets.to(dtype=torch.long).contiguous(), + ) def __iter__(self): iterable = ( From f0bc22d2fe7313da0382e4ce4e42de77e21492d8 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Wed, 29 Oct 2025 17:54:12 -0400 Subject: [PATCH 07/33] Fixed KeyError: 'train_loss'. - Resolved an issue where the configuration requested 'train_loss' in the results, but the server's get_logged_items() did not include it. --- plato/clients/strategies/defaults.py | 13 +++++++++++++ plato/servers/fedavg.py | 18 +++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/plato/clients/strategies/defaults.py b/plato/clients/strategies/defaults.py index 7b104c163..a7a54c86d 100644 --- a/plato/clients/strategies/defaults.py +++ b/plato/clients/strategies/defaults.py @@ -383,6 +383,18 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: except TypeError: num_samples = 0 + # Extract train_loss from trainer's run_history if available + train_loss = None + if ( + context.trainer is not None + and hasattr(context.trainer, "run_history") + and context.trainer.run_history is not None + ): + try: + train_loss = context.trainer.run_history.get_latest_metric("train_loss") + except (AttributeError, KeyError, IndexError): + train_loss = None + report = SimpleNamespace( client_id=context.client_id, num_samples=num_samples, @@ -390,6 +402,7 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: training_time=training_time, comm_time=time.time(), update_response=False, + train_loss=train_loss, ) return report, weights diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index ee636cb49..1f5205f80 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -273,7 +273,7 @@ def clients_processed(self) -> None: def get_logged_items(self) -> dict: """Get items to be logged by the LogProgressCallback class in a .csv file.""" - return { + logged = { "round": self.current_round, "accuracy": self.accuracy, "accuracy_std": self.accuracy_std, @@ -291,6 +291,22 @@ def get_logged_items(self) -> dict: "comm_overhead": self.comm_overhead, } + # Add train_loss if available from client reports + if self.updates and hasattr(self.updates[0].report, "train_loss"): + # Compute weighted average of train_loss across clients + total_samples = sum(update.report.num_samples for update in self.updates) + if total_samples > 0: + weighted_loss = sum( + update.report.train_loss * update.report.num_samples + for update in self.updates + if update.report.train_loss is not None + ) + logged["train_loss"] = weighted_loss / total_samples + else: + logged["train_loss"] = 0.0 + + return logged + @staticmethod def get_accuracy_mean_std(updates): """Compute the accuracy mean and standard deviation across clients.""" From 48100806fd1c8804dcb5f3540ab57a7c04f8683c Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 30 Oct 2025 00:21:56 -0400 Subject: [PATCH 08/33] Fixed train_loss aggregation in FedAvg server to handle None values. --- plato/servers/fedavg.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index 1f5205f80..e9f9c0fde 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -294,7 +294,11 @@ def get_logged_items(self) -> dict: # Add train_loss if available from client reports if self.updates and hasattr(self.updates[0].report, "train_loss"): # Compute weighted average of train_loss across clients - total_samples = sum(update.report.num_samples for update in self.updates) + total_samples = sum( + update.report.num_samples + for update in self.updates + if update.report.train_loss is not None + ) if total_samples > 0: weighted_loss = sum( update.report.train_loss * update.report.num_samples @@ -303,7 +307,7 @@ def get_logged_items(self) -> dict: ) logged["train_loss"] = weighted_loss / total_samples else: - logged["train_loss"] = 0.0 + logged["train_loss"] = None return logged From eb736ebde6cb967a389fe5fc1f70304f2f6d6455 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 30 Oct 2025 00:23:48 -0400 Subject: [PATCH 09/33] Added evaluation configs for nanochat CORE metric. --- configs/Nanochat/synthetic_micro.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/configs/Nanochat/synthetic_micro.toml b/configs/Nanochat/synthetic_micro.toml index b22941f53..abc3a6452 100644 --- a/configs/Nanochat/synthetic_micro.toml +++ b/configs/Nanochat/synthetic_micro.toml @@ -26,6 +26,11 @@ device = "cpu" vocab_size = 512 synthetic_seed = 123 +[evaluation] +type = "nanochat_core" +# bundle_dir = "~/nanochat" # Optional, defaults to nanochat base dir or Plato's data directory +max_per_task = 16 # Optional, -1 means run all examples + [trainer] type = "nanochat" rounds = 1 @@ -46,4 +51,4 @@ n_kv_head = 4 n_embd = 256 [results] -types = "round, elapsed_time, train_loss" +types = "round, elapsed_time, core_metric, train_loss" From d9fe94a949cbc16f02725006fb8cca138182b013 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 30 Oct 2025 00:44:50 -0400 Subject: [PATCH 10/33] Added automatic download of nanochat CORE evaluation bundle. --- plato/evaluators/nanochat_core.py | 168 +++++++++++++++++++++++++++--- 1 file changed, 151 insertions(+), 17 deletions(-) diff --git a/plato/evaluators/nanochat_core.py b/plato/evaluators/nanochat_core.py index 74a7c1447..edac1a898 100644 --- a/plato/evaluators/nanochat_core.py +++ b/plato/evaluators/nanochat_core.py @@ -4,13 +4,20 @@ from __future__ import annotations +import contextlib import csv +import gzip import json import logging +import os import random +import sys +import tarfile import time +import zipfile from pathlib import Path from typing import Any +from urllib.parse import urlparse try: import torch @@ -20,42 +27,169 @@ "Install the `nanochat` extra (includes torch)." ) from exc +import requests import yaml +from plato.config import Config from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable LOGGER = logging.getLogger(__name__) +# URL for the CORE evaluation bundle +EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" + + +@contextlib.contextmanager +def _download_guard(data_path: str): + """Serialize dataset downloads to avoid concurrent corruption.""" + os.makedirs(data_path, exist_ok=True) + lock_file = os.path.join(data_path, ".download.lock") + lock_fd = None + waited = False + + try: + while True: + try: + lock_fd = os.open(lock_file, os.O_CREAT | os.O_EXCL | os.O_RDWR) + break + except FileExistsError: + if not waited: + LOGGER.info( + "Another process is preparing the dataset at %s. Waiting.", + data_path, + ) + waited = True + time.sleep(1) + yield + finally: + if lock_fd is not None: + os.close(lock_fd) + try: + os.remove(lock_file) + except FileNotFoundError: + pass + + +def _download_eval_bundle(url: str, data_path: str) -> None: + """Download the CORE evaluation bundle from a URL if not already available.""" + url_parse = urlparse(url) + file_name = os.path.join(data_path, url_parse.path.split("/")[-1]) + os.makedirs(data_path, exist_ok=True) + sentinel = Path(f"{file_name}.complete") + + if sentinel.exists(): + return + + with _download_guard(data_path): + if sentinel.exists(): + return + + LOGGER.info("Downloading CORE evaluation bundle from %s.", url) + + res = requests.get(url, stream=True, timeout=60) + total_size = int(res.headers.get("Content-Length", 0)) + downloaded_size = 0 + + with open(file_name, "wb+") as file: + for chunk in res.iter_content(chunk_size=1024): + if not chunk: + continue + downloaded_size += len(chunk) + file.write(chunk) + file.flush() + if total_size: + sys.stdout.write(f"\r{100 * downloaded_size / total_size:.1f}%") + sys.stdout.flush() + if total_size: + sys.stdout.write("\n") + + # Unzip the compressed file just downloaded + LOGGER.info("Decompressing the CORE evaluation bundle.") + name, suffix = os.path.splitext(file_name) + + if file_name.endswith("tar.gz"): + with tarfile.open(file_name, "r:gz") as tar: + tar.extractall(data_path) + os.remove(file_name) + elif suffix == ".zip": + LOGGER.info("Extracting %s to %s.", file_name, data_path) + with zipfile.ZipFile(file_name, "r") as zip_ref: + zip_ref.extractall(data_path) + os.remove(file_name) + elif suffix == ".gz": + with gzip.open(file_name, "rb") as zipped_file: + with open(name, "wb") as unzipped_file: + unzipped_file.write(zipped_file.read()) + os.remove(file_name) + else: + LOGGER.warning("Unknown compressed file type for %s.", file_name) + + sentinel.touch() + LOGGER.info("CORE evaluation bundle downloaded and extracted successfully.") + def _resolve_bundle_paths(bundle_dir: str | Path | None) -> tuple[Path, Path, Path]: """Resolve the configuration, metadata, and dataset paths for CORE evaluation.""" ensure_nanochat_importable() - from nanochat.common import get_base_dir # pylint: disable=import-error - if bundle_dir is None: - base_path = Path(get_base_dir()) + def _get_default_base_path() -> Path: + """Get the default base path, trying nanochat first, then Plato's data directory.""" + try: + from nanochat.common import get_base_dir # pylint: disable=import-error + + return Path(get_base_dir()) + except (ImportError, OSError, PermissionError): + plato_data_path = Config().params.get("data_path", "./runtime/data") + path = Path(plato_data_path) / "nanochat" + LOGGER.info("Using Plato data directory for CORE bundle: %s", path) + return path + + # Determine base path + if bundle_dir is not None: + try: + base_path = Path(bundle_dir).expanduser().resolve() + LOGGER.info("Using bundle_dir from config: %s", base_path) + except (OSError, PermissionError, ValueError) as exc: + LOGGER.warning( + "Cannot use bundle_dir '%s': %s. Using default location.", + bundle_dir, + exc, + ) + base_path = _get_default_base_path() else: - base_path = Path(bundle_dir).expanduser().resolve() + base_path = _get_default_base_path() + + # Ensure base path exists + base_path.mkdir(parents=True, exist_ok=True) eval_bundle_dir = base_path / "eval_bundle" config_path = eval_bundle_dir / "core.yaml" data_dir = eval_bundle_dir / "eval_data" metadata_path = eval_bundle_dir / "eval_meta_data.csv" - if not config_path.exists(): - raise FileNotFoundError( - f"CORE evaluation config not found at {config_path}. " - "Ensure the Nanochat eval bundle is downloaded." - ) - if not data_dir.exists(): - raise FileNotFoundError( - f"CORE evaluation data directory not found at {data_dir}. " - "Ensure the Nanochat eval bundle is downloaded." - ) - if not metadata_path.exists(): - raise FileNotFoundError( - f"CORE evaluation metadata CSV not found at {metadata_path}." + # Check if evaluation bundle exists, download if missing + if not config_path.exists() or not data_dir.exists() or not metadata_path.exists(): + LOGGER.info( + "CORE evaluation bundle not found at %s. Downloading automatically...", + base_path, ) + _download_eval_bundle(EVAL_BUNDLE_URL, str(base_path)) + + # Verify download succeeded + if not config_path.exists(): + raise FileNotFoundError( + f"CORE evaluation config not found at {config_path}. " + "Ensure the Nanochat eval bundle is downloaded." + ) + if not data_dir.exists(): + raise FileNotFoundError( + f"CORE evaluation data directory not found at {data_dir}. " + "Ensure the Nanochat eval bundle is downloaded." + ) + if not metadata_path.exists(): + raise FileNotFoundError( + f"CORE evaluation metadata CSV not found at {metadata_path}." + ) return config_path, data_dir, metadata_path From 6f349500331933beeb5866407af8e1ee3f9f347f Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 30 Oct 2025 00:50:01 -0400 Subject: [PATCH 11/33] Using tokenizer's vocab_size to match between model and tokenizer. - To avoid vocabulary size mismatch between model and tokenizer during CORE evaluation. --- plato/models/nanochat.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/plato/models/nanochat.py b/plato/models/nanochat.py index afd8d4a59..429d034b2 100644 --- a/plato/models/nanochat.py +++ b/plato/models/nanochat.py @@ -6,6 +6,7 @@ from dataclasses import fields from typing import Any +import logging try: import torch @@ -126,9 +127,47 @@ def get(model_name: str | None = None, **kwargs: Any): ) return model + # Model vocab_size MUST match tokenizer vocab_size to avoid IndexError + try: + from nanochat.tokenizer import get_tokenizer + + tokenizer = get_tokenizer() + actual_vocab_size = tokenizer.get_vocab_size() + + # Override vocab_size with tokenizer's actual vocab_size + if "vocab_size" in config_kwargs: + configured_vocab = config_kwargs["vocab_size"] + if configured_vocab != actual_vocab_size: + logging.warning( + f"[Nanochat Model] Config specifies vocab_size={configured_vocab}, " + f"but tokenizer has vocab_size={actual_vocab_size}. " + f"Using tokenizer's vocab_size={actual_vocab_size} to match tokenizer." + ) + config_kwargs["vocab_size"] = actual_vocab_size + + logging.info( + f"[Nanochat Model] Using vocab_size={actual_vocab_size} from tokenizer" + ) + except Exception as e: + logging.warning( + f"[Nanochat Model] Could not auto-detect vocab_size from tokenizer: {e}. " + f"Using configured or default vocab_size." + ) + config = GPTConfig(**config_kwargs) model = GPT(config) if init_weights: model.init_weights() + + # This allows CORE evaluation and other components to access the tokenizer + try: + from nanochat.tokenizer import get_tokenizer + + tokenizer = get_tokenizer() + setattr(model, "nanochat_tokenizer", tokenizer) + except Exception: + pass + # Set max_seq_len for CORE evaluation truncation + setattr(model, "max_seq_len", config.sequence_len) setattr(model, "nanochat_config", config_kwargs) return model From 35f25ebcc7b6f063d329b608a8df29679e2f0008 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 30 Oct 2025 00:55:24 -0400 Subject: [PATCH 12/33] Added outputs for Nanochat CORE evaluation in FedAvg server. --- plato/servers/fedavg.py | 12 ++++++++++++ tests/test_config_loader.py | 1 - 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index e9f9c0fde..fe001cb34 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -252,6 +252,14 @@ async def _process_reports(self): trainer = self.require_trainer() self.accuracy = trainer.test(self.testset, self.testset_sampler) + # Extract CORE evaluation results if available (Nanochat CORE evaluation) + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + self._core_metric = core_results.get("core_metric", self.accuracy) + if hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( @@ -309,6 +317,10 @@ def get_logged_items(self) -> dict: else: logged["train_loss"] = None + # Add core_metric if Nanochat CORE evaluation was performed + if hasattr(self, "_core_metric"): + logged["core_metric"] = self._core_metric + return logged @staticmethod diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 01952518a..2b1bb9b75 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -203,7 +203,6 @@ def test_config_loads_evaluation_section(tmp_path: Path, monkeypatch): assert config.evaluation.type == "nanochat_core" assert config.evaluation.max_per_task == 128 assert config.evaluation.bundle_dir == "/custom/path/to/nanochat" - assert config_base.is_dir() Config._instance = None if hasattr(Config, "args"): From e4ae761265364126ff6c0f67c8359a8e702de685 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Fri, 31 Oct 2025 11:08:39 -0400 Subject: [PATCH 13/33] Added specific logging output for CORE benchmark metrics. - Updated log message from "global accuracy" to "Average Centered CORE benchmark metric" - Used ruff to format code --- cleanup.py | 8 ++------ plato/servers/fedavg.py | 12 +++++++++++- plato/servers/fedavg_cs.py | 34 +++++++++++++++++++++++++++++++-- plato/servers/split_learning.py | 27 ++++++++++++++++++++++---- 4 files changed, 68 insertions(+), 13 deletions(-) diff --git a/cleanup.py b/cleanup.py index 68c658019..f03f5f3cf 100644 --- a/cleanup.py +++ b/cleanup.py @@ -146,18 +146,14 @@ def main() -> None: continue cleared = clean_directory(runtime_dir) - print( - f"Failed to delete {runtime_dir}; cleared {cleared} items instead." - ) + print(f"Failed to delete {runtime_dir}; cleared {cleared} items instead.") fallback_dirs += 1 fallback_items += cleared if runtime_total == 0: print("No runtime directories found.") else: - print( - f"Removed {runtime_removed} of {runtime_total} runtime directories." - ) + print(f"Removed {runtime_removed} of {runtime_total} runtime directories.") if fallback_dirs: print( f"Cleared {fallback_items} items in " diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index fe001cb34..36fee944b 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -260,7 +260,16 @@ async def _process_reports(self): core_results = trainer.context.state["nanochat_core_results"] self._core_metric = core_results.get("core_metric", self.accuracy) - if hasattr(Config().trainer, "target_perplexity"): + # If CORE benchmark was run via a Nanochat testing strategy, report the specialized CORE metric instead of the generic 'Global model accuracy' label. + core_metric = getattr(self, "_core_metric", None) + + if core_metric is not None: + logging.info( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + elif hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" @@ -284,6 +293,7 @@ def get_logged_items(self) -> dict: logged = { "round": self.current_round, "accuracy": self.accuracy, + "core_metric": getattr(self, "_core_metric", None), "accuracy_std": self.accuracy_std, "elapsed_time": self.wall_time - self.initial_wall_time, "processing_time": max( diff --git a/plato/servers/fedavg_cs.py b/plato/servers/fedavg_cs.py index fceff363a..eaba3caf6 100644 --- a/plato/servers/fedavg_cs.py +++ b/plato/servers/fedavg_cs.py @@ -253,7 +253,22 @@ async def _process_reports(self): logging.info("[%s] Started model testing.", self) self.accuracy = trainer.test(self.testset, self.testset_sampler) - if hasattr(Config().trainer, "target_perplexity"): + # Extract CORE evaluation results if available (Nanochat CORE evaluation) + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + self._core_metric = core_results.get("core_metric", None) + + core_metric = getattr(self, "_core_metric", None) + if core_metric is not None: + logging.info( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + elif hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" @@ -274,7 +289,22 @@ async def _process_reports(self): logging.info("[%s] Started model testing.", self) self.accuracy = trainer.test(self.testset, self.testset_sampler) - if hasattr(Config().trainer, "target_perplexity"): + # Extract CORE evaluation results if available (Nanochat CORE evaluation) + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + self._core_metric = core_results.get("core_metric", None) + + core_metric = getattr(self, "_core_metric", None) + if core_metric is not None: + logging.info( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + elif hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" diff --git a/plato/servers/split_learning.py b/plato/servers/split_learning.py index 83cf2804a..49fdc2d4c 100644 --- a/plato/servers/split_learning.py +++ b/plato/servers/split_learning.py @@ -99,11 +99,30 @@ async def aggregate_weights(self, updates, baseline_weights, weights_received): self.test_accuracy = trainer.test(self.testset, self.testset_sampler) - logging.warning( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + core_metric = core_results.get("core_metric", None) + if core_metric is not None: + logging.warning( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + else: + logging.warning( + fonts.colourize( + f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" + ) + ) + else: + logging.warning( + fonts.colourize( + f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" + ) ) - ) self.phase = "prompt" # Change client in next round self.next_client = True From 432fe50d3ffea35a7ee653b603439230c592fb73 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Fri, 7 Nov 2025 23:41:41 +0000 Subject: [PATCH 14/33] Typed the Nanochat datasource/optimizer plumbing and enforced valid CORE metadata so ty check is clean again. --- plato/datasources/nanochat.py | 4 ++-- plato/evaluators/nanochat_core.py | 11 +++++++++-- plato/trainers/nanochat.py | 8 ++++++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/plato/datasources/nanochat.py b/plato/datasources/nanochat.py index 9718184cc..7e53c71c6 100644 --- a/plato/datasources/nanochat.py +++ b/plato/datasources/nanochat.py @@ -222,7 +222,7 @@ def __init__( "parquet" if _parquet_available(resolved_base_dir) else "synthetic" ) - self.trainset = NanochatStreamingDataset( + self.trainset: NanochatStreamingDataset = NanochatStreamingDataset( split="train", batch_size=resolved_batch_size, sequence_length=resolved_sequence_len, @@ -235,7 +235,7 @@ def __init__( vocab_size=vocab_size, synthetic_seed=synthetic_seed, ) - self.testset = NanochatStreamingDataset( + self.testset: NanochatStreamingDataset = NanochatStreamingDataset( split="val", batch_size=resolved_batch_size, sequence_length=resolved_sequence_len, diff --git a/plato/evaluators/nanochat_core.py b/plato/evaluators/nanochat_core.py index edac1a898..97a49e9f2 100644 --- a/plato/evaluators/nanochat_core.py +++ b/plato/evaluators/nanochat_core.py @@ -308,15 +308,22 @@ def run_core_evaluation( LOGGER.debug("Skipping unnamed CORE task entry: %s", task) continue + dataset_uri = task.get("dataset_uri") + if not isinstance(dataset_uri, str): + LOGGER.debug( + "Skipping CORE task %s due to missing dataset_uri metadata.", task + ) + continue + task_meta = { "task_type": task.get("icl_task_type"), - "dataset_uri": task.get("dataset_uri"), + "dataset_uri": dataset_uri, "num_fewshot": task.get("num_fewshot", [0])[0], "continuation_delimiter": task.get("continuation_delimiter", " "), } start_time = time.perf_counter() - data = _load_task_data(data_dir, task_meta["dataset_uri"]) + data = _load_task_data(data_dir, dataset_uri) shuffle_rng = random.Random(1337) shuffle_rng.shuffle(data) if max_per_task > 0: diff --git a/plato/trainers/nanochat.py b/plato/trainers/nanochat.py index aa025f41a..307ccf911 100644 --- a/plato/trainers/nanochat.py +++ b/plato/trainers/nanochat.py @@ -153,12 +153,16 @@ def __init__( def create_optimizer( self, model: torch.nn.Module, context: TrainingContext ) -> _OptimizerBundle: - if not hasattr(model, "setup_optimizers"): + if not isinstance(model, torch.nn.Module): + raise TypeError("Nanochat optimizer strategy requires a torch.nn.Module.") + + setup_fn = getattr(model, "setup_optimizers", None) + if not callable(setup_fn): raise AttributeError( "Nanochat model is expected to expose setup_optimizers()." ) - optimizers = model.setup_optimizers( + optimizers = setup_fn( unembedding_lr=self.unembedding_lr, embedding_lr=self.embedding_lr, matrix_lr=self.matrix_lr, From da04815748868e07d3d917dace28c2cd7c444331 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Fri, 7 Nov 2025 23:53:13 +0000 Subject: [PATCH 15/33] All nanochat tests now pass. --- pyproject.toml | 20 ++++++++++---------- tests/test_nanochat_integration.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e12d65cbc..e1e172a6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,16 +48,16 @@ mlx = [ "mlx", ] nanochat = [ - "datasets>=2.14.0", - "numpy==1.26.4", - "psutil>=5.9.0", - "regex>=2023.8.8", - "tiktoken>=0.5.0", - "tokenizers>=0.15.0", - "torch>=2.2.0", - "wandb>=0.16.0", - "jinja2>=3.0", - "PyYAML>=6.0", + "datasets", + "numpy", + "psutil", + "regex", + "tiktoken", + "tokenizers", + "torch", + "wandb", + "jinja2", + "PyYAML", ] [project.urls] diff --git a/tests/test_nanochat_integration.py b/tests/test_nanochat_integration.py index 4af61fe8a..6a9c174e8 100644 --- a/tests/test_nanochat_integration.py +++ b/tests/test_nanochat_integration.py @@ -8,6 +8,7 @@ import pytest from plato.config import Config, ConfigNode +from plato.utils.third_party import ensure_nanochat_importable pytestmark = pytest.mark.integration @@ -41,6 +42,11 @@ def test_nanochat_trainer_smoke(temp_config, tmp_path): _ = pytest.importorskip( "torch", reason="Nanochat trainer smoke requires torch (nanochat extra)." ) + ensure_nanochat_importable() + _ = pytest.importorskip( + "nanochat", + reason="Nanochat trainer smoke requires the nanochat package (nanochat extra).", + ) from plato.datasources.nanochat import DataSource as NanochatDataSource from plato.models.nanochat import Model as NanochatModel from plato.trainers.nanochat import Trainer as NanochatTrainer @@ -99,6 +105,11 @@ def test_nanochat_trainer_selects_core_eval_strategy(temp_config, monkeypatch): _ = pytest.importorskip( "torch", reason="Nanochat trainer requires torch (nanochat extra)." ) + ensure_nanochat_importable() + _ = pytest.importorskip( + "nanochat", + reason="Nanochat trainer requires the nanochat package (nanochat extra).", + ) from plato.models.nanochat import Model as NanochatModel from plato.trainers.nanochat import ( NanochatCoreTestingStrategy, From 279d05ecb6eeed0d79e475b2c69cbb21c020c52c Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Sat, 8 Nov 2025 01:38:01 -0500 Subject: [PATCH 16/33] Updated nanochat README with setup and troubleshooting notes. - Added instructions for initializing submodules and resolving maturin build failure. --- examples/nanochat/README.md | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/nanochat/README.md b/examples/nanochat/README.md index 16ae42fc5..1d061345b 100644 --- a/examples/nanochat/README.md +++ b/examples/nanochat/README.md @@ -4,14 +4,26 @@ This workspace hosts Nanochat-focused experiments within Plato. ## Quick Start -1. Install dependencies (including the vendored tokenizer build requirements): +1. Initialize the nanochat submodule (required for the nanochat integration): + + ```bash + git submodule update --init --recursive + ``` + +2. Install dependencies (including the vendored tokenizer build requirements): ```bash uv sync --extra nanochat uv run --with ./external/nanochat maturin develop --release ``` + **Troubleshooting:** If you encounter a `maturin failed` error with "Can't find Cargo.toml", run the maturin command from within the nanochat directory: + + ```bash + uv sync --extra nanochat + cd external/nanochat && uv run maturin develop --release && cd ../.. + ``` -2. Run the synthetic smoke configuration: +3. Run the synthetic smoke configuration: ```bash uv run --extra nanochat python plato.py --config configs/Nanochat/synthetic_micro.toml From af0bafab579b7dff754c154dad523dd55e34e8b7 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 13 Nov 2025 18:37:09 +0000 Subject: [PATCH 17/33] Added configuration file for NanoChat Parquet mode. --- configs/Nanochat/parquet_micro.toml | 53 +++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 configs/Nanochat/parquet_micro.toml diff --git a/configs/Nanochat/parquet_micro.toml b/configs/Nanochat/parquet_micro.toml new file mode 100644 index 000000000..cfe09a921 --- /dev/null +++ b/configs/Nanochat/parquet_micro.toml @@ -0,0 +1,53 @@ +[clients] +type = "simple" +total_clients = 10 +per_round = 3 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/nanochat/parquet" +model_path = "models/nanochat/parquet" + +[data] +datasource = "Nanochat" +sampler = "iid" +partition_size = 1 +random_seed = 1 +mode = "parquet" +max_train_batches = 16 +max_val_batches = 1 +tokenizer_threads = 2 +tokenizer_batch_size = 32 +device = "cuda" +vocab_size = 512 +synthetic_seed = 123 + +[evaluation] +type = "nanochat_core" +# bundle_dir = "~/nanochat" +max_per_task = 16 + +[trainer] +type = "nanochat" +rounds = 10000 +epochs = 5 +batch_size = 1 +model_name = "nanochat" +optimizer = "nanochat" + +[algorithm] +type = "fedavg" + +[parameters.model] +sequence_len = 256 +vocab_size = 50304 +n_layer = 4 +n_head = 4 +n_kv_head = 4 +n_embd = 256 + +[results] +types = "round, elapsed_time, core_metric, train_loss" From 2b7cf3d6832dc85c967835ec4738f4049a0a7661 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 13 Nov 2025 20:37:13 +0000 Subject: [PATCH 18/33] Formatted code with Ruff and applied autofixes. --- .../model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py | 3 ++- plato/models/nanochat.py | 2 +- pyproject.toml | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py index 487547209..e8d700c97 100644 --- a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py +++ b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py @@ -22,7 +22,6 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F -from data import build_loader from misc.config import get_config from misc.loss_ops import AdaptiveLossSoft from misc.lr_scheduler import build_scheduler @@ -37,6 +36,8 @@ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.utils import AverageMeter, ModelEma, accuracy +from data import build_loader + try: from apex import amp except ImportError: diff --git a/plato/models/nanochat.py b/plato/models/nanochat.py index 429d034b2..114a891f9 100644 --- a/plato/models/nanochat.py +++ b/plato/models/nanochat.py @@ -4,9 +4,9 @@ from __future__ import annotations +import logging from dataclasses import fields from typing import Any -import logging try: import torch diff --git a/pyproject.toml b/pyproject.toml index e1e172a6c..228545350 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ members = [ [dependency-groups] dev = [ "pytest", + "ruff", "ty", ] From 736b29ec7128c5d0362619c5e39d611beb54f51c Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 13 Nov 2025 20:56:00 +0000 Subject: [PATCH 19/33] Added two configuration files for PatchTSMixer model. - Included configurations for both pre-trained and custom modes. --- configs/TimeSeries/patchtsmixer_custom.toml | 67 ++++++++++++++++++ .../TimeSeries/patchtsmixer_pretrained.toml | 68 +++++++++++++++++++ 2 files changed, 135 insertions(+) create mode 100644 configs/TimeSeries/patchtsmixer_custom.toml create mode 100644 configs/TimeSeries/patchtsmixer_pretrained.toml diff --git a/configs/TimeSeries/patchtsmixer_custom.toml b/configs/TimeSeries/patchtsmixer_custom.toml new file mode 100644 index 000000000..d228eb1fb --- /dev/null +++ b/configs/TimeSeries/patchtsmixer_custom.toml @@ -0,0 +1,67 @@ +# Federated Learning with PatchTSMixer for Time Series Forecasting +# This configuration demonstrates using the IBM Granite PatchTSMixer model +# with time series data from HuggingFace datasets + +[clients] +type = "simple" +total_clients = 1 +per_round = 1 +do_test = false + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/patchtsmixer" +model_path = "models/timeseries/patchtsmixer" + +[data] +# ETTh1: Electricity Transformer Temperature dataset (7 features) +datasource = "ETTh1" + +partition_size = 100 # Number of training samples +sampler = "iid" +random_seed = 1 + +[trainer] +type = "HuggingFace" +rounds = 3 +max_concurrency = 2 +model_type = "huggingface" + +# Train from scratch - simpler for testing +model_name = "custom_patchtsmixer" + +# Task type: forecasting, classification, regression, or pretraining +task_type = "forecasting" + +# PatchTSMixer specific parameters (smaller model for testing) +context_length = 64 +prediction_length = 24 +num_input_channels = 7 # ETTh1 has 7 features (HUFL, HULL, MUFL, MULL, LUFL, LULL, OT) +patch_length = 8 +patch_stride = 8 +d_model = 32 # Hidden dimension of the model. Recommended to set it as a multiple of patch_length (i.e. 2-8X of patch_len). Larger value indicates more complex model. +num_layers = 3 # Number of layers to use. Recommended range is 3-15. Larger value indicates more complex model. +expansion_factor = 2 # Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model. +dropout = 0.5 +head_dropout = 0.7 +mode = "common_channel" +gated_attn = true +scaling = "std" + +# Training parameters +epochs = 2 +batch_size = 8 +optimizer = "Adam" + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.001 +weight_decay = 0.0 + +[results] +types = "round, elapsed_time, accuracy" diff --git a/configs/TimeSeries/patchtsmixer_pretrained.toml b/configs/TimeSeries/patchtsmixer_pretrained.toml new file mode 100644 index 000000000..4119c1b47 --- /dev/null +++ b/configs/TimeSeries/patchtsmixer_pretrained.toml @@ -0,0 +1,68 @@ +# Federated Learning with PatchTSMixer for Time Series Forecasting +# This configuration demonstrates using the IBM Granite PatchTSMixer model +# with time series data from HuggingFace datasets + +[clients] +type = "simple" +total_clients = 1 +per_round = 1 +do_test = false + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/patchtsmixer" +model_path = "models/timeseries/patchtsmixer" + +[data] +# ETTh1: Electricity Transformer Temperature dataset (7 features) +datasource = "ETTh1" + +partition_size = 100 # Number of training samples +sampler = "iid" +random_seed = 1 + +[trainer] +type = "HuggingFace" +rounds = 3 +max_concurrency = 2 +model_type = "huggingface" + +# Use pre-trained IBM Granite model +# For pre-trained model, the some settings must match pretrained model +model_name = "ibm-granite/granite-timeseries-patchtsmixer" + +# Task type: forecasting, classification, regression, or pretraining +task_type = "forecasting" + +# PatchTSMixer specific parameters (matching pretrained model) +context_length = 512 +prediction_length = 96 +num_input_channels = 7 +patch_length = 16 +patch_stride = 8 +d_model = 64 +num_layers = 8 +expansion_factor = 2 # Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model. +dropout = 0.5 +head_dropout = 0.7 +mode = "common_channel" +gated_attn = true +scaling = "std" + +# Training parameters +epochs = 2 # Reduced for testing +batch_size = 8 # Reduced for testing +optimizer = "Adam" + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.001 +weight_decay = 0.0 + +[results] +types = "round, elapsed_time, accuracy" From 1fe0e2270918b7db1d441e9f0efb600bd7f29542 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Thu, 13 Nov 2025 22:04:33 +0000 Subject: [PATCH 20/33] Added MSE metric output for time series models. --- plato/servers/fedavg.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index 36fee944b..d6e819bb6 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -269,18 +269,28 @@ async def _process_reports(self): f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" ) ) - elif hasattr(Config().trainer, "target_perplexity"): - logging.info( - fonts.colourize( - f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" - ) - ) else: - logging.info( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n" + trainer = self.require_trainer() + metric_name = getattr(trainer.testing_strategy, "metric_name", "accuracy") + + if metric_name == "mse": + logging.info( + fonts.colourize(f"[{self}] Global model MSE: {self.accuracy:.4f}\n") + ) + elif metric_name == "perplexity" or hasattr( + Config().trainer, "target_perplexity" + ): + logging.info( + fonts.colourize( + f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" + ) + ) + else: + logging.info( + fonts.colourize( + f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n" + ) ) - ) self.clients_processed() self.callback_handler.call_event("on_clients_processed", self) From 205043d0474f03d2be053d179972ba7085ce6640 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Fri, 14 Nov 2025 01:19:12 +0000 Subject: [PATCH 21/33] Added GitHub dataset handling (ETT datasets) for PatchTSMixer model. --- plato/datasources/ETT.py | 194 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 plato/datasources/ETT.py diff --git a/plato/datasources/ETT.py b/plato/datasources/ETT.py new file mode 100644 index 000000000..d0b931241 --- /dev/null +++ b/plato/datasources/ETT.py @@ -0,0 +1,194 @@ +""" +ETT (Electricity Transformer Temperature) datasource for time series forecasting. + +Supports all ETT datasets: +- ETTh1, ETTh2: Hourly data (1 point per hour) +- ETTm1, ETTm2: 15-minute data (4 points per hour) + +Data from: https://github.com/zhouhaoyi/ETDataset +""" + +import logging +import os +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from plato.config import Config +from plato.datasources import base + + +class ETTDataset(Dataset): + """ETT time series dataset with sliding window.""" + + def __init__(self, data, context_length, prediction_length, stride=1): + """ + Create dataset with sliding windows. + + Args: + data: pandas DataFrame or numpy array with shape (timesteps, channels) + context_length: Number of historical timesteps + prediction_length: Number of future timesteps to predict + stride: Stride for sliding window + """ + if isinstance(data, pd.DataFrame): + # Remove date column if present + if "date" in data.columns: + data = data.drop("date", axis=1) + data = data.values + + self.data = torch.FloatTensor(data) + self.context_length = context_length + self.prediction_length = prediction_length + self.stride = stride + + # Calculate number of valid windows + total_length = context_length + prediction_length + self.num_windows = max(0, (len(data) - total_length) // stride + 1) + + def __len__(self): + return self.num_windows + + def __getitem__(self, idx): + """Return past_values and future_values for PatchTSMixer.""" + start_idx = idx * self.stride + end_context = start_idx + self.context_length + end_future = end_context + self.prediction_length + + past_values = self.data[start_idx:end_context] + future_values = self.data[end_context:end_future] + + return { + "past_values": past_values, + "future_values": future_values, + } + + +class DataSource(base.DataSource): + """ETT datasource for time series forecasting (ETTh1, ETTh2, ETTm1, ETTm2).""" + + # Dataset configurations + DATASET_INFO = { + "ETTh1": {"freq": "hourly", "points_per_hour": 1}, + "ETTh2": {"freq": "hourly", "points_per_hour": 1}, + "ETTm1": {"freq": "15min", "points_per_hour": 4}, + "ETTm2": {"freq": "15min", "points_per_hour": 4}, + } + + def __init__(self, **kwargs): + super().__init__() + + # Get dataset name + dataset_name = kwargs.get( + "dataset_name", + getattr(Config().data, "dataset_name", "ETTh1") + ) + + # Validate dataset name + if dataset_name not in self.DATASET_INFO: + raise ValueError( + f"Unknown ETT dataset: {dataset_name}. " + f"Supported datasets: {list(self.DATASET_INFO.keys())}" + ) + + logging.info("Using %s (Electricity Transformer Temperature) dataset", dataset_name) + + dataset_info = self.DATASET_INFO[dataset_name] + logging.info( + "Dataset frequency: %s (%d points per hour)", + dataset_info["freq"], + dataset_info["points_per_hour"] + ) + + # Get configuration + context_length = getattr(Config().trainer, "context_length", 512) + prediction_length = getattr(Config().trainer, "prediction_length", 96) + + # Download and load the data + data_path = self._download_data(dataset_name) + df = pd.read_csv(data_path) + + logging.info( + "Loaded %s dataset with %d timesteps and %d channels", + dataset_name, + len(df), + len(df.columns) - 1, + ) # -1 for date column + + # Split into train/val/test following the standard ETT split + # Standard split: 12 months train, 4 months val, 4 months test + points_per_hour = dataset_info["points_per_hour"] + train_end = 12 * 30 * 24 * points_per_hour # 12 months + val_end = train_end + 4 * 30 * 24 * points_per_hour # + 4 months + + train_df = df[:train_end] + val_df = df[train_end:val_end] + test_df = df[val_end:] + + logging.info( + "%s split - train: %d, val: %d, test: %d", + dataset_name, + len(train_df), + len(val_df), + len(test_df), + ) + + # Create datasets with sliding windows + self.trainset = ETTDataset( + train_df, context_length, prediction_length, stride=1 + ) + + # Use validation set as test set for federated learning + self.testset = ETTDataset( + val_df, context_length, prediction_length, stride=prediction_length + ) + + logging.info( + "Created %d training windows and %d test windows", + len(self.trainset), + len(self.testset), + ) + + def _download_data(self, dataset_name): + """Download ETT dataset from GitHub if not already present.""" + data_dir = Path(Config().params["data_path"]) / "ETT-small" + data_dir.mkdir(parents=True, exist_ok=True) + + data_file = data_dir / f"{dataset_name}.csv" + + if data_file.exists(): + logging.info("%s.csv already exists", dataset_name) + return str(data_file) + + # Download from GitHub + logging.info("Downloading %s.csv from GitHub ...", dataset_name) + url = f"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/{dataset_name}.csv" + + try: + import urllib.request + + urllib.request.urlretrieve(url, str(data_file)) + logging.info("Successfully downloaded %s.csv", dataset_name) + except Exception as e: + logging.error("Failed to download %s.csv: %s", dataset_name, e) + raise RuntimeError( + f"Could not download {dataset_name} dataset from {url}. " + f"Please download it manually to {data_file}" + ) from e + + return str(data_file) + + def num_train_examples(self): + return len(self.trainset) + + def num_test_examples(self): + return len(self.testset) + + def get_train_set(self): + return self.trainset + + def get_test_set(self): + return self.testset From 12721f1e1519e6863d13bb0099f42be9bff5d9ef Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Fri, 14 Nov 2025 03:16:05 +0000 Subject: [PATCH 22/33] Added ETT datasource to the registry. --- plato/datasources/ETT.py | 13 +++++++------ plato/datasources/registry.py | 6 ++++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/plato/datasources/ETT.py b/plato/datasources/ETT.py index d0b931241..c05684c96 100644 --- a/plato/datasources/ETT.py +++ b/plato/datasources/ETT.py @@ -83,10 +83,9 @@ def __init__(self, **kwargs): # Get dataset name dataset_name = kwargs.get( - "dataset_name", - getattr(Config().data, "dataset_name", "ETTh1") + "dataset_name", getattr(Config().data, "dataset_name", "ETTh1") ) - + # Validate dataset name if dataset_name not in self.DATASET_INFO: raise ValueError( @@ -94,13 +93,15 @@ def __init__(self, **kwargs): f"Supported datasets: {list(self.DATASET_INFO.keys())}" ) - logging.info("Using %s (Electricity Transformer Temperature) dataset", dataset_name) - + logging.info( + "Using %s (Electricity Transformer Temperature) dataset", dataset_name + ) + dataset_info = self.DATASET_INFO[dataset_name] logging.info( "Dataset frequency: %s (%d points per hour)", dataset_info["freq"], - dataset_info["points_per_hour"] + dataset_info["points_per_hour"], ) # Get configuration diff --git a/plato/datasources/registry.py b/plato/datasources/registry.py index d4bda0c8d..f71f49b44 100644 --- a/plato/datasources/registry.py +++ b/plato/datasources/registry.py @@ -8,6 +8,7 @@ from plato.config import Config from plato.datasources import ( cinic10, + ETT, feature, femnist, huggingface, @@ -29,6 +30,7 @@ "TinyImageNet": tiny_imagenet, "Feature": feature, "Nanochat": nanochat, + "ETT": ETT, } registered_partitioned_datasources = {"FEMNIST": femnist} @@ -41,6 +43,10 @@ "CIFAR10": ("Torchvision", {"dataset_name": "CIFAR10"}), "CIFAR100": ("Torchvision", {"dataset_name": "CIFAR100"}), "CelebA": ("Torchvision", {"dataset_name": "CelebA"}), + "ETTh1": ("ETT", {"dataset_name": "ETTh1"}), + "ETTh2": ("ETT", {"dataset_name": "ETTh2"}), + "ETTm1": ("ETT", {"dataset_name": "ETTm1"}), + "ETTm2": ("ETT", {"dataset_name": "ETTm2"}), } From 73d19de0d649d40231b6eacadf49b92ebbfeb7b9 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Fri, 14 Nov 2025 03:19:43 +0000 Subject: [PATCH 23/33] Added TimeSeriesDatasetWrapper support for time-series datasets in datasources. --- plato/datasources/huggingface.py | 118 +++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/plato/datasources/huggingface.py b/plato/datasources/huggingface.py index 99267496e..800c79154 100644 --- a/plato/datasources/huggingface.py +++ b/plato/datasources/huggingface.py @@ -9,7 +9,9 @@ import logging import os +import torch from datasets import load_dataset, load_from_disk +from torch.utils.data import Dataset as TorchDataset from transformers import ( AutoConfig, AutoTokenizer, @@ -21,6 +23,78 @@ from plato.config import Config from plato.datasources import base +from plato.utils.timeseries_utils import is_timeseries_model + + +class TimeSeriesDatasetWrapper(TorchDataset): + """ + Wrapper for time series data from HuggingFace datasets. + Converts HuggingFace dataset format to standard time-series format used by HuggingFace time-series models. + """ + + def __init__(self, hf_dataset, context_length, prediction_length): + """ + Args: + hf_dataset: HuggingFace dataset with time series data + context_length: Number of historical timesteps + prediction_length: Number of future timesteps to predict + """ + self.dataset = hf_dataset + self.context_length = context_length + self.prediction_length = prediction_length + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + """ + Returns time series data in PatchTSMixer format. + + Expected HuggingFace dataset format: + - 'past_values'/'future_values': Pre-split time series + - 'target': Full time series to be split + """ + item = self.dataset[idx] + + # Handle different dataset formats + if isinstance(item, dict): + if "past_values" in item and "future_values" in item: + # Already in the right format + return { + "past_values": torch.FloatTensor(item["past_values"]), + "future_values": torch.FloatTensor(item["future_values"]), + } + elif "target" in item: + # Extract from 'target' field and split + target = torch.FloatTensor(item["target"]) + else: + raise ValueError( + f"Dataset must contain either 'past_values'/'future_values' or 'target' field. " + f"Found keys: {list(item.keys())}" + ) + else: + target = item if torch.is_tensor(item) else torch.FloatTensor(item) + + # If 1D, add channel dimension: (length,) -> (length, 1) + if target.dim() == 1: + target = target.unsqueeze(-1) + + # Split into past and future + if len(target) < self.context_length + self.prediction_length: + raise ValueError( + f"Time series too short: got {len(target)} timesteps, " + f"need at least {self.context_length + self.prediction_length}" + ) + + past_values = target[: self.context_length] + future_values = target[ + self.context_length : self.context_length + self.prediction_length + ] + + return { + "past_values": past_values, + "future_values": future_values, + } class DataSource(base.DataSource): @@ -51,6 +125,50 @@ def __init__(self, **kwargs): if callable(save_to_disk): save_to_disk(saved_data_path) + # Determine dataset type from config or model type + model_type = getattr(Config().trainer, "model_type", None) + dataset_type = getattr(Config().data, "dataset_type", "text") + + is_timeseries = is_timeseries_model( + model_type=model_type, dataset_type=dataset_type + ) + + if is_timeseries: + self._init_timeseries_dataset() + else: + self._init_text_dataset() + + def _init_timeseries_dataset(self): + """Initialize time series dataset.""" + logging.info("Initializing time series dataset") + + # Get time series parameters from config + context_length = getattr(Config().trainer, "context_length", 512) + prediction_length = getattr(Config().trainer, "prediction_length", 96) + + # Wrap datasets + train_split = ( + "train" if "train" in self.dataset else list(self.dataset.keys())[0] + ) + test_split = ( + "test" + if "test" in self.dataset + else "validation" + if "validation" in self.dataset + else train_split + ) + + self.trainset = TimeSeriesDatasetWrapper( + self.dataset[train_split], context_length, prediction_length + ) + self.testset = TimeSeriesDatasetWrapper( + self.dataset[test_split], context_length, prediction_length + ) + + def _init_text_dataset(self): + """Initialize text/NLP dataset.""" + logging.info("Initializing text/NLP dataset") + parser = HfArgumentParser(TrainingArguments) (self.training_args,) = parser.parse_args_into_dataclasses( args=["--output_dir=/tmp", "--report_to=none"] From 3bb745c2225f0030a90033c6dfc7a10fa7573b84 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Fri, 14 Nov 2025 03:20:51 +0000 Subject: [PATCH 24/33] Added PatchTSMixer model support to HuggingFace model factory. --- plato/models/huggingface.py | 117 +++++++++++++++++++++++++++++++++++- 1 file changed, 116 insertions(+), 1 deletion(-) diff --git a/plato/models/huggingface.py b/plato/models/huggingface.py index 4a8d3aafb..992f3a375 100644 --- a/plato/models/huggingface.py +++ b/plato/models/huggingface.py @@ -10,6 +10,22 @@ from transformers import AutoConfig, AutoModelForCausalLM from plato.config import Config +from plato.utils.timeseries_utils import is_timeseries_model + +try: + from transformers import ( + PatchTSMixerConfig, + PatchTSMixerForPrediction, + PatchTSMixerForPretraining, + PatchTSMixerForRegression, + PatchTSMixerForTimeSeriesClassification, + ) +except ImportError: + PatchTSMixerConfig = None + PatchTSMixerForPrediction = None + PatchTSMixerForTimeSeriesClassification = None + PatchTSMixerForRegression = None + PatchTSMixerForPretraining = None try: from peft import LoraConfig, get_peft_model @@ -36,7 +52,87 @@ def _lora_config_dict(lora_config: Any) -> dict[str, Any]: class Model: - """The CausalLM model loaded from HuggingFace.""" + """The HuggingFace model factory supporting various model types.""" + + @staticmethod + def _get_timeseries_task_type(model_task=None): + """Determine the task type for time series models from config or arguments.""" + trainer_config = Config().trainer + return ( + model_task + or getattr(trainer_config, "model_task", None) + or getattr(trainer_config, "task_type", "forecasting") + ) + + @staticmethod + def _get_patchtsmixer_model(resolved_model_name, cache_dir, model_task=None): + """Load or create a PatchTSMixer model.""" + if PatchTSMixerForPrediction is None: + raise ImportError( + "PatchTSMixer models are not available. " + "Ensure you have transformers>=4.35.0 installed." + ) + + task_type = Model._get_timeseries_task_type(model_task) + + # Try to load pretrained model first + task_models = { + "classification": PatchTSMixerForTimeSeriesClassification, + "regression": PatchTSMixerForRegression, + "pretraining": PatchTSMixerForPretraining, + "forecasting": PatchTSMixerForPrediction, + } + model_class = task_models.get(task_type, PatchTSMixerForPrediction) + + try: + logging.info( + "Attempting to load pretrained PatchTSMixer model: %s", + resolved_model_name, + ) + model = model_class.from_pretrained( + resolved_model_name, cache_dir=cache_dir + ) + logging.info("Successfully loaded pretrained model") + except (OSError, ValueError, Exception): + # If loading fails, create new model from config + logging.info( + "Model '%s' not found as pretrained, creating from config settings", + resolved_model_name, + ) + trainer_config = Config().trainer + + config = PatchTSMixerConfig( + context_length=getattr(trainer_config, "context_length", 512), + prediction_length=getattr(trainer_config, "prediction_length", 96), + num_input_channels=getattr(trainer_config, "num_input_channels", 7), + patch_length=getattr(trainer_config, "patch_length", 8), + patch_stride=getattr(trainer_config, "patch_stride", 8), + d_model=getattr(trainer_config, "d_model", 64), + num_layers=getattr(trainer_config, "num_layers", 8), + expansion_factor=getattr(trainer_config, "expansion_factor", 2), + dropout=getattr(trainer_config, "dropout", 0.2), + head_dropout=getattr(trainer_config, "head_dropout", 0.2), + mode=getattr(trainer_config, "mode", "common_channel"), + gated_attn=getattr(trainer_config, "gated_attn", True), + scaling=getattr(trainer_config, "scaling", "std"), + prediction_channel_indices=getattr( + trainer_config, "prediction_channel_indices", None + ), + ) + + # Set task-specific parameters and create model + if task_type == "classification": + config.num_labels = getattr(trainer_config, "num_classes", 2) + model = PatchTSMixerForTimeSeriesClassification(config) + elif task_type == "regression": + config.num_targets = getattr(trainer_config, "num_targets", 1) + model = PatchTSMixerForRegression(config) + elif task_type == "pretraining": + model = PatchTSMixerForPretraining(config) + else: # forecasting + model = PatchTSMixerForPrediction(config) + + return model @staticmethod def get(model_name=None, **kwargs): # pylint: disable=unused-argument @@ -55,6 +151,25 @@ def get(model_name=None, **kwargs): # pylint: disable=unused-argument if not isinstance(resolved_model_name, str) or not resolved_model_name: raise ValueError("A valid HuggingFace model name must be provided.") + cache_dir = Config().params["model_path"] + "/huggingface" + + # Determine model type from config or model name + model_type = kwargs.get("model_type") or getattr( + getattr(Config(), "trainer", None), "model_type", None + ) + + # Detect if this is a time series model and which type + is_timeseries = is_timeseries_model( + model_name=resolved_model_name, model_type=model_type + ) + + if is_timeseries: + model_task = kwargs.get("model_task") + return Model._get_patchtsmixer_model( + resolved_model_name, cache_dir, model_task + ) + + # Default to CausalLM for backward compatibility config = AutoConfig.from_pretrained(resolved_model_name, **config_kwargs) model = AutoModelForCausalLM.from_pretrained( From 03abe2facf5d37ebb436d38076be8639c6cc9d80 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Fri, 14 Nov 2025 03:28:12 +0000 Subject: [PATCH 25/33] Added timeseries_utils module with is_timeseries_model function. --- plato/utils/timeseries_utils.py | 39 +++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 plato/utils/timeseries_utils.py diff --git a/plato/utils/timeseries_utils.py b/plato/utils/timeseries_utils.py new file mode 100644 index 000000000..8f2a0e935 --- /dev/null +++ b/plato/utils/timeseries_utils.py @@ -0,0 +1,39 @@ +""" +Utility functions for time series model detection and handling. +""" + +from typing import Optional, Tuple + + +def is_timeseries_model( + model_name: Optional[str] = None, + model_type: Optional[str] = None, + dataset_type: Optional[str] = None, +) -> bool: + """ + Check if a model/dataset is for time series. + + Args: + model_name: Name of the model + model_type: Type of model from config + dataset_type: Type of dataset from config + + Returns: + True if this is a time series model, False otherwise + """ + model_name_lower = model_name.lower() if model_name else "" + model_type_lower = model_type.lower() if model_type else "" + + # Check for PatchTSMixer + if ( + model_type_lower == "patchtsmixer" + or "patchtsmixer" in model_name_lower + or "timeseries" in model_name_lower + ): + return True + + # Check dataset type + if dataset_type and dataset_type.lower() == "timeseries": + return True + + return False From c30dafc9da9bccc952465a2eac0aca3111487eeb Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Fri, 14 Nov 2025 03:32:40 +0000 Subject: [PATCH 26/33] Added time-series support to the HuggingFace trainer. --- plato/trainers/huggingface.py | 299 ++++++++++++++++++++++++---------- 1 file changed, 209 insertions(+), 90 deletions(-) diff --git a/plato/trainers/huggingface.py b/plato/trainers/huggingface.py index 61365a638..7f4ed694f 100644 --- a/plato/trainers/huggingface.py +++ b/plato/trainers/huggingface.py @@ -5,6 +5,7 @@ HuggingFace data handling through strategy objects instead of overriding `load_model`/`save_model` hooks. +Supports both text/NLP models and time series models (e.g., PatchTSMixer). """ import logging @@ -39,6 +40,7 @@ TrainingContext, TrainingStepStrategy, ) +from plato.utils.timeseries_utils import is_timeseries_model class HuggingFaceBatch(dict): @@ -79,6 +81,23 @@ def __call__( return HuggingFaceBatch(batch), labels +class TimeSeriesCollateWrapper: + """Collator for time series data (PatchTSMixer format).""" + + def __call__( + self, examples: Iterable[dict] + ) -> tuple[HuggingFaceBatch, torch.Tensor | None]: + """ + Collate time series examples into batches. + + Expected format: {"past_values": tensor, "future_values": tensor} + """ + batch = default_data_collator(list(examples)) + labels = batch.get("future_values", None) + + return HuggingFaceBatch(batch), labels + + def _resolve_hf_loss(outputs, labels, *, allow_fallback: bool = True): """ Resolve a loss tensor from HuggingFace model outputs. @@ -110,8 +129,10 @@ def _resolve_hf_loss(outputs, labels, *, allow_fallback: bool = True): raise ValueError("HuggingFace model did not return a tensor loss.") logits = getattr(outputs, "logits", None) + if logits is None: + logits = getattr(outputs, "prediction_outputs", None) # PatchTSMixer if logits is None and isinstance(outputs, dict): - logits = outputs.get("logits") + logits = outputs.get("logits") or outputs.get("prediction_outputs") if logits is None and isinstance(outputs, tuple) and len(outputs) > 0: logits = outputs[0] @@ -133,6 +154,13 @@ def _resolve_hf_loss(outputs, labels, *, allow_fallback: bool = True): logits = logits.to(labels.device) if labels.device != logits.device else logits labels = labels.to(logits.device) + # Check if this is a regression task (shapes match) -> use MSE + # Time series: logits (batch, pred_len, channels), labels (batch, pred_len, channels) + # Text generation: logits (batch, seq_len, vocab_size), labels (batch, seq_len) + if logits.shape == labels.shape: + return F.mse_loss(logits, labels) + + # Text generation with causal LM -> use cross-entropy vocab_size = logits.size(-1) if logits.ndim > 2: shift_logits = logits[..., :-1, :].contiguous() @@ -196,12 +224,25 @@ def training_step( optimizer.zero_grad() batch_inputs = dict(examples) - if labels is not None: + + # For time series models like PatchTSMixer, future_values should not be passed as 'labels' + # TODO: Need to check if other time series models follow this + is_timeseries = ( + "past_values" in batch_inputs and "future_values" in batch_inputs + ) + + if not is_timeseries and labels is not None: batch_inputs["labels"] = labels batch_inputs.setdefault("return_dict", True) outputs = model(**batch_inputs) - labels_tensor = batch_inputs.get("labels") + + # For time series, get labels from batch_inputs, otherwise from labels argument + labels_tensor = ( + batch_inputs.get("future_values") + if is_timeseries + else batch_inputs.get("labels") + ) loss = _resolve_hf_loss(outputs, labels_tensor) loss_for_backward = loss.div(accum_steps) if accum_steps > 1 else loss @@ -291,10 +332,21 @@ def finalize(self, model, optimizer, context: TrainingContext): class HuggingFaceTestingStrategy(TestingStrategy): - """Evaluates HuggingFace models and reports perplexity based on loss.""" + """Evaluates HuggingFace models (text: perplexity, time series: MSE).""" - def __init__(self, collate_fn: HuggingFaceCollateWrapper): + def __init__(self, collate_fn, is_timeseries=False): self.collate_fn = collate_fn + self.is_timeseries = is_timeseries + + @property + def metric_name(self) -> str: + """Return the name of the metric this strategy computes.""" + if self.is_timeseries: + return "mse" # For time series models, using mean squared error. + elif hasattr(Config().trainer, "target_perplexity"): + return "perplexity" + else: + return "accuracy" def test_model(self, model, config, testset, sampler, context: TrainingContext): batch_size = config.get("batch_size", 1) @@ -324,41 +376,80 @@ def test_model(self, model, config, testset, sampler, context: TrainingContext): model.eval() context.state["eval_loader"] = data_loader - total_loss = 0.0 - total_weight = 0 - - with torch.no_grad(): - for batch_inputs, labels in data_loader: - batch_inputs = batch_inputs.to(context.device) - if labels is not None: - labels = labels.to(context.device) - batch_inputs["labels"] = labels - - batch_inputs.setdefault("return_dict", True) - outputs = model(**batch_inputs) - loss = _resolve_hf_loss(outputs, labels) - - if labels is not None: - weight = labels.ne(-100).sum().item() - if weight == 0: - continue - else: - weight = 1 - - total_loss += loss.item() * weight - total_weight += weight - - model.train() - context.state.pop("eval_loader", None) - - if total_weight == 0: - return float("inf") - - avg_loss = total_loss / total_weight - try: - return math.exp(avg_loss) - except OverflowError: - return float("inf") + if self.is_timeseries: + total_loss = 0.0 + total_samples = 0 + + with torch.no_grad(): + for batch_inputs, labels in data_loader: + batch_inputs = batch_inputs.to(context.device) + if labels is not None: + labels = labels.to(context.device) + batch_inputs["future_values"] = labels + + batch_inputs.setdefault("return_dict", True) + outputs = model(**batch_inputs) + + loss = getattr(outputs, "loss", None) + if loss is None: + loss = ( + outputs.get("loss") if isinstance(outputs, dict) else None + ) + + if loss is not None: + batch_size = ( + batch_inputs["past_values"].size(0) + if "past_values" in batch_inputs + else 1 + ) + total_loss += loss.item() * batch_size + total_samples += batch_size + + model.train() + context.state.pop("eval_loader", None) + + if total_samples == 0: + return float("inf") + + # Return MSE + return total_loss / total_samples + else: + # Text/NLP: compute perplexity + total_loss = 0.0 + total_weight = 0 + + with torch.no_grad(): + for batch_inputs, labels in data_loader: + batch_inputs = batch_inputs.to(context.device) + if labels is not None: + labels = labels.to(context.device) + batch_inputs["labels"] = labels + + batch_inputs.setdefault("return_dict", True) + outputs = model(**batch_inputs) + loss = _resolve_hf_loss(outputs, labels) + + if labels is not None: + weight = labels.ne(-100).sum().item() + if weight == 0: + continue + else: + weight = 1 + + total_loss += loss.item() * weight + total_weight += weight + + model.train() + context.state.pop("eval_loader", None) + + if total_weight == 0: + return float("inf") + + avg_loss = total_loss / total_weight + try: + return math.exp(avg_loss) + except OverflowError: + return float("inf") def _split_callback_types( @@ -433,57 +524,75 @@ def __init__(self, model=None, callbacks=None): ] ) - model_name = Config().trainer.model_name - config_kwargs = { - "cache_dir": None, - "revision": "main", - "use_auth_token": None, - } - self.config = AutoConfig.from_pretrained(model_name, **config_kwargs) + model_name = getattr(Config().trainer, "model_name", "") + model_type = getattr(Config().trainer, "model_type", None) - cache_dir = Config().params["data_path"] - use_fast_tokenizer = True - revision = "main" - auth_token = getattr( - getattr(Config(), "parameters", None), "huggingface_token", None + # Detect if this is a time series model + self._is_timeseries = is_timeseries_model( + model_name=model_name, model_type=model_type ) - if "llama" in model_name: - if isinstance(auth_token, str) and auth_token: - self.tokenizer = LlamaTokenizer.from_pretrained( - model_name, - config=self.config, - cache_dir=cache_dir, - use_fast=use_fast_tokenizer, - revision=revision, - use_auth_token=auth_token, - ) - else: - self.tokenizer = LlamaTokenizer.from_pretrained( - model_name, - config=self.config, - cache_dir=cache_dir, - use_fast=use_fast_tokenizer, - revision=revision, - ) - else: - if isinstance(auth_token, str) and auth_token: - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, - config=self.config, - cache_dir=cache_dir, - use_fast=use_fast_tokenizer, - revision=revision, - use_auth_token=auth_token, - ) + if self._is_timeseries: + logging.info( + "Detected time series model (type: %s, name: %s)", + model_type, + model_name, + ) + + self.config = None + if not self._is_timeseries: + config_kwargs = { + "cache_dir": None, + "revision": "main", + "use_auth_token": None, + } + self.config = AutoConfig.from_pretrained(model_name, **config_kwargs) + + self.tokenizer = None + if not self._is_timeseries: + cache_dir = Config().params["data_path"] + use_fast_tokenizer = True + revision = "main" + auth_token = getattr( + getattr(Config(), "parameters", None), "huggingface_token", None + ) + + if "llama" in model_name: + if isinstance(auth_token, str) and auth_token: + self.tokenizer = LlamaTokenizer.from_pretrained( + model_name, + config=self.config, + cache_dir=cache_dir, + use_fast=use_fast_tokenizer, + revision=revision, + use_auth_token=auth_token, + ) + else: + self.tokenizer = LlamaTokenizer.from_pretrained( + model_name, + config=self.config, + cache_dir=cache_dir, + use_fast=use_fast_tokenizer, + revision=revision, + ) else: - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, - config=self.config, - cache_dir=cache_dir, - use_fast=use_fast_tokenizer, - revision=revision, - ) + if isinstance(auth_token, str) and auth_token: + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + config=self.config, + cache_dir=cache_dir, + use_fast=use_fast_tokenizer, + revision=revision, + use_auth_token=auth_token, + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + config=self.config, + cache_dir=cache_dir, + use_fast=use_fast_tokenizer, + revision=revision, + ) grad_accum_steps = getattr(Config().trainer, "gradient_accumulation_steps", 1) try: @@ -491,7 +600,15 @@ def __init__(self, model=None, callbacks=None): except (TypeError, ValueError): grad_accum_steps = 1 self._gradient_accumulation_steps = max(grad_accum_steps, 1) - self._collate_wrapper = HuggingFaceCollateWrapper(self.tokenizer) + + # Choose collator based on model type + if self._is_timeseries: + self._collate_wrapper = TimeSeriesCollateWrapper() + logging.info("Using TimeSeriesCollateWrapper for time series model") + else: + self._collate_wrapper = HuggingFaceCollateWrapper(self.tokenizer) + logging.info("Using HuggingFaceCollateWrapper for text model") + self.training_args.gradient_accumulation_steps = ( self._gradient_accumulation_steps ) @@ -513,14 +630,16 @@ def __init__(self, model=None, callbacks=None): num_workers=0, pin_memory=True, ), - testing_strategy=HuggingFaceTestingStrategy(self._collate_wrapper), + testing_strategy=HuggingFaceTestingStrategy( + self._collate_wrapper, is_timeseries=self._is_timeseries + ), ) if hf_callbacks: self.add_callbacks(hf_callbacks) model_instance = self._require_model() - if hasattr(model_instance, "loss_type"): + if hasattr(model_instance, "loss_type") and not self._is_timeseries: setattr(model_instance, "loss_type", "ForCausalLM") # Ensure model checkpoints can be saved when model names include slashes. From dc90468f617a69589c269589414ecbcebd6c42c8 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Fri, 14 Nov 2025 04:35:32 +0000 Subject: [PATCH 27/33] Added documentation for time series model PatchTSMixer. --- docs/docs/examples/Getting Started.md | 5 +++++ .../examples/algorithms/15. Time Series Models.md | 15 +++++++++++++++ 2 files changed, 20 insertions(+) create mode 100644 docs/docs/examples/algorithms/15. Time Series Models.md diff --git a/docs/docs/examples/Getting Started.md b/docs/docs/examples/Getting Started.md index f39bded79..ebc481430 100644 --- a/docs/docs/examples/Getting Started.md +++ b/docs/docs/examples/Getting Started.md @@ -45,6 +45,11 @@ Plato supports both Linux with NVIDIA GPUs and macOS with M1/M2/M4/M4 GPUs. It w - [Model Pruning Algorithms](algorithms/13.%20Model%20Pruning%20Algorithms.md) +- [Gradient Leakage Attacks and Defences](algorithms/14.%20Gradient%20Leakage%20Attacks%20and%20Defences.md) + +- [Time Series Models](algorithms/15.%20Time%20Series%20Models.md) + + ## Case Studies - [Federated LoRA Fine-Tuning](case-studies/1.%20LoRA.md) diff --git a/docs/docs/examples/algorithms/15. Time Series Models.md b/docs/docs/examples/algorithms/15. Time Series Models.md new file mode 100644 index 000000000..9c4c7e502 --- /dev/null +++ b/docs/docs/examples/algorithms/15. Time Series Models.md @@ -0,0 +1,15 @@ +### PatchTSMixer + +PatchTSMixer is a lightweight time-series modeling approach based on the MLP-Mixer architecture. The model can be pretrained and subsequently used for various downstream tasks such as forecasting, classification and regression. + +```bash +uv run python plato.py -c configs/TimeSeries/patchtsmixer_pretrained.toml +``` + +For custom model configurations without using pretrained weights: + +```bash +uv run python plato.py -c configs/TimeSeries/patchtsmixer_custom.toml +``` + +**Reference:** V. Ekambaram, A. Jati, N. Nguyen, S. Sinthong, K. Kalagnanam. "[TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://dl.acm.org/doi/abs/10.1145/3580305.3599533)," in Proc. ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD), 2023. \ No newline at end of file From ed13defa933625f449c44bda205c223917517531 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Fri, 14 Nov 2025 04:44:25 +0000 Subject: [PATCH 28/33] Added links to time series model in docs. --- docs/docs/examples/algorithms/15. Time Series Models.md | 2 +- docs/docs/index.md | 1 + docs/mkdocs.yml | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/docs/examples/algorithms/15. Time Series Models.md b/docs/docs/examples/algorithms/15. Time Series Models.md index 9c4c7e502..c6dea9759 100644 --- a/docs/docs/examples/algorithms/15. Time Series Models.md +++ b/docs/docs/examples/algorithms/15. Time Series Models.md @@ -12,4 +12,4 @@ For custom model configurations without using pretrained weights: uv run python plato.py -c configs/TimeSeries/patchtsmixer_custom.toml ``` -**Reference:** V. Ekambaram, A. Jati, N. Nguyen, S. Sinthong, K. Kalagnanam. "[TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://dl.acm.org/doi/abs/10.1145/3580305.3599533)," in Proc. ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD), 2023. \ No newline at end of file +**Reference:** V. Ekambaram, A. Jati, N. Nguyen, S. Sinthong, K. Kalagnanam. "[TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://dl.acm.org/doi/abs/10.1145/3580305.3599533)," in Proc. ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD), 2023. – [[Code available]](https://github.com/ibm-granite/granite-tsfm) \ No newline at end of file diff --git a/docs/docs/index.md b/docs/docs/index.md index 1d186eda4..63a850aab 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -32,6 +32,7 @@ Welcome to *Plato*, a software framework to facilitate scalable, reproducible, a - **[Poisoning Detection](examples/algorithms/12.%20Poisoning%20Detection%20Algorithms.md)** - **[Model Pruning](examples/algorithms/13.%20Model%20Pruning%20Algorithms.md)** - **[Gradient Leakage Attacks and Defences](examples/algorithms/14.%20Gradient%20Leakage%20Attacks%20and%20Defences.md)** + - **[Time Series Models](examples/algorithms/15.%20Time%20Series%20Models.md)** ## Configuration Settings diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index c5e428749..1289b4e1a 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -66,6 +66,7 @@ nav: - Poisoning Detection: examples/algorithms/12. Poisoning Detection Algorithms.md - Model Pruning: examples/algorithms/13. Model Pruning Algorithms.md - Gradient Leakage Attacks and Defences: examples/algorithms/14. Gradient Leakage Attacks and Defences.md + - Time Series Models: examples/algorithms/15. Time Series Models.md - Case Studies: - Federated LoRA Fine-Tuning: examples/case-studies/1. LoRA.md - Composable Trainer API: examples/case-studies/2. Composable Trainer.md From 7bc7f43f8ac6a0acd3cc99ec4454c91e94f4dc96 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Fri, 21 Nov 2025 21:41:13 +0000 Subject: [PATCH 29/33] Revised dataset split to improve training performance. --- plato/datasources/ETT.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plato/datasources/ETT.py b/plato/datasources/ETT.py index c05684c96..e8032b34e 100644 --- a/plato/datasources/ETT.py +++ b/plato/datasources/ETT.py @@ -120,10 +120,10 @@ def __init__(self, **kwargs): ) # -1 for date column # Split into train/val/test following the standard ETT split - # Standard split: 12 months train, 4 months val, 4 months test + # Standard split: 16 months train, 2 months val, 2 months test points_per_hour = dataset_info["points_per_hour"] - train_end = 12 * 30 * 24 * points_per_hour # 12 months - val_end = train_end + 4 * 30 * 24 * points_per_hour # + 4 months + train_end = 16 * 30 * 24 * points_per_hour # 16 months + val_end = train_end + 2 * 30 * 24 * points_per_hour # + 2 months train_df = df[:train_end] val_df = df[train_end:val_end] From 173095aaf4cc4eee14299e9e16e06b84a9e7056c Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 1 Dec 2025 01:41:19 +0000 Subject: [PATCH 30/33] Added a larger PatchTSMixer config file with extended hyperparameters. --- configs/TimeSeries/patchtsmixer_large.toml | 67 ++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 configs/TimeSeries/patchtsmixer_large.toml diff --git a/configs/TimeSeries/patchtsmixer_large.toml b/configs/TimeSeries/patchtsmixer_large.toml new file mode 100644 index 000000000..6a9253671 --- /dev/null +++ b/configs/TimeSeries/patchtsmixer_large.toml @@ -0,0 +1,67 @@ +# Federated Learning with Large PatchTSMixer for Time Series Forecasting +# This configuration matches the PatchTSMixer paper parameters for ETTh1 + +[clients] +type = "simple" +total_clients = 1 +per_round = 1 +do_test = true # Enable testing to evaluate model on test set + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/patchtsmixer" +model_path = "models/timeseries/patchtsmixer" + +[data] +# ETTh1: Electricity Transformer Temperature dataset (7 features) +datasource = "ETTh1" + +partition_size = 6960 # Full ETTh1 training set +sampler = "iid" +random_seed = 1 + +[trainer] +type = "HuggingFace" +rounds = 1000 +max_concurrency = 10 +model_type = "huggingface" +model_name = "custom_patchtsmixer" + +# Task type: forecasting, classification, regression, or pretraining +task_type = "forecasting" + +# PatchTSMixer specific parameters +context_length = 512 # Paper uses 512 context length +prediction_length = 96 # Standard benchmark (paper tests 96, 192, 336, 720) +num_input_channels = 7 # ETTh1 has 7 features (HUFL, HULL, MUFL, MULL, LUFL, LULL, OT) +patch_length = 16 +patch_stride = 8 + +d_model = 128 +num_layers = 8 +expansion_factor = 2 + +dropout = 0.3 # Increase regularization to prevent overfitting +head_dropout = 0.3 # Increase regularization to prevent overfitting + +# Model configuration +mode = "common_channel" +gated_attn = true +scaling = "std" + +epochs = 100 +batch_size = 64 +optimizer = "Adam" + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0001 +weight_decay = 0.001 + +[results] +types = "round, elapsed_time, mse" From 5562465ec3e97f21bc05ba0626ef674bb76b9201 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 1 Dec 2025 04:46:46 +0000 Subject: [PATCH 31/33] Revised MSE evaluation logs for time series models. --- plato/clients/strategies/defaults.py | 10 +++++++++- plato/servers/fedavg.py | 26 ++++++++++++++++++++++---- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/plato/clients/strategies/defaults.py b/plato/clients/strategies/defaults.py index a7a54c86d..382dcdd3d 100644 --- a/plato/clients/strategies/defaults.py +++ b/plato/clients/strategies/defaults.py @@ -339,7 +339,15 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: if context.sio is not None: await context.sio.disconnect() - if hasattr(Config().trainer, "target_perplexity"): + metric_name = None + if hasattr(context.trainer, "testing_strategy") and hasattr( + context.trainer.testing_strategy, "metric_name" + ): + metric_name = context.trainer.testing_strategy.metric_name + + if metric_name == "mse": + LOGGER.info("[%s] Test MSE: %.2f", context, accuracy) + elif hasattr(Config().trainer, "target_perplexity") or metric_name == "perplexity": LOGGER.info("[%s] Test perplexity: %.2f", context, accuracy) else: LOGGER.info("[%s] Test accuracy: %.2f%%", context, 100 * accuracy) diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index d6e819bb6..3d9fe1e2c 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -243,9 +243,24 @@ async def _process_reports(self): if hasattr(Config().server, "do_test") and not Config().server.do_test: # Compute the average accuracy from client reports self.accuracy, self.accuracy_std = self.get_accuracy_mean_std(self.updates) - logging.info( - "[%s] Average client accuracy: %.2f%%.", self, 100 * self.accuracy - ) + + trainer = self.require_trainer() + metric_name = getattr(trainer.testing_strategy, "metric_name", "accuracy") + + if metric_name == "mse": + logging.info( + "[%s] Average client MSE: %.2f.", self, self.accuracy + ) + elif metric_name == "perplexity" or hasattr( + Config().trainer, "target_perplexity" + ): + logging.info( + "[%s] Average client perplexity: %.2f.", self, self.accuracy + ) + else: + logging.info( + "[%s] Average client accuracy: %.2f%%.", self, 100 * self.accuracy + ) else: # Testing the updated model directly at the server logging.info("[%s] Started model testing.", self) @@ -275,7 +290,7 @@ async def _process_reports(self): if metric_name == "mse": logging.info( - fonts.colourize(f"[{self}] Global model MSE: {self.accuracy:.4f}\n") + fonts.colourize(f"[{self}] Global model MSE: {self.accuracy:.2f}\n") ) elif metric_name == "perplexity" or hasattr( Config().trainer, "target_perplexity" @@ -341,6 +356,9 @@ def get_logged_items(self) -> dict: if hasattr(self, "_core_metric"): logged["core_metric"] = self._core_metric + logged["mse"] = self.accuracy + logged["perplexity"] = self.accuracy + return logged @staticmethod From b9778ecd12b2a17c18b0350e144aa1ecdbb72f41 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 1 Dec 2025 04:52:43 +0000 Subject: [PATCH 32/33] Used uv ruff format . --- plato/clients/strategies/defaults.py | 5 ++++- plato/servers/fedavg.py | 6 ++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/plato/clients/strategies/defaults.py b/plato/clients/strategies/defaults.py index 382dcdd3d..2582f70d8 100644 --- a/plato/clients/strategies/defaults.py +++ b/plato/clients/strategies/defaults.py @@ -347,7 +347,10 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: if metric_name == "mse": LOGGER.info("[%s] Test MSE: %.2f", context, accuracy) - elif hasattr(Config().trainer, "target_perplexity") or metric_name == "perplexity": + elif ( + hasattr(Config().trainer, "target_perplexity") + or metric_name == "perplexity" + ): LOGGER.info("[%s] Test perplexity: %.2f", context, accuracy) else: LOGGER.info("[%s] Test accuracy: %.2f%%", context, 100 * accuracy) diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index 3d9fe1e2c..f93541f2c 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -246,11 +246,9 @@ async def _process_reports(self): trainer = self.require_trainer() metric_name = getattr(trainer.testing_strategy, "metric_name", "accuracy") - + if metric_name == "mse": - logging.info( - "[%s] Average client MSE: %.2f.", self, self.accuracy - ) + logging.info("[%s] Average client MSE: %.2f.", self, self.accuracy) elif metric_name == "perplexity" or hasattr( Config().trainer, "target_perplexity" ): From 20ab5746f9e4cd8c134f74965a31edd976fd17b6 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Tue, 2 Dec 2025 20:43:06 +0000 Subject: [PATCH 33/33] Refactored ETT data splitting and normalization for consistency with HF examples. --- plato/datasources/ETT.py | 36 +++++++++++++++++++++++----------- plato/samplers/iid.py | 42 ++++++++++++++++++++++------------------ 2 files changed, 48 insertions(+), 30 deletions(-) diff --git a/plato/datasources/ETT.py b/plato/datasources/ETT.py index e8032b34e..fc73b82fe 100644 --- a/plato/datasources/ETT.py +++ b/plato/datasources/ETT.py @@ -119,15 +119,31 @@ def __init__(self, **kwargs): len(df.columns) - 1, ) # -1 for date column - # Split into train/val/test following the standard ETT split - # Standard split: 16 months train, 2 months val, 2 months test + # Split into train/val/test following the standard ETT split used by HF examples + # Standard split: 12 months train, 4 months val, 4 months test points_per_hour = dataset_info["points_per_hour"] - train_end = 16 * 30 * 24 * points_per_hour # 16 months - val_end = train_end + 2 * 30 * 24 * points_per_hour # + 2 months + train_end = 12 * 30 * 24 * points_per_hour # 12 months + val_end = train_end + 4 * 30 * 24 * points_per_hour # + 4 months + test_end = train_end + 8 * 30 * 24 * points_per_hour # + 8 months + + # Shift val/test start back by context_length so their first window has history + val_start = max(0, train_end - context_length) + test_start = max(0, val_end - context_length) train_df = df[:train_end] - val_df = df[train_end:val_end] - test_df = df[val_end:] + val_df = df[val_start:val_end] + test_df = df[test_start:test_end] + + # Compute train mean/std per channel and normalize all splits (matches HF demo preprocessing) + feature_cols = [col for col in df.columns if col != "date"] + train_features = train_df[feature_cols] + eps = 1e-6 + feature_mean = train_features.mean() + feature_std = train_features.std().replace(0, eps) + + train_norm = ((train_features - feature_mean) / feature_std).to_numpy() + val_norm = ((val_df[feature_cols] - feature_mean) / feature_std).to_numpy() + test_norm = ((test_df[feature_cols] - feature_mean) / feature_std).to_numpy() logging.info( "%s split - train: %d, val: %d, test: %d", @@ -139,13 +155,11 @@ def __init__(self, **kwargs): # Create datasets with sliding windows self.trainset = ETTDataset( - train_df, context_length, prediction_length, stride=1 + train_norm, context_length, prediction_length, stride=1 ) - # Use validation set as test set for federated learning - self.testset = ETTDataset( - val_df, context_length, prediction_length, stride=prediction_length - ) + # Evaluate on the standard test split with full coverage + self.testset = ETTDataset(test_norm, context_length, prediction_length, stride=1) logging.info( "Created %d training windows and %d test windows", diff --git a/plato/samplers/iid.py b/plato/samplers/iid.py index 815899cce..848775e99 100644 --- a/plato/samplers/iid.py +++ b/plato/samplers/iid.py @@ -18,29 +18,33 @@ def __init__(self, datasource, client_id, testing): super().__init__() if testing: + # Use the full test set for evaluation to avoid sampling/duplication dataset = datasource.get_test_set() + self.subset_indices = list(range(len(dataset))) else: dataset = datasource.get_train_set() - self.dataset_size = len(dataset) - indices = list(range(self.dataset_size)) - np.random.seed(self.random_seed) - np.random.shuffle(indices) - - partition_size = Config().data.partition_size - total_clients = Config().clients.total_clients - total_size = partition_size * total_clients - - # add extra samples to make it evenly divisible, if needed - if len(indices) < total_size: - while len(indices) < total_size: - indices += indices[: (total_size - len(indices))] - else: - indices = indices[:total_size] - assert len(indices) == total_size - - # Compute the indices of data in the subset for this client - self.subset_indices = indices[(int(client_id) - 1) : total_size : total_clients] + self.dataset_size = len(dataset) + indices = list(range(self.dataset_size)) + np.random.seed(self.random_seed) + np.random.shuffle(indices) + + partition_size = Config().data.partition_size + total_clients = Config().clients.total_clients + total_size = partition_size * total_clients + + # add extra samples to make it evenly divisible, if needed + if len(indices) < total_size: + while len(indices) < total_size: + indices += indices[: (total_size - len(indices))] + else: + indices = indices[:total_size] + assert len(indices) == total_size + + # Compute the indices of data in the subset for this client + self.subset_indices = indices[ + (int(client_id) - 1) : total_size : total_clients + ] def get(self): """Obtains an instance of the sampler."""