Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions examples/advanced-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ You can run your Flower project in both _simulation_ and _deployment_ mode witho

When you run the project, the strategy will create a directory structure in the form of `outputs/date/time` and store two `JSON` files: `config.json` containing the `run-config` that the `ServerApp` receives; and `results.json` containing the results (accuracies, losses) that are generated at the strategy.

By default, the metrics: {`centralized_accuracy`, `centralized_loss`, `federated_evaluate_accuracy`, `federated_evaluate_loss`} will be logged to Weights & Biases (they are also stored to the `results.json` previously mentioned). Upon executing `flwr run` you'll see a URL linking to your Weight&Biases dashboard where you can see the metrics.
By default, the strategy logs train and evaluation metrics to Weights & Biases (these are also stored in `results.json`). This includes:

- ClientApp-side metrics such as `train_loss`, `eval_loss`, `eval_acc`, `eval_acc_top3`, and `eval_acc_class_0` ... `eval_acc_class_9`
- ServerApp-side metrics such as `loss`, `accuracy`, `accuracy_top3`, and `accuracy_class_0` ... `accuracy_class_9`

Upon executing `flwr run` you'll see a URL linking to your Weights & Biases dashboard where you can inspect these metrics.

![](_static/wandb_plots.png)

Expand All @@ -69,10 +74,18 @@ The `results.json` would look along the lines of:
},
"evaluate_metrics_clientapp": {
"eval_loss": 2.303316633324679,
"eval_acc": 0.11882631674867869
"eval_acc": 0.11882631674867869,
"eval_acc_top3": 0.35142118863049095,
"eval_acc_class_0": 0.0,
"eval_acc_class_1": 0.021739130434782608,
"eval_acc_class_2": 0.0
},
"evaluate_metrics_serverapp": {
"accuracy": 0.1,
"accuracy_top3": 0.3012,
"accuracy_class_0": 0.0,
"accuracy_class_1": 0.0,
"accuracy_class_2": 0.998,
"loss": 2.3280856304656203
}
},
Expand All @@ -83,10 +96,18 @@ The `results.json` would look along the lines of:
},
"evaluate_metrics_clientapp": {
"eval_loss": 2.1314486836388467,
"eval_acc": 0.19826539462272333
"eval_acc": 0.19826539462272333,
"eval_acc_top3": 0.4854715191260137,
"eval_acc_class_0": 0.06976744186046512,
"eval_acc_class_1": 0.10204081632653061,
"eval_acc_class_2": 0.0
},
"evaluate_metrics_serverapp": {
"accuracy": 0.1,
"accuracy_top3": 0.3006,
"accuracy_class_0": 0.0,
"accuracy_class_1": 0.0,
"accuracy_class_2": 1.0,
"loss": 2.2980988307501944
}
},
Expand Down
12 changes: 8 additions & 4 deletions examples/advanced-pytorch/pytorch_example/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import torch
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
from flwr.clientapp import ClientApp

from pytorch_example.task import Net, load_data
from pytorch_example.task import NUM_CLASSES, Net, load_data
from pytorch_example.task import test as test_fn
from pytorch_example.task import train as train_fn

Expand Down Expand Up @@ -88,7 +87,7 @@ def evaluate(msg: Message, context: Context):
_, valloader = load_data(partition_id, num_partitions)

# Call the evaluation function
eval_loss, eval_acc = test_fn(
eval_loss, eval_results = test_fn(
model,
valloader,
device,
Expand All @@ -97,9 +96,14 @@ def evaluate(msg: Message, context: Context):
# Construct and return reply Message
metrics = {
"eval_loss": eval_loss,
"eval_acc": eval_acc,
"eval_acc": eval_results["accuracy"],
"eval_acc_top3": eval_results["top3_accuracy"],
"num-examples": len(valloader.dataset),
}
for class_idx in range(NUM_CLASSES):
metrics[f"eval_acc_class_{class_idx}"] = eval_results[
f"class_accuracy_{class_idx}"
]
metric_record = MetricRecord(metrics)
content = RecordDict({"metrics": metric_record})
return Message(content=content, reply_to=msg)
25 changes: 20 additions & 5 deletions examples/advanced-pytorch/pytorch_example/server_app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
"""pytorch-example: A Flower / PyTorch app."""

import torch
from datasets import load_dataset
from flwr.app import ArrayRecord, ConfigRecord, Context, MetricRecord
from flwr.serverapp import Grid, ServerApp
from pytorch_example.strategy import CustomFedAvg
from pytorch_example.task import (
NUM_CLASSES,
Net,
apply_eval_transforms,
create_run_dir,
test,
)
from torch.utils.data import DataLoader

from pytorch_example.strategy import CustomFedAvg
from pytorch_example.task import Net, apply_eval_transforms, create_run_dir, test
from datasets import load_dataset

# Create ServerApp
app = ServerApp()
Expand Down Expand Up @@ -70,7 +76,16 @@ def global_evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
net = Net()
net.load_state_dict(arrays.to_torch_state_dict())
net.to(device)
loss, accuracy = test(net, testloader, device=device)
return MetricRecord({"accuracy": accuracy, "loss": loss})
loss, test_results = test(net, testloader, device=device)
metrics = {
"accuracy": test_results["accuracy"],
"accuracy_top3": test_results["top3_accuracy"],
"loss": loss,
}
for class_idx in range(NUM_CLASSES):
metrics[f"accuracy_class_{class_idx}"] = test_results[
f"class_accuracy_{class_idx}"
]
return MetricRecord(metrics)

return global_evaluate
41 changes: 35 additions & 6 deletions examples/advanced-pytorch/pytorch_example/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Normalize(*FM_NORMALIZATION),
]
)
NUM_CLASSES = 10


class Net(nn.Module):
Expand Down Expand Up @@ -74,18 +75,46 @@ def train(net, trainloader, epochs, lr, device):
def test(net, testloader, device):
"""Validate the model on the test set."""
net.to(device)
was_training = net.training
net.eval()
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
top1_correct, top3_correct, loss = 0, 0, 0.0
class_correct = torch.zeros(NUM_CLASSES, device=device, dtype=torch.long)
class_total = torch.zeros(NUM_CLASSES, device=device, dtype=torch.long)
with torch.no_grad():
for batch in testloader:
images = batch["image"].to(device)
labels = batch["label"].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
loss = loss / len(testloader)
return loss, accuracy
loss += criterion(outputs, labels).item() * labels.size(0)
top1_preds = torch.max(outputs, 1)[1]
top1_correct += (top1_preds == labels).sum().item()

top3_preds = torch.topk(outputs, k=3, dim=1).indices
top3_correct += (top3_preds == labels.unsqueeze(1)).any(dim=1).sum().item()

class_total += torch.bincount(labels, minlength=NUM_CLASSES)
class_correct += torch.bincount(
labels[top1_preds == labels], minlength=NUM_CLASSES
)

top1_accuracy = top1_correct / len(testloader.dataset)
top3_accuracy = top3_correct / len(testloader.dataset)
loss = loss / len(testloader.dataset)
net.train(was_training)
class_correct_cpu = class_correct.cpu().tolist()
class_total_cpu = class_total.cpu().tolist()
class_accuracies = {
f"class_accuracy_{class_idx}": (
class_correct_cpu[class_idx] / class_total_cpu[class_idx]
if class_total_cpu[class_idx] > 0
else 0.0
)
for class_idx in range(NUM_CLASSES)
}
metrics = {"accuracy": top1_accuracy, "top3_accuracy": top3_accuracy}
metrics.update(class_accuracies)
return loss, metrics


def apply_train_transforms(batch):
Expand Down