Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3355698
Initial support for the Nanochat model and its evaluation benchmark (…
baochunli Oct 28, 2025
822a1cb
Added support for vendoring the external Nanochat repo as a git submo…
baochunli Oct 28, 2025
e039b2c
ruff check --fix & ruff format.
baochunli Oct 28, 2025
ac5beba
Added benchmark configuration ([evaluation]) support in config.py.
Jasmine-Yuting-Zhang Oct 29, 2025
4501781
Added test to verify that [evaluation] configuration is properly loaded.
Jasmine-Yuting-Zhang Oct 29, 2025
a8efbea
Fixed tensor contiguity issue in datasource.
Jasmine-Yuting-Zhang Oct 29, 2025
f0bc22d
Fixed KeyError: 'train_loss'.
Jasmine-Yuting-Zhang Oct 29, 2025
4810080
Fixed train_loss aggregation in FedAvg server to handle None values.
Jasmine-Yuting-Zhang Oct 30, 2025
eb736eb
Added evaluation configs for nanochat CORE metric.
Jasmine-Yuting-Zhang Oct 30, 2025
d9fe94a
Added automatic download of nanochat CORE evaluation bundle.
Jasmine-Yuting-Zhang Oct 30, 2025
6f34950
Using tokenizer's vocab_size to match between model and tokenizer.
Jasmine-Yuting-Zhang Oct 30, 2025
35f25eb
Added outputs for Nanochat CORE evaluation in FedAvg server.
Jasmine-Yuting-Zhang Oct 30, 2025
e4ae761
Added specific logging output for CORE benchmark metrics.
Jasmine-Yuting-Zhang Oct 31, 2025
432fe50
Typed the Nanochat datasource/optimizer plumbing and enforced valid C…
baochunli Nov 7, 2025
da04815
All nanochat tests now pass.
baochunli Nov 7, 2025
279d05e
Updated nanochat README with setup and troubleshooting notes.
Nov 8, 2025
af0bafa
Added configuration file for NanoChat Parquet mode.
Jasmine-Yuting-Zhang Nov 13, 2025
2b7cf3d
Formatted code with Ruff and applied autofixes.
Jasmine-Yuting-Zhang Nov 13, 2025
736b29e
Added two configuration files for PatchTSMixer model.
Jasmine-Yuting-Zhang Nov 13, 2025
1fe0e22
Added MSE metric output for time series models.
Jasmine-Yuting-Zhang Nov 13, 2025
205043d
Added GitHub dataset handling (ETT datasets) for PatchTSMixer model.
Jasmine-Yuting-Zhang Nov 14, 2025
12721f1
Added ETT datasource to the registry.
Jasmine-Yuting-Zhang Nov 14, 2025
73d19de
Added TimeSeriesDatasetWrapper support for time-series datasets in da…
Jasmine-Yuting-Zhang Nov 14, 2025
3bb745c
Added PatchTSMixer model support to HuggingFace model factory.
Jasmine-Yuting-Zhang Nov 14, 2025
03abe2f
Added timeseries_utils module with is_timeseries_model function.
Jasmine-Yuting-Zhang Nov 14, 2025
c30dafc
Added time-series support to the HuggingFace trainer.
Jasmine-Yuting-Zhang Nov 14, 2025
dc90468
Added documentation for time series model PatchTSMixer.
Jasmine-Yuting-Zhang Nov 14, 2025
ed13def
Added links to time series model in docs.
Jasmine-Yuting-Zhang Nov 14, 2025
7bc7f43
Revised dataset split to improve training performance.
Jasmine-Yuting-Zhang Nov 21, 2025
173095a
Added a larger PatchTSMixer config file with extended hyperparameters.
Jasmine-Yuting-Zhang Dec 1, 2025
5562465
Revised MSE evaluation logs for time series models.
Jasmine-Yuting-Zhang Dec 1, 2025
b9778ec
Used uv ruff format .
Jasmine-Yuting-Zhang Dec 1, 2025
20ab574
Refactored ETT data splitting and normalization for consistency with …
Jasmine-Yuting-Zhang Dec 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "plato/models/t2tvit"]
path = plato/models/t2tvit
url = https://github.com/yitu-opensource/T2T-ViT
[submodule "external/nanochat"]
path = external/nanochat
url = https://github.com/karpathy/nanochat.git
8 changes: 2 additions & 6 deletions cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,14 @@ def main() -> None:
continue

cleared = clean_directory(runtime_dir)
print(
f"Failed to delete {runtime_dir}; cleared {cleared} items instead."
)
print(f"Failed to delete {runtime_dir}; cleared {cleared} items instead.")
fallback_dirs += 1
fallback_items += cleared

if runtime_total == 0:
print("No runtime directories found.")
else:
print(
f"Removed {runtime_removed} of {runtime_total} runtime directories."
)
print(f"Removed {runtime_removed} of {runtime_total} runtime directories.")
if fallback_dirs:
print(
f"Cleared {fallback_items} items in "
Expand Down
53 changes: 53 additions & 0 deletions configs/Nanochat/parquet_micro.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
[clients]
type = "simple"
total_clients = 10
per_round = 3
do_test = true

[server]
address = "127.0.0.1"
port = 8000
simulate_wall_time = false
checkpoint_path = "checkpoints/nanochat/parquet"
model_path = "models/nanochat/parquet"

[data]
datasource = "Nanochat"
sampler = "iid"
partition_size = 1
random_seed = 1
mode = "parquet"
max_train_batches = 16
max_val_batches = 1
tokenizer_threads = 2
tokenizer_batch_size = 32
device = "cuda"
vocab_size = 512
synthetic_seed = 123

[evaluation]
type = "nanochat_core"
# bundle_dir = "~/nanochat"
max_per_task = 16

[trainer]
type = "nanochat"
rounds = 10000
epochs = 5
batch_size = 1
model_name = "nanochat"
optimizer = "nanochat"

[algorithm]
type = "fedavg"

[parameters.model]
sequence_len = 256
vocab_size = 50304
n_layer = 4
n_head = 4
n_kv_head = 4
n_embd = 256

[results]
types = "round, elapsed_time, core_metric, train_loss"
54 changes: 54 additions & 0 deletions configs/Nanochat/synthetic_micro.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
[clients]

type = "simple"
total_clients = 1
per_round = 1
do_test = false

[server]
address = "127.0.0.1"
port = 8000
simulate_wall_time = false
checkpoint_path = "checkpoints/nanochat/synthetic"
model_path = "models/nanochat/synthetic"

[data]
datasource = "Nanochat"
sampler = "iid"
partition_size = 1
random_seed = 1
mode = "synthetic"
max_train_batches = 4
max_val_batches = 1
tokenizer_threads = 2
tokenizer_batch_size = 64
device = "cpu"
vocab_size = 512
synthetic_seed = 123

[evaluation]
type = "nanochat_core"
# bundle_dir = "~/nanochat" # Optional, defaults to nanochat base dir or Plato's data directory
max_per_task = 16 # Optional, -1 means run all examples

[trainer]
type = "nanochat"
rounds = 1
epochs = 1
batch_size = 2
model_name = "nanochat"
optimizer = "nanochat"

[algorithm]
type = "fedavg"

[parameters.model]
sequence_len = 128
vocab_size = 512
n_layer = 2
n_head = 4
n_kv_head = 4
n_embd = 256

[results]
types = "round, elapsed_time, core_metric, train_loss"
67 changes: 67 additions & 0 deletions configs/TimeSeries/patchtsmixer_custom.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Federated Learning with PatchTSMixer for Time Series Forecasting
# This configuration demonstrates using the IBM Granite PatchTSMixer model
# with time series data from HuggingFace datasets

[clients]
type = "simple"
total_clients = 1
per_round = 1
do_test = false

[server]
address = "127.0.0.1"
port = 8000
simulate_wall_time = false
checkpoint_path = "checkpoints/timeseries/patchtsmixer"
model_path = "models/timeseries/patchtsmixer"

[data]
# ETTh1: Electricity Transformer Temperature dataset (7 features)
datasource = "ETTh1"

partition_size = 100 # Number of training samples
sampler = "iid"
random_seed = 1

[trainer]
type = "HuggingFace"
rounds = 3
max_concurrency = 2
model_type = "huggingface"

# Train from scratch - simpler for testing
model_name = "custom_patchtsmixer"

# Task type: forecasting, classification, regression, or pretraining
task_type = "forecasting"

# PatchTSMixer specific parameters (smaller model for testing)
context_length = 64
prediction_length = 24
num_input_channels = 7 # ETTh1 has 7 features (HUFL, HULL, MUFL, MULL, LUFL, LULL, OT)
patch_length = 8
patch_stride = 8
d_model = 32 # Hidden dimension of the model. Recommended to set it as a multiple of patch_length (i.e. 2-8X of patch_len). Larger value indicates more complex model.
num_layers = 3 # Number of layers to use. Recommended range is 3-15. Larger value indicates more complex model.
expansion_factor = 2 # Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model.
dropout = 0.5
head_dropout = 0.7
mode = "common_channel"
gated_attn = true
scaling = "std"

# Training parameters
epochs = 2
batch_size = 8
optimizer = "Adam"

[algorithm]
type = "fedavg"

[parameters]
[parameters.optimizer]
lr = 0.001
weight_decay = 0.0

[results]
types = "round, elapsed_time, accuracy"
67 changes: 67 additions & 0 deletions configs/TimeSeries/patchtsmixer_large.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Federated Learning with Large PatchTSMixer for Time Series Forecasting
# This configuration matches the PatchTSMixer paper parameters for ETTh1

[clients]
type = "simple"
total_clients = 1
per_round = 1
do_test = true # Enable testing to evaluate model on test set

[server]
address = "127.0.0.1"
port = 8000
simulate_wall_time = false
checkpoint_path = "checkpoints/timeseries/patchtsmixer"
model_path = "models/timeseries/patchtsmixer"

[data]
# ETTh1: Electricity Transformer Temperature dataset (7 features)
datasource = "ETTh1"

partition_size = 6960 # Full ETTh1 training set
sampler = "iid"
random_seed = 1

[trainer]
type = "HuggingFace"
rounds = 1000
max_concurrency = 10
model_type = "huggingface"
model_name = "custom_patchtsmixer"

# Task type: forecasting, classification, regression, or pretraining
task_type = "forecasting"

# PatchTSMixer specific parameters
context_length = 512 # Paper uses 512 context length
prediction_length = 96 # Standard benchmark (paper tests 96, 192, 336, 720)
num_input_channels = 7 # ETTh1 has 7 features (HUFL, HULL, MUFL, MULL, LUFL, LULL, OT)
patch_length = 16
patch_stride = 8

d_model = 128
num_layers = 8
expansion_factor = 2

dropout = 0.3 # Increase regularization to prevent overfitting
head_dropout = 0.3 # Increase regularization to prevent overfitting

# Model configuration
mode = "common_channel"
gated_attn = true
scaling = "std"

epochs = 100
batch_size = 64
optimizer = "Adam"

[algorithm]
type = "fedavg"

[parameters]
[parameters.optimizer]
lr = 0.0001
weight_decay = 0.001

[results]
types = "round, elapsed_time, mse"
68 changes: 68 additions & 0 deletions configs/TimeSeries/patchtsmixer_pretrained.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Federated Learning with PatchTSMixer for Time Series Forecasting
# This configuration demonstrates using the IBM Granite PatchTSMixer model
# with time series data from HuggingFace datasets

[clients]
type = "simple"
total_clients = 1
per_round = 1
do_test = false

[server]
address = "127.0.0.1"
port = 8000
simulate_wall_time = false
checkpoint_path = "checkpoints/timeseries/patchtsmixer"
model_path = "models/timeseries/patchtsmixer"

[data]
# ETTh1: Electricity Transformer Temperature dataset (7 features)
datasource = "ETTh1"

partition_size = 100 # Number of training samples
sampler = "iid"
random_seed = 1

[trainer]
type = "HuggingFace"
rounds = 3
max_concurrency = 2
model_type = "huggingface"

# Use pre-trained IBM Granite model
# For pre-trained model, the some settings must match pretrained model
model_name = "ibm-granite/granite-timeseries-patchtsmixer"

# Task type: forecasting, classification, regression, or pretraining
task_type = "forecasting"

# PatchTSMixer specific parameters (matching pretrained model)
context_length = 512
prediction_length = 96
num_input_channels = 7
patch_length = 16
patch_stride = 8
d_model = 64
num_layers = 8
expansion_factor = 2 # Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model.
dropout = 0.5
head_dropout = 0.7
mode = "common_channel"
gated_attn = true
scaling = "std"

# Training parameters
epochs = 2 # Reduced for testing
batch_size = 8 # Reduced for testing
optimizer = "Adam"

[algorithm]
type = "fedavg"

[parameters]
[parameters.optimizer]
lr = 0.001
weight_decay = 0.0

[results]
types = "round, elapsed_time, accuracy"
5 changes: 5 additions & 0 deletions docs/docs/examples/Getting Started.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ Plato supports both Linux with NVIDIA GPUs and macOS with M1/M2/M4/M4 GPUs. It w

- [Model Pruning Algorithms](algorithms/13.%20Model%20Pruning%20Algorithms.md)

- [Gradient Leakage Attacks and Defences](algorithms/14.%20Gradient%20Leakage%20Attacks%20and%20Defences.md)

- [Time Series Models](algorithms/15.%20Time%20Series%20Models.md)


## Case Studies

- [Federated LoRA Fine-Tuning](case-studies/1.%20LoRA.md)
Expand Down
15 changes: 15 additions & 0 deletions docs/docs/examples/algorithms/15. Time Series Models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
### PatchTSMixer

PatchTSMixer is a lightweight time-series modeling approach based on the MLP-Mixer architecture. The model can be pretrained and subsequently used for various downstream tasks such as forecasting, classification and regression.

```bash
uv run python plato.py -c configs/TimeSeries/patchtsmixer_pretrained.toml
```

For custom model configurations without using pretrained weights:

```bash
uv run python plato.py -c configs/TimeSeries/patchtsmixer_custom.toml
```

**Reference:** V. Ekambaram, A. Jati, N. Nguyen, S. Sinthong, K. Kalagnanam. "[TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://dl.acm.org/doi/abs/10.1145/3580305.3599533)," in Proc. ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD), 2023. – [[Code available]](https://github.com/ibm-granite/granite-tsfm)
1 change: 1 addition & 0 deletions docs/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Welcome to *Plato*, a software framework to facilitate scalable, reproducible, a
- **[Poisoning Detection](examples/algorithms/12.%20Poisoning%20Detection%20Algorithms.md)**
- **[Model Pruning](examples/algorithms/13.%20Model%20Pruning%20Algorithms.md)**
- **[Gradient Leakage Attacks and Defences](examples/algorithms/14.%20Gradient%20Leakage%20Attacks%20and%20Defences.md)**
- **[Time Series Models](examples/algorithms/15.%20Time%20Series%20Models.md)**

## Configuration Settings

Expand Down
1 change: 1 addition & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ nav:
- Poisoning Detection: examples/algorithms/12. Poisoning Detection Algorithms.md
- Model Pruning: examples/algorithms/13. Model Pruning Algorithms.md
- Gradient Leakage Attacks and Defences: examples/algorithms/14. Gradient Leakage Attacks and Defences.md
- Time Series Models: examples/algorithms/15. Time Series Models.md
- Case Studies:
- Federated LoRA Fine-Tuning: examples/case-studies/1. LoRA.md
- Composable Trainer API: examples/case-studies/2. Composable Trainer.md
Expand Down
Loading
Loading