Skip to content

Commit 5ae0426

Browse files
authored
Migrated all examples to use the new composable trainer architecture. (#389)
1 parent 2a6edf5 commit 5ae0426

File tree

63 files changed

+2797
-13388
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+2797
-13388
lines changed

docs/docs/examples/algorithms/8. Personalized Federated Learning Algorithms.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
FedRep learns a shared data representation (the global layers) across clients and a unique, personalized local "head" (the local layers) for each client. In this implementation, after each round of local training, only the representation on each client is retrieved and uploaded to the server for aggregation.
44

55
```bash
6-
cd examples/personalized_fl/fedrep
7-
uv run fedrep.py -c ../configs/fedrep_CIFAR10_resnet18.yml
6+
cd examples/personalized_fl
7+
uv run fedrep/fedrep.py -c configs/fedrep_CIFAR10_resnet18.yml
88
```
99

1010
**Reference:** Collins et al., "[Exploiting Shared Representations for Personalized Federated Learning](http://proceedings.mlr.press/v139/collins21a/collins21a.pdf)," in Proc. International Conference on Machine Learning (ICML), 2021.
@@ -16,8 +16,8 @@ uv run fedrep.py -c ../configs/fedrep_CIFAR10_resnet18.yml
1616
FedBABU only updates the global layers of the model during FL training. The local layers are frozen at the beginning of each local training epoch.
1717

1818
```bash
19-
cd examples/personalized_fl/fedbabu
20-
uv run fedbabu.py -c ../configs/fedbabu_CIFAR10_resnet18.yml
19+
cd examples/personalized_fl
20+
uv run fedbabu/fedbabu.py -c configs/fedbabu_CIFAR10_resnet18.yml
2121
```
2222

2323
**Reference:** Oh et al., "[FedBABU: Towards Enhanced Representation for Federated Image Classification](https://openreview.net/forum?id=HuaYQfggn5u)," in Proc. International Conference on Learning Representations (ICLR), 2022.
@@ -29,8 +29,8 @@ uv run fedbabu.py -c ../configs/fedbabu_CIFAR10_resnet18.yml
2929
APFL jointly optimizes the global model and personalized models by interpolating between local and personalized models. Once the global model is received, each client will carry out a regular local update, and then conduct a personalized optimization to acquire a trained personalized model. The trained global model and the personalized model will subsequently be combined using the parameter "alpha," which can be dynamically updated.
3030

3131
```bash
32-
cd examples/personalized_fl/apfl
33-
uv run apfl.py -c ../configs/apfl_CIFAR10_resnet18.yml
32+
cd examples/personalized_fl
33+
uv run apfl/apfl.py -c configs/apfl_CIFAR10_resnet18.yml
3434
```
3535

3636
**Reference:** Deng et al., "[Adaptive Personalized Federated Learning](https://arxiv.org/abs/2003.13461)," in Arxiv, 2021.
@@ -42,8 +42,8 @@ uv run apfl.py -c ../configs/apfl_CIFAR10_resnet18.yml
4242
FedPer learns a global representation and personalized heads, but makes simultaneous local updates for both sets of parameters, therefore makes the same number of local updates for the head and the representation on each local round.
4343

4444
```bash
45-
cd examples/personalized_fl/fedper
46-
uv run fedper.py -c ../configs/fedper_CIFAR10_resnet18.yml
45+
cd examples/personalized_fl
46+
uv run fedper/fedper.py -c configs/fedper_CIFAR10_resnet18.yml
4747
```
4848

4949
**Reference:** Arivazhagan et al., "[Federated learning with personalization layers](https://arxiv.org/abs/1912.00818)," in Arxiv, 2019.
@@ -55,8 +55,8 @@ uv run fedper.py -c ../configs/fedper_CIFAR10_resnet18.yml
5555
With LG-FedAvg only the global layers of a model are sent to the server for aggregation, while each client keeps local layers to itself.
5656

5757
```bash
58-
cd examples/personalized_fl/lgfedavg
59-
uv run lgfedavg.py -c ../configs/lgfedavg_CIFAR10_resnet18.yml
58+
cd examples/personalized_fl
59+
uv run lgfedavg/lgfedavg.py -c configs/lgfedavg_CIFAR10_resnet18.yml
6060
```
6161

6262
**Reference:** Liang et al., "[Think Locally, Act Globally: Federated Learning with Local and Global Representations](https://arxiv.org/abs/2001.01523)," in Proc. NeurIPS, 2019.
@@ -68,8 +68,8 @@ uv run lgfedavg.py -c ../configs/lgfedavg_CIFAR10_resnet18.yml
6868
Ditto jointly optimizes the global model and personalized models by learning local models that are encouraged to be close together by global regularization. In this example, once the global model is received, each client will carry out a regular local update and then optimizes the personalized model.
6969

7070
```bash
71-
cd examples/personalized_fl/ditto
72-
uv run ditto.py -c ../configs/ditto_CIFAR10_resnet18.yml
71+
cd examples/personalized_fl
72+
uv run ditto/ditto.py -c configs/ditto_CIFAR10_resnet18.yml
7373
```
7474

7575
**Reference:** Li et al., "[Ditto: Fair and robust federated learning through personalization](https://proceedings.mlr.press/v139/li21h.html)," in Proc ICML, 2021.
@@ -81,8 +81,8 @@ uv run ditto.py -c ../configs/ditto_CIFAR10_resnet18.yml
8181
Per-FedAvg uses the Model-Agnostic Meta-Learning (MAML) framework to perform local training during the regular training rounds. It performs two forward and backward passes with fixed learning rates in each iteration.
8282

8383
```bash
84-
cd examples/personalized_fl/perfedavg
85-
uv run perfedavg.py -c ../configs/perfedavg_CIFAR10_resnet18.yml
84+
cd examples/personalized_fl
85+
uv run perfedavg/perfedavg.py -c configs/perfedavg_CIFAR10_resnet18.yml
8686
```
8787

8888
**Reference:** Fallah et al., "[Personalized Federated Learning with Theoretical Guarantees: A Model-Agnostic Meta-Learning Approach](https://proceedings.neurips.cc/paper/2020/hash/24389bfe4fe2eba8bf9aa9203a44cdad-Abstract.html)," in Proc NeurIPS, 2020.
@@ -94,8 +94,8 @@ uv run perfedavg.py -c ../configs/perfedavg_CIFAR10_resnet18.yml
9494
Hermes utilizes structured pruning to improve both communication efficiency and inference efficiency of federated learning. It prunes channels with the lowest magnitudes in each local model and adjusts the pruning amount based on each local model's test accuracy and its previous pruning amount. When the server aggregates pruned updates, it only averages parameters that were not pruned on all clients.
9595

9696
```bash
97-
cd examples/personalized_fl/hermes
98-
uv run hermes.py -c ../configs/hermes_CIFAR10_resnet18.yml
97+
cd examples/personalized_fl
98+
uv run hermes/hermes.py -c configs/hermes_CIFAR10_resnet18.yml
9999
```
100100

101101
**Reference:** Li et al., "[Hermes: An Efficient Federated Learning Framework for Heterogeneous Mobile Clients](https://sites.duke.edu/angli/files/2021/10/2021_Mobicom_Hermes_v1.pdf)," in Proc. 27th Annual International Conference on Mobile Computing and Networking (MobiCom), 2021.

docs/docs/references/trainers.md

Lines changed: 54 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Plato's trainer system uses a **composition-based architecture** built on the **
9696
| **LRSchedulerStrategy** | LR scheduling | Custom schedules, warmup |
9797
| **ModelUpdateStrategy** | State management | Control variates, personalization (SCAFFOLD, Ditto) |
9898
| **DataLoaderStrategy** | Data loading | Custom sampling, augmentation |
99+
| **TestingStrategy** | Model evaluation | Custom model evaluation and testing |
99100

100101
---
101102

@@ -193,85 +194,10 @@ class ComposableTrainer(base.Trainer):
193194
| `data_loader_strategy` | `DataLoaderStrategy` | `DefaultDataLoaderStrategy()` | Strategy for data loading |
194195
| `testing_strategy` | `TestingStrategy` | `DefaultTestingStrategy()` | Strategy for model evaluation |
195196

196-
#### Key Methods
197+
#### Methods
197198

198-
!!! note "`train(trainset, sampler, **kwargs) -> float`"
199+
Here is a list of methods in `ComposableTrainer` that can be called.
199200

200-
Train the model on the given dataset and sampler.
201-
202-
**Parameters:**
203-
204-
- `trainset`: Training dataset
205-
- `sampler`: Data sampler for this client
206-
- `**kwargs`: Additional arguments
207-
208-
**Returns:**
209-
210-
- Training time in seconds
211-
212-
**Example:**
213-
214-
```python
215-
training_time = trainer.train(trainset, sampler)
216-
```
217-
218-
!!! note "`test(testset, sampler=None, **kwargs) -> float`"
219-
Test the model on the given dataset.
220-
221-
**Parameters:**
222-
223-
- `testset`: Test dataset
224-
- `sampler`: Optional data sampler
225-
- `**kwargs`: Additional arguments
226-
227-
**Returns:**
228-
229-
- Test accuracy (0.0 to 1.0)
230-
231-
**Example:**
232-
233-
```python
234-
accuracy = trainer.test(testset)
235-
print(f"Accuracy: {accuracy * 100:.2f}%")
236-
```
237-
238-
!!! note "`train_model(config, trainset, sampler, **kwargs)`"
239-
Main training loop implementation. Called internally by `train()`.
240-
241-
**Parameters:**
242-
243-
- `config`: Configuration dictionary
244-
- `trainset`: Training dataset
245-
- `sampler`: Data sampler
246-
- `**kwargs`: Additional arguments
247-
248-
!!! note "`save_model(filename=None, location=None)`"
249-
Save model weights and training history.
250-
251-
**Parameters:**
252-
253-
- `filename`: Optional custom filename
254-
- `location`: Optional custom directory
255-
256-
**Example:**
257-
258-
```python
259-
trainer.save_model("my_model.pth")
260-
```
261-
262-
!!! note "`load_model(filename=None, location=None)`"
263-
Load model weights and training history.
264-
265-
**Parameters:**
266-
267-
- `filename`: Optional custom filename
268-
- `location`: Optional custom directory
269-
270-
**Example:**
271-
272-
```python
273-
trainer.load_model("my_model.pth")
274-
```
275201

276202
#### Attributes
277203

@@ -662,9 +588,9 @@ class ModelUpdateStrategy(Strategy):
662588

663589
#### When to Implement
664590

665-
- Control variates (SCAFFOLD)
666-
- Dynamic regularization state (FedDyn)
667-
- Personalization (FedPer, FedRep, Ditto)
591+
- Control variates (e.g., SCAFFOLD)
592+
- Dynamic regularization state (e.g., FedDyn)
593+
- Personalization (e.g., FedPer, FedRep, Ditto)
668594
- Layer freezing/unfreezing
669595
- Custom state management
670596

@@ -955,6 +881,22 @@ trainer = ComposableTrainer(
955881
)
956882
```
957883

884+
### Testing Strategies
885+
886+
**Location**: `plato.trainers.strategies.testing`
887+
888+
| Strategy | Description | Parameters |
889+
|----------|-------------|------------|
890+
| `DefaultTestingStrategy` | Standard Testing | Uses config settings |
891+
892+
**Example:**
893+
```python
894+
from plato.trainers.strategies import TestingStrategy
895+
896+
trainer = ComposableTrainer(
897+
testing_strategy=DefaultTestingStrategy()
898+
)
899+
```
958900
---
959901

960902
## Algorithm-Specific Strategies
@@ -1212,7 +1154,7 @@ algorithm:
12121154

12131155
---
12141156

1215-
### FedMos
1157+
### FedMoS
12161158

12171159
**Location**: `plato.trainers.strategies.algorithms.fedmos_strategy`
12181160

@@ -2241,7 +2183,7 @@ Here is a list of all the methods available in the `RunHistory` class:
22412183

22422184
When using the strategy pattern is no longer feasible, it is also possible to customize the training or testing procedure using subclassing, and overriding hook methods. To customize the training loop using subclassing, subclass the `basic.Trainer` class in `plato.trainers`, and override the following hook methods:
22432185

2244-
!!! example "train_model()"
2186+
!!! note "`train_model()`"
22452187
**`def train_model(self, config, trainset, sampler, **kwargs):`**
22462188

22472189
Override this method to provide a custom training loop.
@@ -2252,8 +2194,8 @@ When using the strategy pattern is no longer feasible, it is also possible to cu
22522194

22532195
**Example:** A complete example can be found in the Hugging Face trainer, located at `plato/trainers/huggingface.py`.
22542196

2255-
!!! example "test_model()"
2256-
**`test_model(self, config, testset, sampler=None, **kwargs):`**
2197+
!!! note "`test_model()`"
2198+
**`def test_model(self, config, testset, sampler=None, **kwargs):`**
22572199

22582200
Override this method to provide a custom testing loop.
22592201

@@ -2262,6 +2204,34 @@ When using the strategy pattern is no longer feasible, it is also possible to cu
22622204

22632205
**Example:** A complete example can be found in `plato/trainers/huggingface.py`.
22642206

2207+
!!! note "`save_model(filename=None, location=None)`"
2208+
Save model weights and training history.
2209+
2210+
**Parameters:**
2211+
2212+
- `filename`: Optional custom filename
2213+
- `location`: Optional custom directory
2214+
2215+
**Example:**
2216+
2217+
```python
2218+
trainer.save_model("my_model.pth")
2219+
```
2220+
2221+
!!! note "`load_model(filename=None, location=None)`"
2222+
Load model weights and training history.
2223+
2224+
**Parameters:**
2225+
2226+
- `filename`: Optional custom filename
2227+
- `location`: Optional custom directory
2228+
2229+
**Example:**
2230+
2231+
```python
2232+
trainer.load_model("my_model.pth")
2233+
```
2234+
22652235
---
22662236

22672237
## Import Guide

0 commit comments

Comments
 (0)