diff --git a/.gitmodules b/.gitmodules index 377fc17b7..1176795af 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "plato/models/t2tvit"] path = plato/models/t2tvit url = https://github.com/yitu-opensource/T2T-ViT +[submodule "external/nanochat"] + path = external/nanochat + url = https://github.com/karpathy/nanochat.git diff --git a/cleanup.py b/cleanup.py index 68c658019..f03f5f3cf 100644 --- a/cleanup.py +++ b/cleanup.py @@ -146,18 +146,14 @@ def main() -> None: continue cleared = clean_directory(runtime_dir) - print( - f"Failed to delete {runtime_dir}; cleared {cleared} items instead." - ) + print(f"Failed to delete {runtime_dir}; cleared {cleared} items instead.") fallback_dirs += 1 fallback_items += cleared if runtime_total == 0: print("No runtime directories found.") else: - print( - f"Removed {runtime_removed} of {runtime_total} runtime directories." - ) + print(f"Removed {runtime_removed} of {runtime_total} runtime directories.") if fallback_dirs: print( f"Cleared {fallback_items} items in " diff --git a/configs/Nanochat/parquet_micro.toml b/configs/Nanochat/parquet_micro.toml new file mode 100644 index 000000000..cfe09a921 --- /dev/null +++ b/configs/Nanochat/parquet_micro.toml @@ -0,0 +1,53 @@ +[clients] +type = "simple" +total_clients = 10 +per_round = 3 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/nanochat/parquet" +model_path = "models/nanochat/parquet" + +[data] +datasource = "Nanochat" +sampler = "iid" +partition_size = 1 +random_seed = 1 +mode = "parquet" +max_train_batches = 16 +max_val_batches = 1 +tokenizer_threads = 2 +tokenizer_batch_size = 32 +device = "cuda" +vocab_size = 512 +synthetic_seed = 123 + +[evaluation] +type = "nanochat_core" +# bundle_dir = "~/nanochat" +max_per_task = 16 + +[trainer] +type = "nanochat" +rounds = 10000 +epochs = 5 +batch_size = 1 +model_name = "nanochat" +optimizer = "nanochat" + +[algorithm] +type = "fedavg" + +[parameters.model] +sequence_len = 256 +vocab_size = 50304 +n_layer = 4 +n_head = 4 +n_kv_head = 4 +n_embd = 256 + +[results] +types = "round, elapsed_time, core_metric, train_loss" diff --git a/configs/Nanochat/synthetic_micro.toml b/configs/Nanochat/synthetic_micro.toml new file mode 100644 index 000000000..abc3a6452 --- /dev/null +++ b/configs/Nanochat/synthetic_micro.toml @@ -0,0 +1,54 @@ +[clients] + +type = "simple" +total_clients = 1 +per_round = 1 +do_test = false + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/nanochat/synthetic" +model_path = "models/nanochat/synthetic" + +[data] +datasource = "Nanochat" +sampler = "iid" +partition_size = 1 +random_seed = 1 +mode = "synthetic" +max_train_batches = 4 +max_val_batches = 1 +tokenizer_threads = 2 +tokenizer_batch_size = 64 +device = "cpu" +vocab_size = 512 +synthetic_seed = 123 + +[evaluation] +type = "nanochat_core" +# bundle_dir = "~/nanochat" # Optional, defaults to nanochat base dir or Plato's data directory +max_per_task = 16 # Optional, -1 means run all examples + +[trainer] +type = "nanochat" +rounds = 1 +epochs = 1 +batch_size = 2 +model_name = "nanochat" +optimizer = "nanochat" + +[algorithm] +type = "fedavg" + +[parameters.model] +sequence_len = 128 +vocab_size = 512 +n_layer = 2 +n_head = 4 +n_kv_head = 4 +n_embd = 256 + +[results] +types = "round, elapsed_time, core_metric, train_loss" diff --git a/configs/TimeSeries/patchtsmixer_custom.toml b/configs/TimeSeries/patchtsmixer_custom.toml new file mode 100644 index 000000000..d228eb1fb --- /dev/null +++ b/configs/TimeSeries/patchtsmixer_custom.toml @@ -0,0 +1,67 @@ +# Federated Learning with PatchTSMixer for Time Series Forecasting +# This configuration demonstrates using the IBM Granite PatchTSMixer model +# with time series data from HuggingFace datasets + +[clients] +type = "simple" +total_clients = 1 +per_round = 1 +do_test = false + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/patchtsmixer" +model_path = "models/timeseries/patchtsmixer" + +[data] +# ETTh1: Electricity Transformer Temperature dataset (7 features) +datasource = "ETTh1" + +partition_size = 100 # Number of training samples +sampler = "iid" +random_seed = 1 + +[trainer] +type = "HuggingFace" +rounds = 3 +max_concurrency = 2 +model_type = "huggingface" + +# Train from scratch - simpler for testing +model_name = "custom_patchtsmixer" + +# Task type: forecasting, classification, regression, or pretraining +task_type = "forecasting" + +# PatchTSMixer specific parameters (smaller model for testing) +context_length = 64 +prediction_length = 24 +num_input_channels = 7 # ETTh1 has 7 features (HUFL, HULL, MUFL, MULL, LUFL, LULL, OT) +patch_length = 8 +patch_stride = 8 +d_model = 32 # Hidden dimension of the model. Recommended to set it as a multiple of patch_length (i.e. 2-8X of patch_len). Larger value indicates more complex model. +num_layers = 3 # Number of layers to use. Recommended range is 3-15. Larger value indicates more complex model. +expansion_factor = 2 # Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model. +dropout = 0.5 +head_dropout = 0.7 +mode = "common_channel" +gated_attn = true +scaling = "std" + +# Training parameters +epochs = 2 +batch_size = 8 +optimizer = "Adam" + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.001 +weight_decay = 0.0 + +[results] +types = "round, elapsed_time, accuracy" diff --git a/configs/TimeSeries/patchtsmixer_large.toml b/configs/TimeSeries/patchtsmixer_large.toml new file mode 100644 index 000000000..6a9253671 --- /dev/null +++ b/configs/TimeSeries/patchtsmixer_large.toml @@ -0,0 +1,67 @@ +# Federated Learning with Large PatchTSMixer for Time Series Forecasting +# This configuration matches the PatchTSMixer paper parameters for ETTh1 + +[clients] +type = "simple" +total_clients = 1 +per_round = 1 +do_test = true # Enable testing to evaluate model on test set + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/patchtsmixer" +model_path = "models/timeseries/patchtsmixer" + +[data] +# ETTh1: Electricity Transformer Temperature dataset (7 features) +datasource = "ETTh1" + +partition_size = 6960 # Full ETTh1 training set +sampler = "iid" +random_seed = 1 + +[trainer] +type = "HuggingFace" +rounds = 1000 +max_concurrency = 10 +model_type = "huggingface" +model_name = "custom_patchtsmixer" + +# Task type: forecasting, classification, regression, or pretraining +task_type = "forecasting" + +# PatchTSMixer specific parameters +context_length = 512 # Paper uses 512 context length +prediction_length = 96 # Standard benchmark (paper tests 96, 192, 336, 720) +num_input_channels = 7 # ETTh1 has 7 features (HUFL, HULL, MUFL, MULL, LUFL, LULL, OT) +patch_length = 16 +patch_stride = 8 + +d_model = 128 +num_layers = 8 +expansion_factor = 2 + +dropout = 0.3 # Increase regularization to prevent overfitting +head_dropout = 0.3 # Increase regularization to prevent overfitting + +# Model configuration +mode = "common_channel" +gated_attn = true +scaling = "std" + +epochs = 100 +batch_size = 64 +optimizer = "Adam" + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0001 +weight_decay = 0.001 + +[results] +types = "round, elapsed_time, mse" diff --git a/configs/TimeSeries/patchtsmixer_pretrained.toml b/configs/TimeSeries/patchtsmixer_pretrained.toml new file mode 100644 index 000000000..4119c1b47 --- /dev/null +++ b/configs/TimeSeries/patchtsmixer_pretrained.toml @@ -0,0 +1,68 @@ +# Federated Learning with PatchTSMixer for Time Series Forecasting +# This configuration demonstrates using the IBM Granite PatchTSMixer model +# with time series data from HuggingFace datasets + +[clients] +type = "simple" +total_clients = 1 +per_round = 1 +do_test = false + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/patchtsmixer" +model_path = "models/timeseries/patchtsmixer" + +[data] +# ETTh1: Electricity Transformer Temperature dataset (7 features) +datasource = "ETTh1" + +partition_size = 100 # Number of training samples +sampler = "iid" +random_seed = 1 + +[trainer] +type = "HuggingFace" +rounds = 3 +max_concurrency = 2 +model_type = "huggingface" + +# Use pre-trained IBM Granite model +# For pre-trained model, the some settings must match pretrained model +model_name = "ibm-granite/granite-timeseries-patchtsmixer" + +# Task type: forecasting, classification, regression, or pretraining +task_type = "forecasting" + +# PatchTSMixer specific parameters (matching pretrained model) +context_length = 512 +prediction_length = 96 +num_input_channels = 7 +patch_length = 16 +patch_stride = 8 +d_model = 64 +num_layers = 8 +expansion_factor = 2 # Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model. +dropout = 0.5 +head_dropout = 0.7 +mode = "common_channel" +gated_attn = true +scaling = "std" + +# Training parameters +epochs = 2 # Reduced for testing +batch_size = 8 # Reduced for testing +optimizer = "Adam" + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.001 +weight_decay = 0.0 + +[results] +types = "round, elapsed_time, accuracy" diff --git a/docs/docs/examples/Getting Started.md b/docs/docs/examples/Getting Started.md index f39bded79..ebc481430 100644 --- a/docs/docs/examples/Getting Started.md +++ b/docs/docs/examples/Getting Started.md @@ -45,6 +45,11 @@ Plato supports both Linux with NVIDIA GPUs and macOS with M1/M2/M4/M4 GPUs. It w - [Model Pruning Algorithms](algorithms/13.%20Model%20Pruning%20Algorithms.md) +- [Gradient Leakage Attacks and Defences](algorithms/14.%20Gradient%20Leakage%20Attacks%20and%20Defences.md) + +- [Time Series Models](algorithms/15.%20Time%20Series%20Models.md) + + ## Case Studies - [Federated LoRA Fine-Tuning](case-studies/1.%20LoRA.md) diff --git a/docs/docs/examples/algorithms/15. Time Series Models.md b/docs/docs/examples/algorithms/15. Time Series Models.md new file mode 100644 index 000000000..c6dea9759 --- /dev/null +++ b/docs/docs/examples/algorithms/15. Time Series Models.md @@ -0,0 +1,15 @@ +### PatchTSMixer + +PatchTSMixer is a lightweight time-series modeling approach based on the MLP-Mixer architecture. The model can be pretrained and subsequently used for various downstream tasks such as forecasting, classification and regression. + +```bash +uv run python plato.py -c configs/TimeSeries/patchtsmixer_pretrained.toml +``` + +For custom model configurations without using pretrained weights: + +```bash +uv run python plato.py -c configs/TimeSeries/patchtsmixer_custom.toml +``` + +**Reference:** V. Ekambaram, A. Jati, N. Nguyen, S. Sinthong, K. Kalagnanam. "[TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://dl.acm.org/doi/abs/10.1145/3580305.3599533)," in Proc. ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD), 2023. – [[Code available]](https://github.com/ibm-granite/granite-tsfm) \ No newline at end of file diff --git a/docs/docs/index.md b/docs/docs/index.md index 1d186eda4..63a850aab 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -32,6 +32,7 @@ Welcome to *Plato*, a software framework to facilitate scalable, reproducible, a - **[Poisoning Detection](examples/algorithms/12.%20Poisoning%20Detection%20Algorithms.md)** - **[Model Pruning](examples/algorithms/13.%20Model%20Pruning%20Algorithms.md)** - **[Gradient Leakage Attacks and Defences](examples/algorithms/14.%20Gradient%20Leakage%20Attacks%20and%20Defences.md)** + - **[Time Series Models](examples/algorithms/15.%20Time%20Series%20Models.md)** ## Configuration Settings diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index c5e428749..1289b4e1a 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -66,6 +66,7 @@ nav: - Poisoning Detection: examples/algorithms/12. Poisoning Detection Algorithms.md - Model Pruning: examples/algorithms/13. Model Pruning Algorithms.md - Gradient Leakage Attacks and Defences: examples/algorithms/14. Gradient Leakage Attacks and Defences.md + - Time Series Models: examples/algorithms/15. Time Series Models.md - Case Studies: - Federated LoRA Fine-Tuning: examples/case-studies/1. LoRA.md - Composable Trainer API: examples/case-studies/2. Composable Trainer.md diff --git a/docs/nanochat_integration_checklist.md b/docs/nanochat_integration_checklist.md new file mode 100644 index 000000000..f0b7bf9b6 --- /dev/null +++ b/docs/nanochat_integration_checklist.md @@ -0,0 +1,55 @@ +# Nanochat Integration Checklist + +This checklist coordinates the work to incorporate the Nanochat stack into Plato. Owners are placeholder roles until specific engineers are assigned. + +## Third-Party Submodule +- **Owner:** Infrastructure +- **Deliverables:** Maintain the `external/nanochat` git submodule; document update procedure in `docs/third_party.md`. +- **Dependencies:** None. + +## Model Registry +- **Owner:** Modeling +- **Deliverables:** Implement `plato/models/nanochat.py` mirroring Nanochat GPT config, register entry in `plato/models/registry.py`, supply weight-loading utilities. +- **Status:** In progress – factory module and registry wiring landed. +- **Dependencies:** Third-party submodule. + +## Tokenizer & Processor +- **Owner:** Infrastructure +- **Deliverables:** Package Rust BPE via Maturin optional extra; wrap as `plato/processors/nanochat_tokenizer.py` with lazy import and fallbacks; document build steps in README. +- **Status:** Prototype processor and optional dependency group landed; CI build integration remains TODO. +- **Dependencies:** Third-party submodule, build tooling prototype. + +## Datasource +- **Owner:** Data +- **Deliverables:** Create `plato/datasources/nanochat.py` handling dataset acquisition and sharding; register in datasource registry; store license metadata. +- **Status:** In progress – streaming dataset with synthetic fallback available. +- **Dependencies:** Tokenizer availability. + +## Trainer & Algorithm +- **Owner:** Training +- **Deliverables:** Port Nanochat engine into `plato/trainers/nanochat.py`; add algorithm glue if federated coordination diverges; ensure checkpoint compatibility. +- **Status:** In progress – composable trainer wrapper with Nanochat-specific optimiser/loader strategies in place. +- **Dependencies:** Model registry entry, datasource. + +## Evaluation Strategy +- **Owner:** Evaluation +- **Deliverables:** Translate `nanochat/core_eval.py` into reusable evaluator hooked into Plato testing strategy; add pytest coverage with synthetic data. +- **Status:** CORE evaluation adapter hooked into trainer testing strategy; follow-up coverage to use real eval bundles outstanding. +- **Dependencies:** Model, tokenizer. + +## Configuration & Examples +- **Owner:** Product +- **Deliverables:** Author `configs/Nanochat/*.toml` scenarios and `examples/nanochat/` workspace; include reference scripts and documentation. +- **Status:** Synthetic micro config and workspace README published; larger-scale scenarios pending. +- **Dependencies:** Model, datasource, trainer. + +## Documentation & Release +- **Owner:** Docs +- **Deliverables:** Publish `docs/models/nanochat.md`, extend root README tables, add integration notes and changelog entry; outline hardware requirements. +- **Dependencies:** All prior tracks. + +## Validation +- **Owner:** QA +- **Deliverables:** Expand CI to compile tokenizer, run smoke train/eval, and enforce import order checks; record expected metrics in evaluation baselines. +- **Status:** Initial pytest smoke checks for tokenizer/trainer added; CI enablement still pending. +- **Dependencies:** Evaluation strategy, trainer. diff --git a/docs/third_party.md b/docs/third_party.md new file mode 100644 index 000000000..d481b6829 --- /dev/null +++ b/docs/third_party.md @@ -0,0 +1,18 @@ +# Third-Party Assets + +This page records external projects that are vendored into the Plato repository to support specific integrations. Please update the relevant entry whenever the upstream source, commit hash, or licensing information changes. + +## Nanochat +- **Upstream:** [karpathy/nanochat](https://github.com/karpathy/nanochat) +- **Location:** `external/nanochat` (git submodule) +- **License:** MIT (included in `external/nanochat/LICENSE`) + +### Updating the Submodule +1. `git submodule update --remote external/nanochat` +2. Inspect upstream changes for compatibility with Plato. +3. Commit the submodule pointer update and note any required integration work in the checklist. + +### Notes +- After cloning Plato, run `git submodule update --init --recursive` to populate all external dependencies. +- The Rust tokenizer (`rustbpe`) builds via `maturin`. Ensure `uv run --with ./external/nanochat maturin develop --release` succeeds before pushing updates. +- Avoid local modifications inside the submodule; contribute fixes upstream when possible. diff --git a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py index 487547209..e8d700c97 100644 --- a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py +++ b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/main.py @@ -22,7 +22,6 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F -from data import build_loader from misc.config import get_config from misc.loss_ops import AdaptiveLossSoft from misc.lr_scheduler import build_scheduler @@ -37,6 +36,8 @@ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.utils import AverageMeter, ModelEma, accuracy +from data import build_loader + try: from apex import amp except ImportError: diff --git a/examples/nanochat/README.md b/examples/nanochat/README.md new file mode 100644 index 000000000..1d061345b --- /dev/null +++ b/examples/nanochat/README.md @@ -0,0 +1,61 @@ +# Nanochat Integration Workspace + +This workspace hosts Nanochat-focused experiments within Plato. + +## Quick Start + +1. Initialize the nanochat submodule (required for the nanochat integration): + + ```bash + git submodule update --init --recursive + ``` + +2. Install dependencies (including the vendored tokenizer build requirements): + + ```bash + uv sync --extra nanochat + uv run --with ./external/nanochat maturin develop --release + ``` + **Troubleshooting:** If you encounter a `maturin failed` error with "Can't find Cargo.toml", run the maturin command from within the nanochat directory: + + ```bash + uv sync --extra nanochat + cd external/nanochat && uv run maturin develop --release && cd ../.. + ``` + +3. Run the synthetic smoke configuration: + + ```bash + uv run --extra nanochat python plato.py --config configs/Nanochat/synthetic_micro.toml + ``` + + This launches a single-client training round using the Nanochat trainer, synthetic + token streams, and a downsized GPT configuration for CPU debugging. + +## CORE Evaluation + +The Nanochat trainer can invoke the upstream CORE benchmark by adding the section +below to your TOML configuration: + +```toml +[evaluation] +type = "nanochat_core" +max_per_task = 128 # optional; limits evaluation samples per task +# bundle_dir = "/custom/path/to/nanochat" # defaults to ~/.cache/nanochat +``` + +Make sure the official evaluation bundle has been downloaded so the following files +exist (the default location is `~/.cache/nanochat/eval_bundle`): + +- `core.yaml` +- `eval_data/*.jsonl` +- `eval_meta_data.csv` + +The provided `configs/Nanochat/synthetic_micro.toml` can be extended with the +`[evaluation]` block once those assets are present. + +## Roadmap + +- Integrate real Nanochat tokenized datasets and publish download helpers. +- Add baseline evaluation scripts leveraging `nanochat/core_eval.py`. +- Capture reproducible metrics and hardware notes for larger-scale runs. diff --git a/examples/nanochat/pyproject.toml b/examples/nanochat/pyproject.toml new file mode 100644 index 000000000..c0e19f09c --- /dev/null +++ b/examples/nanochat/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "plato-nanochat-examples" +version = "0.1.0" +description = "Nanochat integration examples for Plato." +readme = "README.md" +requires-python = ">=3.10" +dependencies = [] diff --git a/examples/unlearning/fedunlearning/fedunlearning_server.py b/examples/unlearning/fedunlearning/fedunlearning_server.py index 8938b6549..6d6ade576 100644 --- a/examples/unlearning/fedunlearning/fedunlearning_server.py +++ b/examples/unlearning/fedunlearning/fedunlearning_server.py @@ -43,9 +43,7 @@ async def aggregate_deltas(self, updates, deltas_received, context): if not filtered_pairs: if self._fallback_to_original: - return await super().aggregate_deltas( - updates, deltas_received, context - ) + return await super().aggregate_deltas(updates, deltas_received, context) zero_delta = self._zero_delta( context, deltas_received[0] if deltas_received else None diff --git a/external/nanochat b/external/nanochat new file mode 160000 index 000000000..c75fe54aa --- /dev/null +++ b/external/nanochat @@ -0,0 +1 @@ +Subproject commit c75fe54aa7c1fa881701c246f9427bcbe4eee5a4 diff --git a/plato/clients/strategies/defaults.py b/plato/clients/strategies/defaults.py index 7b104c163..2582f70d8 100644 --- a/plato/clients/strategies/defaults.py +++ b/plato/clients/strategies/defaults.py @@ -339,7 +339,18 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: if context.sio is not None: await context.sio.disconnect() - if hasattr(Config().trainer, "target_perplexity"): + metric_name = None + if hasattr(context.trainer, "testing_strategy") and hasattr( + context.trainer.testing_strategy, "metric_name" + ): + metric_name = context.trainer.testing_strategy.metric_name + + if metric_name == "mse": + LOGGER.info("[%s] Test MSE: %.2f", context, accuracy) + elif ( + hasattr(Config().trainer, "target_perplexity") + or metric_name == "perplexity" + ): LOGGER.info("[%s] Test perplexity: %.2f", context, accuracy) else: LOGGER.info("[%s] Test accuracy: %.2f%%", context, 100 * accuracy) @@ -383,6 +394,18 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: except TypeError: num_samples = 0 + # Extract train_loss from trainer's run_history if available + train_loss = None + if ( + context.trainer is not None + and hasattr(context.trainer, "run_history") + and context.trainer.run_history is not None + ): + try: + train_loss = context.trainer.run_history.get_latest_metric("train_loss") + except (AttributeError, KeyError, IndexError): + train_loss = None + report = SimpleNamespace( client_id=context.client_id, num_samples=num_samples, @@ -390,6 +413,7 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: training_time=training_time, comm_time=time.time(), update_response=False, + train_loss=train_loss, ) return report, weights diff --git a/plato/config.py b/plato/config.py index ec7d840a0..56f35cfa8 100644 --- a/plato/config.py +++ b/plato/config.py @@ -153,6 +153,7 @@ class Config: clients: Any server: Any data: Any + evaluation: Any trainer: Any algorithm: Any results: Any @@ -342,6 +343,10 @@ def __new__(cls): Config.params["base_path"], "data" ) + # User-defined evaluation configuration + if hasattr(config, "evaluation"): + Config.evaluation = config.evaluation + # Pretrained models if hasattr(Config().server, "model_path"): Config.params["model_path"] = os.path.join( diff --git a/plato/datasources/ETT.py b/plato/datasources/ETT.py new file mode 100644 index 000000000..fc73b82fe --- /dev/null +++ b/plato/datasources/ETT.py @@ -0,0 +1,209 @@ +""" +ETT (Electricity Transformer Temperature) datasource for time series forecasting. + +Supports all ETT datasets: +- ETTh1, ETTh2: Hourly data (1 point per hour) +- ETTm1, ETTm2: 15-minute data (4 points per hour) + +Data from: https://github.com/zhouhaoyi/ETDataset +""" + +import logging +import os +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from plato.config import Config +from plato.datasources import base + + +class ETTDataset(Dataset): + """ETT time series dataset with sliding window.""" + + def __init__(self, data, context_length, prediction_length, stride=1): + """ + Create dataset with sliding windows. + + Args: + data: pandas DataFrame or numpy array with shape (timesteps, channels) + context_length: Number of historical timesteps + prediction_length: Number of future timesteps to predict + stride: Stride for sliding window + """ + if isinstance(data, pd.DataFrame): + # Remove date column if present + if "date" in data.columns: + data = data.drop("date", axis=1) + data = data.values + + self.data = torch.FloatTensor(data) + self.context_length = context_length + self.prediction_length = prediction_length + self.stride = stride + + # Calculate number of valid windows + total_length = context_length + prediction_length + self.num_windows = max(0, (len(data) - total_length) // stride + 1) + + def __len__(self): + return self.num_windows + + def __getitem__(self, idx): + """Return past_values and future_values for PatchTSMixer.""" + start_idx = idx * self.stride + end_context = start_idx + self.context_length + end_future = end_context + self.prediction_length + + past_values = self.data[start_idx:end_context] + future_values = self.data[end_context:end_future] + + return { + "past_values": past_values, + "future_values": future_values, + } + + +class DataSource(base.DataSource): + """ETT datasource for time series forecasting (ETTh1, ETTh2, ETTm1, ETTm2).""" + + # Dataset configurations + DATASET_INFO = { + "ETTh1": {"freq": "hourly", "points_per_hour": 1}, + "ETTh2": {"freq": "hourly", "points_per_hour": 1}, + "ETTm1": {"freq": "15min", "points_per_hour": 4}, + "ETTm2": {"freq": "15min", "points_per_hour": 4}, + } + + def __init__(self, **kwargs): + super().__init__() + + # Get dataset name + dataset_name = kwargs.get( + "dataset_name", getattr(Config().data, "dataset_name", "ETTh1") + ) + + # Validate dataset name + if dataset_name not in self.DATASET_INFO: + raise ValueError( + f"Unknown ETT dataset: {dataset_name}. " + f"Supported datasets: {list(self.DATASET_INFO.keys())}" + ) + + logging.info( + "Using %s (Electricity Transformer Temperature) dataset", dataset_name + ) + + dataset_info = self.DATASET_INFO[dataset_name] + logging.info( + "Dataset frequency: %s (%d points per hour)", + dataset_info["freq"], + dataset_info["points_per_hour"], + ) + + # Get configuration + context_length = getattr(Config().trainer, "context_length", 512) + prediction_length = getattr(Config().trainer, "prediction_length", 96) + + # Download and load the data + data_path = self._download_data(dataset_name) + df = pd.read_csv(data_path) + + logging.info( + "Loaded %s dataset with %d timesteps and %d channels", + dataset_name, + len(df), + len(df.columns) - 1, + ) # -1 for date column + + # Split into train/val/test following the standard ETT split used by HF examples + # Standard split: 12 months train, 4 months val, 4 months test + points_per_hour = dataset_info["points_per_hour"] + train_end = 12 * 30 * 24 * points_per_hour # 12 months + val_end = train_end + 4 * 30 * 24 * points_per_hour # + 4 months + test_end = train_end + 8 * 30 * 24 * points_per_hour # + 8 months + + # Shift val/test start back by context_length so their first window has history + val_start = max(0, train_end - context_length) + test_start = max(0, val_end - context_length) + + train_df = df[:train_end] + val_df = df[val_start:val_end] + test_df = df[test_start:test_end] + + # Compute train mean/std per channel and normalize all splits (matches HF demo preprocessing) + feature_cols = [col for col in df.columns if col != "date"] + train_features = train_df[feature_cols] + eps = 1e-6 + feature_mean = train_features.mean() + feature_std = train_features.std().replace(0, eps) + + train_norm = ((train_features - feature_mean) / feature_std).to_numpy() + val_norm = ((val_df[feature_cols] - feature_mean) / feature_std).to_numpy() + test_norm = ((test_df[feature_cols] - feature_mean) / feature_std).to_numpy() + + logging.info( + "%s split - train: %d, val: %d, test: %d", + dataset_name, + len(train_df), + len(val_df), + len(test_df), + ) + + # Create datasets with sliding windows + self.trainset = ETTDataset( + train_norm, context_length, prediction_length, stride=1 + ) + + # Evaluate on the standard test split with full coverage + self.testset = ETTDataset(test_norm, context_length, prediction_length, stride=1) + + logging.info( + "Created %d training windows and %d test windows", + len(self.trainset), + len(self.testset), + ) + + def _download_data(self, dataset_name): + """Download ETT dataset from GitHub if not already present.""" + data_dir = Path(Config().params["data_path"]) / "ETT-small" + data_dir.mkdir(parents=True, exist_ok=True) + + data_file = data_dir / f"{dataset_name}.csv" + + if data_file.exists(): + logging.info("%s.csv already exists", dataset_name) + return str(data_file) + + # Download from GitHub + logging.info("Downloading %s.csv from GitHub ...", dataset_name) + url = f"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/{dataset_name}.csv" + + try: + import urllib.request + + urllib.request.urlretrieve(url, str(data_file)) + logging.info("Successfully downloaded %s.csv", dataset_name) + except Exception as e: + logging.error("Failed to download %s.csv: %s", dataset_name, e) + raise RuntimeError( + f"Could not download {dataset_name} dataset from {url}. " + f"Please download it manually to {data_file}" + ) from e + + return str(data_file) + + def num_train_examples(self): + return len(self.trainset) + + def num_test_examples(self): + return len(self.testset) + + def get_train_set(self): + return self.trainset + + def get_test_set(self): + return self.testset diff --git a/plato/datasources/huggingface.py b/plato/datasources/huggingface.py index 99267496e..800c79154 100644 --- a/plato/datasources/huggingface.py +++ b/plato/datasources/huggingface.py @@ -9,7 +9,9 @@ import logging import os +import torch from datasets import load_dataset, load_from_disk +from torch.utils.data import Dataset as TorchDataset from transformers import ( AutoConfig, AutoTokenizer, @@ -21,6 +23,78 @@ from plato.config import Config from plato.datasources import base +from plato.utils.timeseries_utils import is_timeseries_model + + +class TimeSeriesDatasetWrapper(TorchDataset): + """ + Wrapper for time series data from HuggingFace datasets. + Converts HuggingFace dataset format to standard time-series format used by HuggingFace time-series models. + """ + + def __init__(self, hf_dataset, context_length, prediction_length): + """ + Args: + hf_dataset: HuggingFace dataset with time series data + context_length: Number of historical timesteps + prediction_length: Number of future timesteps to predict + """ + self.dataset = hf_dataset + self.context_length = context_length + self.prediction_length = prediction_length + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + """ + Returns time series data in PatchTSMixer format. + + Expected HuggingFace dataset format: + - 'past_values'/'future_values': Pre-split time series + - 'target': Full time series to be split + """ + item = self.dataset[idx] + + # Handle different dataset formats + if isinstance(item, dict): + if "past_values" in item and "future_values" in item: + # Already in the right format + return { + "past_values": torch.FloatTensor(item["past_values"]), + "future_values": torch.FloatTensor(item["future_values"]), + } + elif "target" in item: + # Extract from 'target' field and split + target = torch.FloatTensor(item["target"]) + else: + raise ValueError( + f"Dataset must contain either 'past_values'/'future_values' or 'target' field. " + f"Found keys: {list(item.keys())}" + ) + else: + target = item if torch.is_tensor(item) else torch.FloatTensor(item) + + # If 1D, add channel dimension: (length,) -> (length, 1) + if target.dim() == 1: + target = target.unsqueeze(-1) + + # Split into past and future + if len(target) < self.context_length + self.prediction_length: + raise ValueError( + f"Time series too short: got {len(target)} timesteps, " + f"need at least {self.context_length + self.prediction_length}" + ) + + past_values = target[: self.context_length] + future_values = target[ + self.context_length : self.context_length + self.prediction_length + ] + + return { + "past_values": past_values, + "future_values": future_values, + } class DataSource(base.DataSource): @@ -51,6 +125,50 @@ def __init__(self, **kwargs): if callable(save_to_disk): save_to_disk(saved_data_path) + # Determine dataset type from config or model type + model_type = getattr(Config().trainer, "model_type", None) + dataset_type = getattr(Config().data, "dataset_type", "text") + + is_timeseries = is_timeseries_model( + model_type=model_type, dataset_type=dataset_type + ) + + if is_timeseries: + self._init_timeseries_dataset() + else: + self._init_text_dataset() + + def _init_timeseries_dataset(self): + """Initialize time series dataset.""" + logging.info("Initializing time series dataset") + + # Get time series parameters from config + context_length = getattr(Config().trainer, "context_length", 512) + prediction_length = getattr(Config().trainer, "prediction_length", 96) + + # Wrap datasets + train_split = ( + "train" if "train" in self.dataset else list(self.dataset.keys())[0] + ) + test_split = ( + "test" + if "test" in self.dataset + else "validation" + if "validation" in self.dataset + else train_split + ) + + self.trainset = TimeSeriesDatasetWrapper( + self.dataset[train_split], context_length, prediction_length + ) + self.testset = TimeSeriesDatasetWrapper( + self.dataset[test_split], context_length, prediction_length + ) + + def _init_text_dataset(self): + """Initialize text/NLP dataset.""" + logging.info("Initializing text/NLP dataset") + parser = HfArgumentParser(TrainingArguments) (self.training_args,) = parser.parse_args_into_dataclasses( args=["--output_dir=/tmp", "--report_to=none"] diff --git a/plato/datasources/nanochat.py b/plato/datasources/nanochat.py new file mode 100644 index 000000000..7e53c71c6 --- /dev/null +++ b/plato/datasources/nanochat.py @@ -0,0 +1,274 @@ +""" +Streaming datasource backed by the vendored Nanochat project. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable + +try: + import torch + from torch.utils.data import IterableDataset +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat datasource requires PyTorch. " + "Install torch via the project's optional dependencies." + ) from exc + +from plato.config import Config +from plato.datasources.base import DataSource as BaseDataSource +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + +DEFAULT_VOCAB_SIZE = 50304 +DEFAULT_SEQUENCE_LENGTH = 2048 + + +def _resolve_base_dir(base_dir: str | Path | None) -> Path | None: + if base_dir is None: + return None + return Path(base_dir).expanduser().resolve() + + +def _parquet_available(base_dir: Path | None) -> bool: + try: + ensure_nanochat_importable() + from nanochat.dataset import ( # type: ignore[attr-defined] + DATA_DIR, + list_parquet_files, + ) + except (ThirdPartyImportError, ImportError): # pragma: no cover - defensive + return False + + if base_dir is not None: + candidate_dir = base_dir / "base_data" + if not candidate_dir.exists(): + return False + parquet_dir = candidate_dir + else: + parquet_dir = Path(DATA_DIR) + + try: + return len(list_parquet_files(str(parquet_dir))) > 0 + except FileNotFoundError: + return False + + +@dataclass +class _SyntheticState: + generator: torch.Generator + + +class NanochatStreamingDataset(IterableDataset): + """Iterable dataset yielding (inputs, targets) token tensors.""" + + def __init__( + self, + *, + split: str, + batch_size: int, + sequence_length: int, + mode: str, + base_dir: Path | None, + max_batches: int | None, + tokenizer_threads: int, + tokenizer_batch_size: int, + device: str, + vocab_size: int, + synthetic_seed: int, + ): + super().__init__() + if split not in {"train", "val"}: + raise ValueError("split must be 'train' or 'val'.") + + if mode not in {"auto", "parquet", "synthetic"}: + raise ValueError("mode must be 'auto', 'parquet', or 'synthetic'.") + + self.split = split + self.batch_size = batch_size + self.sequence_length = sequence_length + self.base_dir = base_dir + self.max_batches = max_batches + self.tokenizer_threads = tokenizer_threads + self.tokenizer_batch_size = tokenizer_batch_size + self.device = device + self.vocab_size = vocab_size + self.synthetic_seed = synthetic_seed + + resolved_mode = mode + if resolved_mode == "auto": + resolved_mode = "parquet" if _parquet_available(base_dir) else "synthetic" + self.mode = resolved_mode + self._synthetic_state: _SyntheticState | None = None + + # Configure Nanochat's base directory if provided. + if self.base_dir is not None: + os.environ.setdefault("NANOCHAT_BASE_DIR", str(self.base_dir)) + + def _synthetic_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: + if self._synthetic_state is None: + generator = torch.Generator() + generator.manual_seed(self.synthetic_seed) + self._synthetic_state = _SyntheticState(generator=generator) + + generator = self._synthetic_state.generator + while True: + tokens = torch.randint( + low=0, + high=self.vocab_size, + size=(self.batch_size, self.sequence_length + 1), + dtype=torch.long, + generator=generator, + ) + inputs = tokens[:, :-1].contiguous() + targets = tokens[:, 1:].contiguous() + yield inputs, targets + + def _parquet_iterable(self) -> Iterable[tuple[torch.Tensor, torch.Tensor]]: + ensure_nanochat_importable() + from nanochat.dataloader import ( # type: ignore[attr-defined] + tokenizing_distributed_data_loader, + ) + + loader = tokenizing_distributed_data_loader( + self.batch_size, + self.sequence_length, + split=self.split, + tokenizer_threads=self.tokenizer_threads, + tokenizer_batch_size=self.tokenizer_batch_size, + device=self.device, + ) + for inputs, targets in loader: + yield ( + inputs.to(dtype=torch.long).contiguous(), + targets.to(dtype=torch.long).contiguous(), + ) + + def __iter__(self): + iterable = ( + self._parquet_iterable() + if self.mode == "parquet" + else self._synthetic_iterable() + ) + + for batch_index, batch in enumerate(iterable): + if self.max_batches is not None and batch_index >= self.max_batches: + break + yield batch + + def __len__(self) -> int: + if self.max_batches is None: + raise TypeError("Streaming dataset does not have a finite length.") + return self.max_batches + + +class DataSource(BaseDataSource): # type: ignore[misc] + """Plato datasource exposing Nanochat token streams.""" + + def __init__( + self, + *, + batch_size: int | None = None, + sequence_length: int | None = None, + mode: str = "auto", + base_dir: str | Path | None = None, + max_train_batches: int | None = 64, + max_val_batches: int | None = 8, + tokenizer_threads: int = 4, + tokenizer_batch_size: int = 128, + device: str = "cpu", + vocab_size: int = DEFAULT_VOCAB_SIZE, + synthetic_seed: int = 42, + ): + super().__init__() + + cfg_data = getattr(Config(), "data", None) + if cfg_data is not None: + mode = getattr(cfg_data, "mode", mode) + base_dir = getattr(cfg_data, "base_dir", base_dir) + max_train_batches = getattr( + cfg_data, "max_train_batches", max_train_batches + ) + max_val_batches = getattr(cfg_data, "max_val_batches", max_val_batches) + tokenizer_threads = getattr( + cfg_data, "tokenizer_threads", tokenizer_threads + ) + tokenizer_batch_size = getattr( + cfg_data, "tokenizer_batch_size", tokenizer_batch_size + ) + device = getattr(cfg_data, "device", device) + vocab_size = getattr(cfg_data, "vocab_size", vocab_size) + synthetic_seed = getattr(cfg_data, "synthetic_seed", synthetic_seed) + + config = getattr(Config(), "parameters", None) + model_conf = getattr(config, "model", None) + default_seq_len = DEFAULT_SEQUENCE_LENGTH + if model_conf is not None and hasattr(model_conf, "_asdict"): + seq_len_candidate = model_conf._asdict().get("sequence_len") + if isinstance(seq_len_candidate, int) and seq_len_candidate > 0: + default_seq_len = seq_len_candidate + + resolved_sequence_len = sequence_length or default_seq_len + resolved_batch_size = batch_size or getattr( + getattr(Config(), "trainer", None), "batch_size", 1 + ) + + resolved_base_dir = _resolve_base_dir(base_dir) + dataset_mode = mode + if dataset_mode == "auto": + dataset_mode = ( + "parquet" if _parquet_available(resolved_base_dir) else "synthetic" + ) + + self.trainset: NanochatStreamingDataset = NanochatStreamingDataset( + split="train", + batch_size=resolved_batch_size, + sequence_length=resolved_sequence_len, + mode=dataset_mode, + base_dir=resolved_base_dir, + max_batches=max_train_batches, + tokenizer_threads=tokenizer_threads, + tokenizer_batch_size=tokenizer_batch_size, + device=device, + vocab_size=vocab_size, + synthetic_seed=synthetic_seed, + ) + self.testset: NanochatStreamingDataset = NanochatStreamingDataset( + split="val", + batch_size=resolved_batch_size, + sequence_length=resolved_sequence_len, + mode=dataset_mode, + base_dir=resolved_base_dir, + max_batches=max_val_batches, + tokenizer_threads=tokenizer_threads, + tokenizer_batch_size=tokenizer_batch_size, + device=device, + vocab_size=vocab_size, + synthetic_seed=synthetic_seed + 1, + ) + self.sequence_length = resolved_sequence_len + self.batch_size = resolved_batch_size + self.mode = dataset_mode + + @staticmethod + def input_shape(): + """Return the default input shape (sequence length).""" + return (DEFAULT_SEQUENCE_LENGTH,) + + def num_train_examples(self) -> int: + dataset = self.trainset + if dataset.max_batches is None: + raise RuntimeError( + "Nanochat datasource streams infinity; configure max_train_batches to report size." + ) + return dataset.max_batches * dataset.batch_size + + def num_test_examples(self) -> int: + dataset = self.testset + if dataset.max_batches is None: + raise RuntimeError( + "Nanochat datasource streams infinity; configure max_val_batches to report size." + ) + return dataset.max_batches * dataset.batch_size diff --git a/plato/datasources/registry.py b/plato/datasources/registry.py index e2f101de3..f71f49b44 100644 --- a/plato/datasources/registry.py +++ b/plato/datasources/registry.py @@ -8,10 +8,12 @@ from plato.config import Config from plato.datasources import ( cinic10, + ETT, feature, femnist, huggingface, lora, + nanochat, purchase, texas, tiny_imagenet, @@ -27,6 +29,8 @@ "Texas": texas, "TinyImageNet": tiny_imagenet, "Feature": feature, + "Nanochat": nanochat, + "ETT": ETT, } registered_partitioned_datasources = {"FEMNIST": femnist} @@ -39,6 +43,10 @@ "CIFAR10": ("Torchvision", {"dataset_name": "CIFAR10"}), "CIFAR100": ("Torchvision", {"dataset_name": "CIFAR100"}), "CelebA": ("Torchvision", {"dataset_name": "CelebA"}), + "ETTh1": ("ETT", {"dataset_name": "ETTh1"}), + "ETTh2": ("ETT", {"dataset_name": "ETTh2"}), + "ETTm1": ("ETT", {"dataset_name": "ETTm1"}), + "ETTm2": ("ETT", {"dataset_name": "ETTm2"}), } diff --git a/plato/evaluators/__init__.py b/plato/evaluators/__init__.py new file mode 100644 index 000000000..62d84e994 --- /dev/null +++ b/plato/evaluators/__init__.py @@ -0,0 +1,3 @@ +"""Evaluation helpers for Plato integrations.""" + +# Intentionally empty for now; modules register themselves when imported. diff --git a/plato/evaluators/nanochat_core.py b/plato/evaluators/nanochat_core.py new file mode 100644 index 000000000..97a49e9f2 --- /dev/null +++ b/plato/evaluators/nanochat_core.py @@ -0,0 +1,358 @@ +""" +Adapter utilities to run Nanochat's CORE evaluation benchmark within Plato. +""" + +from __future__ import annotations + +import contextlib +import csv +import gzip +import json +import logging +import os +import random +import sys +import tarfile +import time +import zipfile +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +try: + import torch +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat CORE evaluation requires PyTorch. " + "Install the `nanochat` extra (includes torch)." + ) from exc + +import requests +import yaml + +from plato.config import Config +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + +LOGGER = logging.getLogger(__name__) + +# URL for the CORE evaluation bundle +EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" + + +@contextlib.contextmanager +def _download_guard(data_path: str): + """Serialize dataset downloads to avoid concurrent corruption.""" + os.makedirs(data_path, exist_ok=True) + lock_file = os.path.join(data_path, ".download.lock") + lock_fd = None + waited = False + + try: + while True: + try: + lock_fd = os.open(lock_file, os.O_CREAT | os.O_EXCL | os.O_RDWR) + break + except FileExistsError: + if not waited: + LOGGER.info( + "Another process is preparing the dataset at %s. Waiting.", + data_path, + ) + waited = True + time.sleep(1) + yield + finally: + if lock_fd is not None: + os.close(lock_fd) + try: + os.remove(lock_file) + except FileNotFoundError: + pass + + +def _download_eval_bundle(url: str, data_path: str) -> None: + """Download the CORE evaluation bundle from a URL if not already available.""" + url_parse = urlparse(url) + file_name = os.path.join(data_path, url_parse.path.split("/")[-1]) + os.makedirs(data_path, exist_ok=True) + sentinel = Path(f"{file_name}.complete") + + if sentinel.exists(): + return + + with _download_guard(data_path): + if sentinel.exists(): + return + + LOGGER.info("Downloading CORE evaluation bundle from %s.", url) + + res = requests.get(url, stream=True, timeout=60) + total_size = int(res.headers.get("Content-Length", 0)) + downloaded_size = 0 + + with open(file_name, "wb+") as file: + for chunk in res.iter_content(chunk_size=1024): + if not chunk: + continue + downloaded_size += len(chunk) + file.write(chunk) + file.flush() + if total_size: + sys.stdout.write(f"\r{100 * downloaded_size / total_size:.1f}%") + sys.stdout.flush() + if total_size: + sys.stdout.write("\n") + + # Unzip the compressed file just downloaded + LOGGER.info("Decompressing the CORE evaluation bundle.") + name, suffix = os.path.splitext(file_name) + + if file_name.endswith("tar.gz"): + with tarfile.open(file_name, "r:gz") as tar: + tar.extractall(data_path) + os.remove(file_name) + elif suffix == ".zip": + LOGGER.info("Extracting %s to %s.", file_name, data_path) + with zipfile.ZipFile(file_name, "r") as zip_ref: + zip_ref.extractall(data_path) + os.remove(file_name) + elif suffix == ".gz": + with gzip.open(file_name, "rb") as zipped_file: + with open(name, "wb") as unzipped_file: + unzipped_file.write(zipped_file.read()) + os.remove(file_name) + else: + LOGGER.warning("Unknown compressed file type for %s.", file_name) + + sentinel.touch() + LOGGER.info("CORE evaluation bundle downloaded and extracted successfully.") + + +def _resolve_bundle_paths(bundle_dir: str | Path | None) -> tuple[Path, Path, Path]: + """Resolve the configuration, metadata, and dataset paths for CORE evaluation.""" + ensure_nanochat_importable() + + def _get_default_base_path() -> Path: + """Get the default base path, trying nanochat first, then Plato's data directory.""" + try: + from nanochat.common import get_base_dir # pylint: disable=import-error + + return Path(get_base_dir()) + except (ImportError, OSError, PermissionError): + plato_data_path = Config().params.get("data_path", "./runtime/data") + path = Path(plato_data_path) / "nanochat" + LOGGER.info("Using Plato data directory for CORE bundle: %s", path) + return path + + # Determine base path + if bundle_dir is not None: + try: + base_path = Path(bundle_dir).expanduser().resolve() + LOGGER.info("Using bundle_dir from config: %s", base_path) + except (OSError, PermissionError, ValueError) as exc: + LOGGER.warning( + "Cannot use bundle_dir '%s': %s. Using default location.", + bundle_dir, + exc, + ) + base_path = _get_default_base_path() + else: + base_path = _get_default_base_path() + + # Ensure base path exists + base_path.mkdir(parents=True, exist_ok=True) + + eval_bundle_dir = base_path / "eval_bundle" + config_path = eval_bundle_dir / "core.yaml" + data_dir = eval_bundle_dir / "eval_data" + metadata_path = eval_bundle_dir / "eval_meta_data.csv" + + # Check if evaluation bundle exists, download if missing + if not config_path.exists() or not data_dir.exists() or not metadata_path.exists(): + LOGGER.info( + "CORE evaluation bundle not found at %s. Downloading automatically...", + base_path, + ) + _download_eval_bundle(EVAL_BUNDLE_URL, str(base_path)) + + # Verify download succeeded + if not config_path.exists(): + raise FileNotFoundError( + f"CORE evaluation config not found at {config_path}. " + "Ensure the Nanochat eval bundle is downloaded." + ) + if not data_dir.exists(): + raise FileNotFoundError( + f"CORE evaluation data directory not found at {data_dir}. " + "Ensure the Nanochat eval bundle is downloaded." + ) + if not metadata_path.exists(): + raise FileNotFoundError( + f"CORE evaluation metadata CSV not found at {metadata_path}." + ) + + return config_path, data_dir, metadata_path + + +def _load_core_tasks(config_path: Path) -> list[dict[str, Any]]: + """Load task definitions from the CORE YAML config.""" + with config_path.open("r", encoding="utf-8") as handle: + config = yaml.safe_load(handle) + tasks = config.get("icl_tasks", []) + if not isinstance(tasks, list) or not tasks: + raise ValueError( + f"No CORE tasks defined in {config_path}. Inspect the eval bundle." + ) + return tasks + + +def _load_metadata(metadata_path: Path) -> dict[str, float]: + """Load random baseline metadata for centering accuracy.""" + baseline_map: dict[str, float] = {} + with metadata_path.open("r", encoding="utf-8") as handle: + reader = csv.DictReader(handle) + for row in reader: + label = row.get("Eval Task") + baseline = row.get("Random baseline") + if label is None or baseline is None: + continue + try: + baseline_map[label] = float(baseline) + except ValueError: + LOGGER.debug("Skipping malformed baseline row: %s", row) + if not baseline_map: + raise ValueError( + f"Random baselines missing in {metadata_path}. Required for CORE metric." + ) + return baseline_map + + +def _load_task_data(data_dir: Path, dataset_uri: str) -> list[dict[str, Any]]: + """Load task dataset rows from newline-delimited JSON.""" + path = data_dir / dataset_uri + if not path.exists(): + raise FileNotFoundError( + f"CORE dataset shard '{dataset_uri}' missing under {data_dir}." + ) + with path.open("r", encoding="utf-8") as handle: + return [json.loads(line.strip()) for line in handle if line.strip()] + + +def _resolve_tokenizer(model) -> Any: + """Obtain a tokenizer compatible with Nanochat core evaluation.""" + tokenizer = getattr(model, "nanochat_tokenizer", None) + if tokenizer is not None: + return tokenizer + + ensure_nanochat_importable() + from nanochat.tokenizer import get_tokenizer # pylint: disable=import-error + + return get_tokenizer() + + +def run_core_evaluation( + model: torch.nn.Module, + *, + tokenizer: Any | None = None, + bundle_dir: str | Path | None = None, + max_per_task: int = -1, + device: torch.device | str | None = None, +) -> dict[str, Any]: + """ + Execute the CORE benchmark for the provided model. + + Args: + model: Nanochat-style autoregressive model. + tokenizer: Optional tokenizer; falls back to nanochat.tokenizer.get_tokenizer(). + bundle_dir: Optional base directory containing `eval_bundle/`. + max_per_task: Optional cap on examples per task for quicker smoke tests (-1 = all). + device: Device to run evaluation on. Defaults to the model's current device. + + Returns: + Dictionary with `results`, `centered_results`, and `core_metric`. + """ + ensure_nanochat_importable() + from nanochat.core_eval import evaluate_task # pylint: disable=import-error + + config_path, data_dir, metadata_path = _resolve_bundle_paths(bundle_dir) + tasks = _load_core_tasks(config_path) + baselines = _load_metadata(metadata_path) + + eval_tokenizer = tokenizer or _resolve_tokenizer(model) + if eval_tokenizer is None: + raise RuntimeError( + "Nanochat CORE evaluation requires a tokenizer. " + "Either attach `model.nanochat_tokenizer` or provide one explicitly." + ) + + if device is None: + try: + first_param = next(model.parameters()) + device = first_param.device + except StopIteration: + device = torch.device("cpu") + if isinstance(device, str): + device = torch.device(device) + + model_device = device + model_was_training = model.training + model = model.to(model_device) + model.eval() + + results: dict[str, float] = {} + centered_results: dict[str, float] = {} + + for task in tasks: + label = task.get("label") + if not label: + LOGGER.debug("Skipping unnamed CORE task entry: %s", task) + continue + + dataset_uri = task.get("dataset_uri") + if not isinstance(dataset_uri, str): + LOGGER.debug( + "Skipping CORE task %s due to missing dataset_uri metadata.", task + ) + continue + + task_meta = { + "task_type": task.get("icl_task_type"), + "dataset_uri": dataset_uri, + "num_fewshot": task.get("num_fewshot", [0])[0], + "continuation_delimiter": task.get("continuation_delimiter", " "), + } + start_time = time.perf_counter() + + data = _load_task_data(data_dir, dataset_uri) + shuffle_rng = random.Random(1337) + shuffle_rng.shuffle(data) + if max_per_task > 0: + data = data[:max_per_task] + + accuracy = evaluate_task(model, eval_tokenizer, data, model_device, task_meta) + baseline = baselines.get(label, 0.0) + centered = (accuracy - 0.01 * baseline) / (1.0 - 0.01 * baseline) + + results[label] = accuracy + centered_results[label] = centered + elapsed = time.perf_counter() - start_time + LOGGER.info( + "CORE task %s | accuracy %.4f | centered %.4f | %.2fs", + label, + accuracy, + centered, + elapsed, + ) + + if model_was_training: + model.train() + + if not centered_results: + raise RuntimeError("No CORE tasks were evaluated; check the eval bundle.") + + core_metric = sum(centered_results.values()) / len(centered_results) + return { + "results": results, + "centered_results": centered_results, + "core_metric": core_metric, + } diff --git a/plato/models/huggingface.py b/plato/models/huggingface.py index 4a8d3aafb..992f3a375 100644 --- a/plato/models/huggingface.py +++ b/plato/models/huggingface.py @@ -10,6 +10,22 @@ from transformers import AutoConfig, AutoModelForCausalLM from plato.config import Config +from plato.utils.timeseries_utils import is_timeseries_model + +try: + from transformers import ( + PatchTSMixerConfig, + PatchTSMixerForPrediction, + PatchTSMixerForPretraining, + PatchTSMixerForRegression, + PatchTSMixerForTimeSeriesClassification, + ) +except ImportError: + PatchTSMixerConfig = None + PatchTSMixerForPrediction = None + PatchTSMixerForTimeSeriesClassification = None + PatchTSMixerForRegression = None + PatchTSMixerForPretraining = None try: from peft import LoraConfig, get_peft_model @@ -36,7 +52,87 @@ def _lora_config_dict(lora_config: Any) -> dict[str, Any]: class Model: - """The CausalLM model loaded from HuggingFace.""" + """The HuggingFace model factory supporting various model types.""" + + @staticmethod + def _get_timeseries_task_type(model_task=None): + """Determine the task type for time series models from config or arguments.""" + trainer_config = Config().trainer + return ( + model_task + or getattr(trainer_config, "model_task", None) + or getattr(trainer_config, "task_type", "forecasting") + ) + + @staticmethod + def _get_patchtsmixer_model(resolved_model_name, cache_dir, model_task=None): + """Load or create a PatchTSMixer model.""" + if PatchTSMixerForPrediction is None: + raise ImportError( + "PatchTSMixer models are not available. " + "Ensure you have transformers>=4.35.0 installed." + ) + + task_type = Model._get_timeseries_task_type(model_task) + + # Try to load pretrained model first + task_models = { + "classification": PatchTSMixerForTimeSeriesClassification, + "regression": PatchTSMixerForRegression, + "pretraining": PatchTSMixerForPretraining, + "forecasting": PatchTSMixerForPrediction, + } + model_class = task_models.get(task_type, PatchTSMixerForPrediction) + + try: + logging.info( + "Attempting to load pretrained PatchTSMixer model: %s", + resolved_model_name, + ) + model = model_class.from_pretrained( + resolved_model_name, cache_dir=cache_dir + ) + logging.info("Successfully loaded pretrained model") + except (OSError, ValueError, Exception): + # If loading fails, create new model from config + logging.info( + "Model '%s' not found as pretrained, creating from config settings", + resolved_model_name, + ) + trainer_config = Config().trainer + + config = PatchTSMixerConfig( + context_length=getattr(trainer_config, "context_length", 512), + prediction_length=getattr(trainer_config, "prediction_length", 96), + num_input_channels=getattr(trainer_config, "num_input_channels", 7), + patch_length=getattr(trainer_config, "patch_length", 8), + patch_stride=getattr(trainer_config, "patch_stride", 8), + d_model=getattr(trainer_config, "d_model", 64), + num_layers=getattr(trainer_config, "num_layers", 8), + expansion_factor=getattr(trainer_config, "expansion_factor", 2), + dropout=getattr(trainer_config, "dropout", 0.2), + head_dropout=getattr(trainer_config, "head_dropout", 0.2), + mode=getattr(trainer_config, "mode", "common_channel"), + gated_attn=getattr(trainer_config, "gated_attn", True), + scaling=getattr(trainer_config, "scaling", "std"), + prediction_channel_indices=getattr( + trainer_config, "prediction_channel_indices", None + ), + ) + + # Set task-specific parameters and create model + if task_type == "classification": + config.num_labels = getattr(trainer_config, "num_classes", 2) + model = PatchTSMixerForTimeSeriesClassification(config) + elif task_type == "regression": + config.num_targets = getattr(trainer_config, "num_targets", 1) + model = PatchTSMixerForRegression(config) + elif task_type == "pretraining": + model = PatchTSMixerForPretraining(config) + else: # forecasting + model = PatchTSMixerForPrediction(config) + + return model @staticmethod def get(model_name=None, **kwargs): # pylint: disable=unused-argument @@ -55,6 +151,25 @@ def get(model_name=None, **kwargs): # pylint: disable=unused-argument if not isinstance(resolved_model_name, str) or not resolved_model_name: raise ValueError("A valid HuggingFace model name must be provided.") + cache_dir = Config().params["model_path"] + "/huggingface" + + # Determine model type from config or model name + model_type = kwargs.get("model_type") or getattr( + getattr(Config(), "trainer", None), "model_type", None + ) + + # Detect if this is a time series model and which type + is_timeseries = is_timeseries_model( + model_name=resolved_model_name, model_type=model_type + ) + + if is_timeseries: + model_task = kwargs.get("model_task") + return Model._get_patchtsmixer_model( + resolved_model_name, cache_dir, model_task + ) + + # Default to CausalLM for backward compatibility config = AutoConfig.from_pretrained(resolved_model_name, **config_kwargs) model = AutoModelForCausalLM.from_pretrained( diff --git a/plato/models/nanochat.py b/plato/models/nanochat.py new file mode 100644 index 000000000..114a891f9 --- /dev/null +++ b/plato/models/nanochat.py @@ -0,0 +1,173 @@ +""" +Factory for Nanochat GPT models integrated with Plato's registry. +""" + +from __future__ import annotations + +import logging +from dataclasses import fields +from typing import Any + +try: + import torch +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat model integration requires PyTorch. " + "Install torch via the project's optional dependencies." + ) from exc + +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + +DEFAULT_MODEL_CONFIG: dict[str, int] = { + "sequence_len": 2048, + "vocab_size": 50304, + "n_layer": 12, + "n_head": 6, + "n_kv_head": 6, + "n_embd": 768, +} + + +def _import_nanochat_modules(): + ensure_nanochat_importable() + from nanochat.checkpoint_manager import ( # type: ignore[attr-defined] + load_model_from_dir, + ) + from nanochat.gpt import GPT, GPTConfig # type: ignore[attr-defined] + + return GPT, GPTConfig, load_model_from_dir + + +def _sanitize_kwargs(kwargs: dict[str, Any], valid_fields: set[str]) -> dict[str, Any]: + """Filter kwargs to those accepted by GPTConfig.""" + config_kwargs = DEFAULT_MODEL_CONFIG.copy() + for key, value in kwargs.items(): + if key in valid_fields: + config_kwargs[key] = value + return config_kwargs + + +def _load_from_checkpoint( + load_dir: str, + *, + device: str | torch.device = "cpu", + phase: str = "train", + model_tag: str | None = None, + step: int | None = None, +): + """Load a Nanochat checkpoint via checkpoint_manager.""" + GPT, _, load_model_from_dir = _import_nanochat_modules() + torch_device = torch.device(device) + model, tokenizer, metadata = load_model_from_dir( + load_dir, + device=torch_device, + phase=phase, + model_tag=model_tag, + step=step, + ) + # Attach helpful metadata to the model for downstream use. + setattr(model, "nanochat_tokenizer", tokenizer) + setattr(model, "nanochat_metadata", metadata) + if not isinstance(model, GPT): + raise TypeError( + "Checkpoint loader returned an unexpected model type. " + "Ensure the checkpoint directory points to Nanochat artifacts." + ) + return model + + +class Model: + """Nanochat GPT factory compatible with Plato's model registry.""" + + @staticmethod + def get(model_name: str | None = None, **kwargs: Any): + """ + Instantiate a Nanochat GPT model. + + Keyword Args: + sequence_len: Context length (tokens). + vocab_size: Token vocabulary size. + n_layer: Number of transformer blocks. + n_head: Attention heads for queries. + n_kv_head: Attention heads for keys/values (MQA/GQA). + n_embd: Hidden dimension width. + init_weights: Whether to run Nanochat's weight initialisation (default True). + load_checkpoint_dir: Optional checkpoint directory produced by Nanochat. + load_checkpoint_tag: Optional subdirectory/model tag within checkpoint dir. + load_checkpoint_step: Optional numeric step to load (defaults to latest). + device: Torch device string for checkpoint loading. + phase: "train" or "eval" when loading checkpoints. + """ + try: + GPT, GPTConfig, _ = _import_nanochat_modules() + except ThirdPartyImportError as exc: # pragma: no cover - defensive branch + raise ImportError( + "Nanochat submodule not found. " + "Run `git submodule update --init --recursive` to populate external/nanochat." + ) from exc + + init_weights = kwargs.pop("init_weights", True) + load_dir = kwargs.pop("load_checkpoint_dir", None) + checkpoint_tag = kwargs.pop("load_checkpoint_tag", None) + checkpoint_step = kwargs.pop("load_checkpoint_step", None) + checkpoint_phase = kwargs.pop("phase", "train") + checkpoint_device = kwargs.pop("device", "cpu") + + # GPTConfig only accepts specific fields; filter unknown kwargs. + config_fields = {field.name for field in fields(GPTConfig)} + config_kwargs = _sanitize_kwargs(kwargs, config_fields) + + if load_dir: + model = _load_from_checkpoint( + load_dir, + device=checkpoint_device, + phase=checkpoint_phase, + model_tag=checkpoint_tag, + step=checkpoint_step, + ) + return model + + # Model vocab_size MUST match tokenizer vocab_size to avoid IndexError + try: + from nanochat.tokenizer import get_tokenizer + + tokenizer = get_tokenizer() + actual_vocab_size = tokenizer.get_vocab_size() + + # Override vocab_size with tokenizer's actual vocab_size + if "vocab_size" in config_kwargs: + configured_vocab = config_kwargs["vocab_size"] + if configured_vocab != actual_vocab_size: + logging.warning( + f"[Nanochat Model] Config specifies vocab_size={configured_vocab}, " + f"but tokenizer has vocab_size={actual_vocab_size}. " + f"Using tokenizer's vocab_size={actual_vocab_size} to match tokenizer." + ) + config_kwargs["vocab_size"] = actual_vocab_size + + logging.info( + f"[Nanochat Model] Using vocab_size={actual_vocab_size} from tokenizer" + ) + except Exception as e: + logging.warning( + f"[Nanochat Model] Could not auto-detect vocab_size from tokenizer: {e}. " + f"Using configured or default vocab_size." + ) + + config = GPTConfig(**config_kwargs) + model = GPT(config) + if init_weights: + model.init_weights() + + # This allows CORE evaluation and other components to access the tokenizer + try: + from nanochat.tokenizer import get_tokenizer + + tokenizer = get_tokenizer() + setattr(model, "nanochat_tokenizer", tokenizer) + except Exception: + pass + # Set max_seq_len for CORE evaluation truncation + setattr(model, "max_seq_len", config.sequence_len) + setattr(model, "nanochat_config", config_kwargs) + return model diff --git a/plato/models/registry.py b/plato/models/registry.py index e6691e3a0..382a2d064 100644 --- a/plato/models/registry.py +++ b/plato/models/registry.py @@ -15,6 +15,7 @@ huggingface, lenet5, multilayer, + nanochat, resnet, torch_hub, vgg, @@ -40,6 +41,7 @@ "torch_hub": torch_hub.Model, "huggingface": huggingface.Model, "vit": vit.Model, + "nanochat": nanochat.Model, } registered_mlx_models = {} diff --git a/plato/processors/nanochat_tokenizer.py b/plato/processors/nanochat_tokenizer.py new file mode 100644 index 000000000..7460ae910 --- /dev/null +++ b/plato/processors/nanochat_tokenizer.py @@ -0,0 +1,135 @@ +""" +Prototype tokenizer processor for Nanochat, wrapping the rustbpe+tiktoken stack. + +This module exercises the build tooling for the Rust extension while providing +an adapter that conforms to Plato's processor interface. +""" + +from __future__ import annotations + +import pickle +from collections.abc import Iterable, Sequence +from pathlib import Path +from typing import Any + +from plato.processors.base import Processor + +SPECIAL_TOKENS = [ + "<|bos|>", + "<|user_start|>", + "<|user_end|>", + "<|assistant_start|>", + "<|assistant_end|>", + "<|python_start|>", + "<|python_end|>", + "<|output_start|>", + "<|output_end|>", +] + +DEFAULT_PATTERN = ( + r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]" + r"|\s+(?!\S)|\s+" +) + + +class NanochatTokenizerProcessor(Processor): + """Prototype tokenizer that can either load a saved encoding or train via rustbpe.""" + + def __init__( + self, + tokenizer_path: str | Path | None = None, + train_corpus: Iterable[str] | None = None, + vocab_size: int = 32000, + pattern: str = DEFAULT_PATTERN, + bos_token: str = "<|bos|>", + prepend_bos: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.tokenizer_path = Path(tokenizer_path) if tokenizer_path else None + self.train_corpus = train_corpus + self.vocab_size = max(vocab_size, 256 + len(SPECIAL_TOKENS)) + self.pattern = pattern + self.bos_token = bos_token + self.prepend_bos = prepend_bos + + self._encoding = self._load_encoding() + self._special_tokens = self._infer_special_tokens() + self._bos_token_id = self._special_tokens.get(self.bos_token) + + def _load_encoding(self): + if self.tokenizer_path: + return self._load_from_pickle(self.tokenizer_path) + if self.train_corpus is not None: + return self._train_from_corpus(self.train_corpus) + raise ValueError("Either tokenizer_path or train_corpus must be provided.") + + def _load_from_pickle(self, path: Path): + with path.open("rb") as handle: + encoding = pickle.load(handle) + return encoding + + def _train_from_corpus(self, corpus: Iterable[str]): + try: + import rustbpe # type: ignore[import-not-found] + except ImportError as exc: # pragma: no cover - guarded import + raise RuntimeError( + "rustbpe extension is required to train a Nanochat tokenizer." + ) from exc + + try: + import tiktoken # type: ignore[import-not-found] + except ImportError as exc: # pragma: no cover - guarded import + raise RuntimeError( + "tiktoken is required to construct Nanochat tokenizer encodings." + ) from exc + + tokenizer = rustbpe.Tokenizer() + tokenizer.train_from_iterator( + iter(corpus), + self.vocab_size - len(SPECIAL_TOKENS), + pattern=self.pattern, + ) + + mergeable_ranks = { + bytes(piece): rank for piece, rank in tokenizer.get_mergeable_ranks() + } + tokens_offset = len(mergeable_ranks) + special_tokens = { + token: tokens_offset + index for index, token in enumerate(SPECIAL_TOKENS) + } + + return tiktoken.Encoding( + name="nanochat-rustbpe", + pat_str=tokenizer.get_pattern(), + mergeable_ranks=mergeable_ranks, + special_tokens=special_tokens, + ) + + def _infer_special_tokens(self): + try: + encode_single_token = self._encoding.encode_single_token + special_token_set = getattr(self._encoding, "special_tokens_set", set()) + except AttributeError as exc: + raise RuntimeError( + "tiktoken encoding missing expected interfaces." + ) from exc + + mapping = {} + for token in SPECIAL_TOKENS: + if token in special_token_set: + mapping[token] = encode_single_token(token) + return mapping + + def _encode_one(self, text: str): + ids = list(self._encoding.encode_ordinary(text)) + if self.prepend_bos and self._bos_token_id is not None: + ids.insert(0, self._bos_token_id) + return ids + + def process(self, data: Any): + if isinstance(data, str): + return self._encode_one(data) + if isinstance(data, Sequence): + return [self._encode_one(item) for item in data] + raise TypeError(f"Unsupported payload type: {type(data)!r}") diff --git a/plato/samplers/iid.py b/plato/samplers/iid.py index 815899cce..848775e99 100644 --- a/plato/samplers/iid.py +++ b/plato/samplers/iid.py @@ -18,29 +18,33 @@ def __init__(self, datasource, client_id, testing): super().__init__() if testing: + # Use the full test set for evaluation to avoid sampling/duplication dataset = datasource.get_test_set() + self.subset_indices = list(range(len(dataset))) else: dataset = datasource.get_train_set() - self.dataset_size = len(dataset) - indices = list(range(self.dataset_size)) - np.random.seed(self.random_seed) - np.random.shuffle(indices) - - partition_size = Config().data.partition_size - total_clients = Config().clients.total_clients - total_size = partition_size * total_clients - - # add extra samples to make it evenly divisible, if needed - if len(indices) < total_size: - while len(indices) < total_size: - indices += indices[: (total_size - len(indices))] - else: - indices = indices[:total_size] - assert len(indices) == total_size - - # Compute the indices of data in the subset for this client - self.subset_indices = indices[(int(client_id) - 1) : total_size : total_clients] + self.dataset_size = len(dataset) + indices = list(range(self.dataset_size)) + np.random.seed(self.random_seed) + np.random.shuffle(indices) + + partition_size = Config().data.partition_size + total_clients = Config().clients.total_clients + total_size = partition_size * total_clients + + # add extra samples to make it evenly divisible, if needed + if len(indices) < total_size: + while len(indices) < total_size: + indices += indices[: (total_size - len(indices))] + else: + indices = indices[:total_size] + assert len(indices) == total_size + + # Compute the indices of data in the subset for this client + self.subset_indices = indices[ + (int(client_id) - 1) : total_size : total_clients + ] def get(self): """Obtains an instance of the sampler.""" diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index ee636cb49..f93541f2c 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -243,27 +243,67 @@ async def _process_reports(self): if hasattr(Config().server, "do_test") and not Config().server.do_test: # Compute the average accuracy from client reports self.accuracy, self.accuracy_std = self.get_accuracy_mean_std(self.updates) - logging.info( - "[%s] Average client accuracy: %.2f%%.", self, 100 * self.accuracy - ) + + trainer = self.require_trainer() + metric_name = getattr(trainer.testing_strategy, "metric_name", "accuracy") + + if metric_name == "mse": + logging.info("[%s] Average client MSE: %.2f.", self, self.accuracy) + elif metric_name == "perplexity" or hasattr( + Config().trainer, "target_perplexity" + ): + logging.info( + "[%s] Average client perplexity: %.2f.", self, self.accuracy + ) + else: + logging.info( + "[%s] Average client accuracy: %.2f%%.", self, 100 * self.accuracy + ) else: # Testing the updated model directly at the server logging.info("[%s] Started model testing.", self) trainer = self.require_trainer() self.accuracy = trainer.test(self.testset, self.testset_sampler) - if hasattr(Config().trainer, "target_perplexity"): + # Extract CORE evaluation results if available (Nanochat CORE evaluation) + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + self._core_metric = core_results.get("core_metric", self.accuracy) + + # If CORE benchmark was run via a Nanochat testing strategy, report the specialized CORE metric instead of the generic 'Global model accuracy' label. + core_metric = getattr(self, "_core_metric", None) + + if core_metric is not None: logging.info( fonts.colourize( - f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" ) ) else: - logging.info( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n" + trainer = self.require_trainer() + metric_name = getattr(trainer.testing_strategy, "metric_name", "accuracy") + + if metric_name == "mse": + logging.info( + fonts.colourize(f"[{self}] Global model MSE: {self.accuracy:.2f}\n") + ) + elif metric_name == "perplexity" or hasattr( + Config().trainer, "target_perplexity" + ): + logging.info( + fonts.colourize( + f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" + ) + ) + else: + logging.info( + fonts.colourize( + f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n" + ) ) - ) self.clients_processed() self.callback_handler.call_event("on_clients_processed", self) @@ -273,9 +313,10 @@ def clients_processed(self) -> None: def get_logged_items(self) -> dict: """Get items to be logged by the LogProgressCallback class in a .csv file.""" - return { + logged = { "round": self.current_round, "accuracy": self.accuracy, + "core_metric": getattr(self, "_core_metric", None), "accuracy_std": self.accuracy_std, "elapsed_time": self.wall_time - self.initial_wall_time, "processing_time": max( @@ -291,6 +332,33 @@ def get_logged_items(self) -> dict: "comm_overhead": self.comm_overhead, } + # Add train_loss if available from client reports + if self.updates and hasattr(self.updates[0].report, "train_loss"): + # Compute weighted average of train_loss across clients + total_samples = sum( + update.report.num_samples + for update in self.updates + if update.report.train_loss is not None + ) + if total_samples > 0: + weighted_loss = sum( + update.report.train_loss * update.report.num_samples + for update in self.updates + if update.report.train_loss is not None + ) + logged["train_loss"] = weighted_loss / total_samples + else: + logged["train_loss"] = None + + # Add core_metric if Nanochat CORE evaluation was performed + if hasattr(self, "_core_metric"): + logged["core_metric"] = self._core_metric + + logged["mse"] = self.accuracy + logged["perplexity"] = self.accuracy + + return logged + @staticmethod def get_accuracy_mean_std(updates): """Compute the accuracy mean and standard deviation across clients.""" diff --git a/plato/servers/fedavg_cs.py b/plato/servers/fedavg_cs.py index fceff363a..eaba3caf6 100644 --- a/plato/servers/fedavg_cs.py +++ b/plato/servers/fedavg_cs.py @@ -253,7 +253,22 @@ async def _process_reports(self): logging.info("[%s] Started model testing.", self) self.accuracy = trainer.test(self.testset, self.testset_sampler) - if hasattr(Config().trainer, "target_perplexity"): + # Extract CORE evaluation results if available (Nanochat CORE evaluation) + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + self._core_metric = core_results.get("core_metric", None) + + core_metric = getattr(self, "_core_metric", None) + if core_metric is not None: + logging.info( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + elif hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" @@ -274,7 +289,22 @@ async def _process_reports(self): logging.info("[%s] Started model testing.", self) self.accuracy = trainer.test(self.testset, self.testset_sampler) - if hasattr(Config().trainer, "target_perplexity"): + # Extract CORE evaluation results if available (Nanochat CORE evaluation) + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + self._core_metric = core_results.get("core_metric", None) + + core_metric = getattr(self, "_core_metric", None) + if core_metric is not None: + logging.info( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + elif hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" diff --git a/plato/servers/split_learning.py b/plato/servers/split_learning.py index 83cf2804a..49fdc2d4c 100644 --- a/plato/servers/split_learning.py +++ b/plato/servers/split_learning.py @@ -99,11 +99,30 @@ async def aggregate_weights(self, updates, baseline_weights, weights_received): self.test_accuracy = trainer.test(self.testset, self.testset_sampler) - logging.warning( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" + if ( + hasattr(trainer, "context") + and "nanochat_core_results" in trainer.context.state + ): + core_results = trainer.context.state["nanochat_core_results"] + core_metric = core_results.get("core_metric", None) + if core_metric is not None: + logging.warning( + fonts.colourize( + f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" + ) + ) + else: + logging.warning( + fonts.colourize( + f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" + ) + ) + else: + logging.warning( + fonts.colourize( + f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" + ) ) - ) self.phase = "prompt" # Change client in next round self.next_client = True diff --git a/plato/trainers/huggingface.py b/plato/trainers/huggingface.py index 61365a638..7f4ed694f 100644 --- a/plato/trainers/huggingface.py +++ b/plato/trainers/huggingface.py @@ -5,6 +5,7 @@ HuggingFace data handling through strategy objects instead of overriding `load_model`/`save_model` hooks. +Supports both text/NLP models and time series models (e.g., PatchTSMixer). """ import logging @@ -39,6 +40,7 @@ TrainingContext, TrainingStepStrategy, ) +from plato.utils.timeseries_utils import is_timeseries_model class HuggingFaceBatch(dict): @@ -79,6 +81,23 @@ def __call__( return HuggingFaceBatch(batch), labels +class TimeSeriesCollateWrapper: + """Collator for time series data (PatchTSMixer format).""" + + def __call__( + self, examples: Iterable[dict] + ) -> tuple[HuggingFaceBatch, torch.Tensor | None]: + """ + Collate time series examples into batches. + + Expected format: {"past_values": tensor, "future_values": tensor} + """ + batch = default_data_collator(list(examples)) + labels = batch.get("future_values", None) + + return HuggingFaceBatch(batch), labels + + def _resolve_hf_loss(outputs, labels, *, allow_fallback: bool = True): """ Resolve a loss tensor from HuggingFace model outputs. @@ -110,8 +129,10 @@ def _resolve_hf_loss(outputs, labels, *, allow_fallback: bool = True): raise ValueError("HuggingFace model did not return a tensor loss.") logits = getattr(outputs, "logits", None) + if logits is None: + logits = getattr(outputs, "prediction_outputs", None) # PatchTSMixer if logits is None and isinstance(outputs, dict): - logits = outputs.get("logits") + logits = outputs.get("logits") or outputs.get("prediction_outputs") if logits is None and isinstance(outputs, tuple) and len(outputs) > 0: logits = outputs[0] @@ -133,6 +154,13 @@ def _resolve_hf_loss(outputs, labels, *, allow_fallback: bool = True): logits = logits.to(labels.device) if labels.device != logits.device else logits labels = labels.to(logits.device) + # Check if this is a regression task (shapes match) -> use MSE + # Time series: logits (batch, pred_len, channels), labels (batch, pred_len, channels) + # Text generation: logits (batch, seq_len, vocab_size), labels (batch, seq_len) + if logits.shape == labels.shape: + return F.mse_loss(logits, labels) + + # Text generation with causal LM -> use cross-entropy vocab_size = logits.size(-1) if logits.ndim > 2: shift_logits = logits[..., :-1, :].contiguous() @@ -196,12 +224,25 @@ def training_step( optimizer.zero_grad() batch_inputs = dict(examples) - if labels is not None: + + # For time series models like PatchTSMixer, future_values should not be passed as 'labels' + # TODO: Need to check if other time series models follow this + is_timeseries = ( + "past_values" in batch_inputs and "future_values" in batch_inputs + ) + + if not is_timeseries and labels is not None: batch_inputs["labels"] = labels batch_inputs.setdefault("return_dict", True) outputs = model(**batch_inputs) - labels_tensor = batch_inputs.get("labels") + + # For time series, get labels from batch_inputs, otherwise from labels argument + labels_tensor = ( + batch_inputs.get("future_values") + if is_timeseries + else batch_inputs.get("labels") + ) loss = _resolve_hf_loss(outputs, labels_tensor) loss_for_backward = loss.div(accum_steps) if accum_steps > 1 else loss @@ -291,10 +332,21 @@ def finalize(self, model, optimizer, context: TrainingContext): class HuggingFaceTestingStrategy(TestingStrategy): - """Evaluates HuggingFace models and reports perplexity based on loss.""" + """Evaluates HuggingFace models (text: perplexity, time series: MSE).""" - def __init__(self, collate_fn: HuggingFaceCollateWrapper): + def __init__(self, collate_fn, is_timeseries=False): self.collate_fn = collate_fn + self.is_timeseries = is_timeseries + + @property + def metric_name(self) -> str: + """Return the name of the metric this strategy computes.""" + if self.is_timeseries: + return "mse" # For time series models, using mean squared error. + elif hasattr(Config().trainer, "target_perplexity"): + return "perplexity" + else: + return "accuracy" def test_model(self, model, config, testset, sampler, context: TrainingContext): batch_size = config.get("batch_size", 1) @@ -324,41 +376,80 @@ def test_model(self, model, config, testset, sampler, context: TrainingContext): model.eval() context.state["eval_loader"] = data_loader - total_loss = 0.0 - total_weight = 0 - - with torch.no_grad(): - for batch_inputs, labels in data_loader: - batch_inputs = batch_inputs.to(context.device) - if labels is not None: - labels = labels.to(context.device) - batch_inputs["labels"] = labels - - batch_inputs.setdefault("return_dict", True) - outputs = model(**batch_inputs) - loss = _resolve_hf_loss(outputs, labels) - - if labels is not None: - weight = labels.ne(-100).sum().item() - if weight == 0: - continue - else: - weight = 1 - - total_loss += loss.item() * weight - total_weight += weight - - model.train() - context.state.pop("eval_loader", None) - - if total_weight == 0: - return float("inf") - - avg_loss = total_loss / total_weight - try: - return math.exp(avg_loss) - except OverflowError: - return float("inf") + if self.is_timeseries: + total_loss = 0.0 + total_samples = 0 + + with torch.no_grad(): + for batch_inputs, labels in data_loader: + batch_inputs = batch_inputs.to(context.device) + if labels is not None: + labels = labels.to(context.device) + batch_inputs["future_values"] = labels + + batch_inputs.setdefault("return_dict", True) + outputs = model(**batch_inputs) + + loss = getattr(outputs, "loss", None) + if loss is None: + loss = ( + outputs.get("loss") if isinstance(outputs, dict) else None + ) + + if loss is not None: + batch_size = ( + batch_inputs["past_values"].size(0) + if "past_values" in batch_inputs + else 1 + ) + total_loss += loss.item() * batch_size + total_samples += batch_size + + model.train() + context.state.pop("eval_loader", None) + + if total_samples == 0: + return float("inf") + + # Return MSE + return total_loss / total_samples + else: + # Text/NLP: compute perplexity + total_loss = 0.0 + total_weight = 0 + + with torch.no_grad(): + for batch_inputs, labels in data_loader: + batch_inputs = batch_inputs.to(context.device) + if labels is not None: + labels = labels.to(context.device) + batch_inputs["labels"] = labels + + batch_inputs.setdefault("return_dict", True) + outputs = model(**batch_inputs) + loss = _resolve_hf_loss(outputs, labels) + + if labels is not None: + weight = labels.ne(-100).sum().item() + if weight == 0: + continue + else: + weight = 1 + + total_loss += loss.item() * weight + total_weight += weight + + model.train() + context.state.pop("eval_loader", None) + + if total_weight == 0: + return float("inf") + + avg_loss = total_loss / total_weight + try: + return math.exp(avg_loss) + except OverflowError: + return float("inf") def _split_callback_types( @@ -433,57 +524,75 @@ def __init__(self, model=None, callbacks=None): ] ) - model_name = Config().trainer.model_name - config_kwargs = { - "cache_dir": None, - "revision": "main", - "use_auth_token": None, - } - self.config = AutoConfig.from_pretrained(model_name, **config_kwargs) + model_name = getattr(Config().trainer, "model_name", "") + model_type = getattr(Config().trainer, "model_type", None) - cache_dir = Config().params["data_path"] - use_fast_tokenizer = True - revision = "main" - auth_token = getattr( - getattr(Config(), "parameters", None), "huggingface_token", None + # Detect if this is a time series model + self._is_timeseries = is_timeseries_model( + model_name=model_name, model_type=model_type ) - if "llama" in model_name: - if isinstance(auth_token, str) and auth_token: - self.tokenizer = LlamaTokenizer.from_pretrained( - model_name, - config=self.config, - cache_dir=cache_dir, - use_fast=use_fast_tokenizer, - revision=revision, - use_auth_token=auth_token, - ) - else: - self.tokenizer = LlamaTokenizer.from_pretrained( - model_name, - config=self.config, - cache_dir=cache_dir, - use_fast=use_fast_tokenizer, - revision=revision, - ) - else: - if isinstance(auth_token, str) and auth_token: - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, - config=self.config, - cache_dir=cache_dir, - use_fast=use_fast_tokenizer, - revision=revision, - use_auth_token=auth_token, - ) + if self._is_timeseries: + logging.info( + "Detected time series model (type: %s, name: %s)", + model_type, + model_name, + ) + + self.config = None + if not self._is_timeseries: + config_kwargs = { + "cache_dir": None, + "revision": "main", + "use_auth_token": None, + } + self.config = AutoConfig.from_pretrained(model_name, **config_kwargs) + + self.tokenizer = None + if not self._is_timeseries: + cache_dir = Config().params["data_path"] + use_fast_tokenizer = True + revision = "main" + auth_token = getattr( + getattr(Config(), "parameters", None), "huggingface_token", None + ) + + if "llama" in model_name: + if isinstance(auth_token, str) and auth_token: + self.tokenizer = LlamaTokenizer.from_pretrained( + model_name, + config=self.config, + cache_dir=cache_dir, + use_fast=use_fast_tokenizer, + revision=revision, + use_auth_token=auth_token, + ) + else: + self.tokenizer = LlamaTokenizer.from_pretrained( + model_name, + config=self.config, + cache_dir=cache_dir, + use_fast=use_fast_tokenizer, + revision=revision, + ) else: - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, - config=self.config, - cache_dir=cache_dir, - use_fast=use_fast_tokenizer, - revision=revision, - ) + if isinstance(auth_token, str) and auth_token: + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + config=self.config, + cache_dir=cache_dir, + use_fast=use_fast_tokenizer, + revision=revision, + use_auth_token=auth_token, + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + config=self.config, + cache_dir=cache_dir, + use_fast=use_fast_tokenizer, + revision=revision, + ) grad_accum_steps = getattr(Config().trainer, "gradient_accumulation_steps", 1) try: @@ -491,7 +600,15 @@ def __init__(self, model=None, callbacks=None): except (TypeError, ValueError): grad_accum_steps = 1 self._gradient_accumulation_steps = max(grad_accum_steps, 1) - self._collate_wrapper = HuggingFaceCollateWrapper(self.tokenizer) + + # Choose collator based on model type + if self._is_timeseries: + self._collate_wrapper = TimeSeriesCollateWrapper() + logging.info("Using TimeSeriesCollateWrapper for time series model") + else: + self._collate_wrapper = HuggingFaceCollateWrapper(self.tokenizer) + logging.info("Using HuggingFaceCollateWrapper for text model") + self.training_args.gradient_accumulation_steps = ( self._gradient_accumulation_steps ) @@ -513,14 +630,16 @@ def __init__(self, model=None, callbacks=None): num_workers=0, pin_memory=True, ), - testing_strategy=HuggingFaceTestingStrategy(self._collate_wrapper), + testing_strategy=HuggingFaceTestingStrategy( + self._collate_wrapper, is_timeseries=self._is_timeseries + ), ) if hf_callbacks: self.add_callbacks(hf_callbacks) model_instance = self._require_model() - if hasattr(model_instance, "loss_type"): + if hasattr(model_instance, "loss_type") and not self._is_timeseries: setattr(model_instance, "loss_type", "ForCausalLM") # Ensure model checkpoints can be saved when model names include slashes. diff --git a/plato/trainers/nanochat.py b/plato/trainers/nanochat.py new file mode 100644 index 000000000..307ccf911 --- /dev/null +++ b/plato/trainers/nanochat.py @@ -0,0 +1,296 @@ +""" +Trainer wiring for Nanochat models within the composable trainer framework. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Iterable, Iterator, List, Sequence + +try: + import torch +except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "Nanochat trainer requires PyTorch. " + "Install torch via the project's optional dependencies." + ) from exc + +from plato.config import Config +from plato.datasources.nanochat import NanochatStreamingDataset +from plato.evaluators.nanochat_core import run_core_evaluation +from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies.base import ( + DataLoaderStrategy, + OptimizerStrategy, + TestingStrategy, + TrainingContext, + TrainingStepStrategy, +) +from plato.trainers.strategies.data_loader import DefaultDataLoaderStrategy +from plato.utils.third_party import ThirdPartyImportError, ensure_nanochat_importable + + +def _first_element_collate(batch: Sequence[Any]) -> Any: + """Return the first (and only) element from a DataLoader batch.""" + if not batch: + raise ValueError("Received empty batch from Nanochat dataset.") + return batch[0] + + +class NanochatDataLoaderStrategy(DataLoaderStrategy): + """Use identity collation for pre-batched Nanochat streaming datasets.""" + + def __init__(self): + self._fallback = DefaultDataLoaderStrategy() + + def create_train_loader( + self, trainset, sampler, batch_size: int, context: TrainingContext + ) -> torch.utils.data.DataLoader | Iterator: + if isinstance(trainset, NanochatStreamingDataset): + return torch.utils.data.DataLoader( + trainset, + batch_size=1, + shuffle=False, + sampler=None, + num_workers=0, + collate_fn=_first_element_collate, + ) + return self._fallback.create_train_loader( + trainset, sampler, batch_size, context + ) + + +class NanochatTrainingStepStrategy(TrainingStepStrategy): + """Call Nanochat's integrated loss computation during training.""" + + def __init__(self, loss_reduction: str = "mean"): + self.loss_reduction = loss_reduction + + def training_step( + self, + model: torch.nn.Module, + optimizer, + examples: torch.Tensor, + labels: torch.Tensor, + loss_criterion, + context: TrainingContext, + ) -> torch.Tensor: + optimizer.zero_grad() + if labels is not None: + loss = model(examples, targets=labels, loss_reduction=self.loss_reduction) + else: + outputs = model(examples) + loss = loss_criterion(outputs, labels) + + if not isinstance(loss, torch.Tensor): + raise TypeError( + "Nanochat model forward pass must return a torch.Tensor loss." + ) + + loss.backward() + optimizer.step() + context.state["optimizer_step_completed"] = True + return loss.detach() + + +@dataclass +class _OptimizerBundle: + """Bundle multiple optimizers under a single interface.""" + + optimizers: List[torch.optim.Optimizer] + + def zero_grad(self) -> None: + for optimizer in self.optimizers: + optimizer.zero_grad() + + def step(self) -> None: + for optimizer in self.optimizers: + optimizer.step() + + def state_dict(self) -> dict[str, Any]: + return { + "optimizers": [optimizer.state_dict() for optimizer in self.optimizers], + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + for optimizer, payload in zip( + self.optimizers, + state_dict.get("optimizers", []), + strict=False, # type: ignore[arg-type] + ): + optimizer.load_state_dict(payload) + + @property + def param_groups(self) -> list[dict[str, Any]]: + groups: list[dict[str, Any]] = [] + for optimizer in self.optimizers: + groups.extend(getattr(optimizer, "param_groups", [])) + return groups + + def params_state_update(self) -> None: + for optimizer in self.optimizers: + hook = getattr(optimizer, "params_state_update", None) + if callable(hook): + hook() + + +class NanochatOptimizerStrategy(OptimizerStrategy): + """Adapter around nanochat.gpt.GPT.setup_optimizers.""" + + def __init__( + self, + *, + unembedding_lr: float = 0.004, + embedding_lr: float = 0.2, + matrix_lr: float = 0.02, + weight_decay: float = 0.0, + ): + self.unembedding_lr = unembedding_lr + self.embedding_lr = embedding_lr + self.matrix_lr = matrix_lr + self.weight_decay = weight_decay + + def create_optimizer( + self, model: torch.nn.Module, context: TrainingContext + ) -> _OptimizerBundle: + if not isinstance(model, torch.nn.Module): + raise TypeError("Nanochat optimizer strategy requires a torch.nn.Module.") + + setup_fn = getattr(model, "setup_optimizers", None) + if not callable(setup_fn): + raise AttributeError( + "Nanochat model is expected to expose setup_optimizers()." + ) + + optimizers = setup_fn( + unembedding_lr=self.unembedding_lr, + embedding_lr=self.embedding_lr, + matrix_lr=self.matrix_lr, + weight_decay=self.weight_decay, + ) + if not isinstance(optimizers, Iterable): + raise TypeError("setup_optimizers() must return an iterable of optimizers.") + + optimizer_list: list[torch.optim.Optimizer] = list(optimizers) + if not optimizer_list: + raise ValueError("setup_optimizers() returned an empty optimizer list.") + return _OptimizerBundle(optimizer_list) + + +class NanochatTestingStrategy(TestingStrategy): + """Compute average token loss over the validation iterator.""" + + def __init__(self, reduction: str = "sum"): + self.reduction = reduction + + def test_model( + self, + model: torch.nn.Module, + config: dict[str, Any], + testset, + sampler, + context: TrainingContext, + ) -> float: + if not isinstance(testset, NanochatStreamingDataset): + raise TypeError( + "NanochatTestingStrategy expects a NanochatStreamingDataset instance." + ) + + model.eval() + total_loss = 0.0 + total_tokens = 0 + + with torch.no_grad(): + for inputs, targets in testset: + inputs = inputs.to(context.device) + targets = targets.to(context.device) + loss = model(inputs, targets=targets, loss_reduction=self.reduction) + total_loss += float(loss.item()) + total_tokens += targets.numel() + + model.train() + if total_tokens == 0: + return float("nan") + return total_loss / total_tokens + + +class NanochatCoreTestingStrategy(TestingStrategy): + """Evaluate the CORE benchmark and return the aggregate metric.""" + + def __init__(self, bundle_dir: str | None = None, max_per_task: int = -1): + self.bundle_dir = bundle_dir + self.max_per_task = max_per_task + + def test_model( + self, + model: torch.nn.Module, + config: dict[str, Any], + testset, + sampler, + context: TrainingContext, + ) -> float: + device = context.device or next(model.parameters()).device + tokenizer = getattr(model, "nanochat_tokenizer", None) + results = run_core_evaluation( + model, + tokenizer=tokenizer, + bundle_dir=self.bundle_dir, + max_per_task=self.max_per_task, + device=device, + ) + context.state["nanochat_core_results"] = results + return float(results["core_metric"]) + + +class Trainer(ComposableTrainer): + """Composable trainer specialised for Nanochat workloads.""" + + def __init__( + self, + model=None, + callbacks=None, + *, + optimizer_params: dict[str, Any] | None = None, + loss_reduction: str = "mean", + ): + try: + ensure_nanochat_importable() + except ThirdPartyImportError as exc: # pragma: no cover - defensive branch + raise ImportError( + "Nanochat trainer requires the external/nanochat submodule. " + "Run `git submodule update --init --recursive`." + ) from exc + + optimizer_strategy = NanochatOptimizerStrategy( + **(optimizer_params or {}), + ) + training_step_strategy = NanochatTrainingStepStrategy( + loss_reduction=loss_reduction + ) + data_loader_strategy = NanochatDataLoaderStrategy() + + evaluation_cfg = getattr(Config(), "evaluation", None) + evaluation_type = ( + getattr(evaluation_cfg, "type", "").lower() if evaluation_cfg else "" + ) + if evaluation_type == "nanochat_core": + max_per_task = getattr(evaluation_cfg, "max_per_task", -1) + max_per_task_value = -1 if max_per_task is None else int(max_per_task) + testing_strategy = NanochatCoreTestingStrategy( + bundle_dir=getattr(evaluation_cfg, "bundle_dir", None), + max_per_task=max_per_task_value, + ) + else: + testing_strategy = NanochatTestingStrategy() + + super().__init__( + model=model, + callbacks=callbacks, + loss_strategy=None, + optimizer_strategy=optimizer_strategy, + training_step_strategy=training_step_strategy, + lr_scheduler_strategy=None, + model_update_strategy=None, + data_loader_strategy=data_loader_strategy, + testing_strategy=testing_strategy, + ) diff --git a/plato/trainers/registry.py b/plato/trainers/registry.py index ff0db0cb9..ef47a287e 100644 --- a/plato/trainers/registry.py +++ b/plato/trainers/registry.py @@ -12,6 +12,9 @@ gan, split_learning, ) +from plato.trainers import ( + nanochat as nanochat_trainer, +) registered_trainers = { "composable": composable.ComposableTrainer, @@ -19,6 +22,7 @@ "timm_basic": basic.TrainerWithTimmScheduler, "gan": gan.Trainer, "split_learning": split_learning.Trainer, + "nanochat": nanochat_trainer.Trainer, } diff --git a/plato/utils/third_party.py b/plato/utils/third_party.py new file mode 100644 index 000000000..890764e13 --- /dev/null +++ b/plato/utils/third_party.py @@ -0,0 +1,42 @@ +""" +Helpers for accessing vendored third-party projects. +""" + +from __future__ import annotations + +import sys +from functools import lru_cache +from pathlib import Path + + +class ThirdPartyImportError(ImportError): + """Raised when a vendored third-party project is unavailable.""" + + +@lru_cache(maxsize=None) +def _nanochat_root() -> Path: + """Return the root directory of the Nanochat submodule.""" + repo_root = Path(__file__).resolve().parents[2] + nanochat_root = repo_root / "external" / "nanochat" + if not nanochat_root.exists(): + raise ThirdPartyImportError( + "Nanochat submodule missing. Run `git submodule update --init --recursive`." + ) + return nanochat_root + + +def ensure_nanochat_importable() -> Path: + """ + Ensure the vendored Nanochat package is importable. + + Returns: + Path to the Nanochat project root. + + Raises: + ThirdPartyImportError: If the vendored Nanochat tree is missing. + """ + nanochat_root = _nanochat_root() + path_str = str(nanochat_root) + if path_str not in sys.path: + sys.path.insert(0, path_str) + return nanochat_root diff --git a/plato/utils/timeseries_utils.py b/plato/utils/timeseries_utils.py new file mode 100644 index 000000000..8f2a0e935 --- /dev/null +++ b/plato/utils/timeseries_utils.py @@ -0,0 +1,39 @@ +""" +Utility functions for time series model detection and handling. +""" + +from typing import Optional, Tuple + + +def is_timeseries_model( + model_name: Optional[str] = None, + model_type: Optional[str] = None, + dataset_type: Optional[str] = None, +) -> bool: + """ + Check if a model/dataset is for time series. + + Args: + model_name: Name of the model + model_type: Type of model from config + dataset_type: Type of dataset from config + + Returns: + True if this is a time series model, False otherwise + """ + model_name_lower = model_name.lower() if model_name else "" + model_type_lower = model_type.lower() if model_type else "" + + # Check for PatchTSMixer + if ( + model_type_lower == "patchtsmixer" + or "patchtsmixer" in model_name_lower + or "timeseries" in model_name_lower + ): + return True + + # Check dataset type + if dataset_type and dataset_type.lower() == "timeseries": + return True + + return False diff --git a/pyproject.toml b/pyproject.toml index 099d06475..228545350 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,18 @@ dp = [ mlx = [ "mlx", ] +nanochat = [ + "datasets", + "numpy", + "psutil", + "regex", + "tiktoken", + "tokenizers", + "torch", + "wandb", + "jinja2", + "PyYAML", +] [project.urls] Homepage = "https://github.com/TL-System/plato" @@ -89,11 +101,13 @@ members = [ "examples/detector", "examples/gradient_leakage_attacks", "tools", + "examples/nanochat", ] [dependency-groups] dev = [ "pytest", + "ruff", "ty", ] diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index c22b0875c..2b1bb9b75 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -157,3 +157,54 @@ def test_config_base_path_used_without_cli_override(tmp_path: Path, monkeypatch) if hasattr(Config, "args"): delattr(Config, "args") Config._cli_overrides = {} + + +def test_config_loads_evaluation_section(tmp_path: Path, monkeypatch): + """Test that [evaluation] configuration is properly loaded.""" + config_base = tmp_path / "runtime" + config_path = tmp_path / "config.toml" + + config_data = { + "clients": {"type": "simple", "total_clients": 2, "per_round": 1}, + "server": {"address": "127.0.0.1", "port": 8000}, + "data": {"datasource": "MNIST"}, + "trainer": {"type": "basic", "rounds": 1, "epochs": 1, "batch_size": 10}, + "algorithm": {"type": "fedavg"}, + "evaluation": { + "type": "nanochat_core", + "max_per_task": 128, + "bundle_dir": "/custom/path/to/nanochat", + }, + } + + toml_writer.dump(config_data, config_path) + + monkeypatch.delenv("config_file", raising=False) + monkeypatch.setattr( + sys, + "argv", + [ + sys.argv[0], + "--config", + str(config_path), + "--base", + str(config_base), + ], + ) + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} + + config = Config() + + assert hasattr(config, "evaluation") + assert config.evaluation.type == "nanochat_core" + assert config.evaluation.max_per_task == 128 + assert config.evaluation.bundle_dir == "/custom/path/to/nanochat" + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} diff --git a/tests/test_nanochat_integration.py b/tests/test_nanochat_integration.py new file mode 100644 index 000000000..6a9c174e8 --- /dev/null +++ b/tests/test_nanochat_integration.py @@ -0,0 +1,162 @@ +"""Nanochat integration smoke checks (optional, require nanochat extras).""" + +from __future__ import annotations + +import importlib.util +import os + +import pytest + +from plato.config import Config, ConfigNode +from plato.utils.third_party import ensure_nanochat_importable + +pytestmark = pytest.mark.integration + +_RUSTBPE_AVAILABLE = importlib.util.find_spec("rustbpe") is not None + + +@pytest.mark.skipif( + not _RUSTBPE_AVAILABLE, + reason="Nanochat tokenizer tests require rustbpe extension (install nanochat extras).", +) +def test_nanochat_tokenizer_processor_round_trip(tmp_path): + """Train a tiny tokenizer via rustbpe and encode a sample string.""" + pytest.importorskip( + "tiktoken", reason="Nanochat tokenizer tests require tiktoken (nanochat extra)." + ) + from plato.processors.nanochat_tokenizer import NanochatTokenizerProcessor + + corpus = ["hello nanochat", "minimal tokenizer exercise"] + processor = NanochatTokenizerProcessor( + train_corpus=corpus, + vocab_size=300, + prepend_bos=False, + ) + encoded = processor.process("hello nanochat") + assert isinstance(encoded, list) + assert len(encoded) > 0 + + +def test_nanochat_trainer_smoke(temp_config, tmp_path): + """Run one training step with synthetic Nanochat data on CPU.""" + _ = pytest.importorskip( + "torch", reason="Nanochat trainer smoke requires torch (nanochat extra)." + ) + ensure_nanochat_importable() + _ = pytest.importorskip( + "nanochat", + reason="Nanochat trainer smoke requires the nanochat package (nanochat extra).", + ) + from plato.datasources.nanochat import DataSource as NanochatDataSource + from plato.models.nanochat import Model as NanochatModel + from plato.trainers.nanochat import Trainer as NanochatTrainer + + cfg = Config() + cfg.trainer.type = "nanochat" + cfg.trainer.model_name = "nanochat_smoke" + cfg.trainer.batch_size = 2 + cfg.trainer.rounds = 1 + cfg.trainer.epochs = 1 + + cfg.parameters.model = { + "sequence_len": 16, + "vocab_size": 512, + "n_layer": 1, + "n_head": 2, + "n_kv_head": 2, + "n_embd": 128, + } + + datasource = NanochatDataSource( + batch_size=cfg.trainer.batch_size, + sequence_length=cfg.parameters.model["sequence_len"], + mode="synthetic", + max_train_batches=2, + max_val_batches=1, + device="cpu", + vocab_size=cfg.parameters.model["vocab_size"], + synthetic_seed=123, + ) + + model = NanochatModel.get( + sequence_len=cfg.parameters.model["sequence_len"], + vocab_size=cfg.parameters.model["vocab_size"], + n_layer=cfg.parameters.model["n_layer"], + n_head=cfg.parameters.model["n_head"], + n_kv_head=cfg.parameters.model["n_kv_head"], + n_embd=cfg.parameters.model["n_embd"], + init_weights=True, + ) + + trainer = NanochatTrainer(model=model) + trainset = datasource.get_train_set() + elapsed = trainer.train(trainset, sampler=None) + + assert isinstance(elapsed, float) + assert elapsed >= 0.0 + + model_dir = Config().params["model_path"] + checkpoint_name = f"{cfg.trainer.model_name}_{trainer.client_id}_{Config().params['run_id']}.safetensors" + assert os.path.exists(os.path.join(model_dir, checkpoint_name)) + + +def test_nanochat_trainer_selects_core_eval_strategy(temp_config, monkeypatch): + """Ensure evaluation config triggers the CORE testing strategy.""" + _ = pytest.importorskip( + "torch", reason="Nanochat trainer requires torch (nanochat extra)." + ) + ensure_nanochat_importable() + _ = pytest.importorskip( + "nanochat", + reason="Nanochat trainer requires the nanochat package (nanochat extra).", + ) + from plato.models.nanochat import Model as NanochatModel + from plato.trainers.nanochat import ( + NanochatCoreTestingStrategy, + ) + from plato.trainers.nanochat import ( + Trainer as NanochatTrainer, + ) + + monkeypatch.setattr( + "plato.evaluators.nanochat_core.run_core_evaluation", + lambda *args, **kwargs: { + "results": {}, + "centered_results": {}, + "core_metric": 0.0, + }, + ) + + cfg = Config() + cfg.trainer.type = "nanochat" + cfg.trainer.model_name = "nanochat_core" + cfg.trainer.batch_size = 1 + cfg.trainer.rounds = 1 + cfg.trainer.epochs = 1 + cfg.parameters.model = { + "sequence_len": 16, + "vocab_size": 512, + "n_layer": 1, + "n_head": 2, + "n_kv_head": 2, + "n_embd": 128, + } + cfg.evaluation = ConfigNode.from_object( + { + "type": "nanochat_core", + "max_per_task": 1, + } + ) + + model = NanochatModel.get( + sequence_len=cfg.parameters.model["sequence_len"], + vocab_size=cfg.parameters.model["vocab_size"], + n_layer=cfg.parameters.model["n_layer"], + n_head=cfg.parameters.model["n_head"], + n_kv_head=cfg.parameters.model["n_kv_head"], + n_embd=cfg.parameters.model["n_embd"], + init_weights=True, + ) + + trainer = NanochatTrainer(model=model) + assert isinstance(trainer.testing_strategy, NanochatCoreTestingStrategy)