diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index 1252c5bb04d5..fbb15ddc6717 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -54,7 +54,12 @@ You can run your Flower project in both _simulation_ and _deployment_ mode witho When you run the project, the strategy will create a directory structure in the form of `outputs/date/time` and store two `JSON` files: `config.json` containing the `run-config` that the `ServerApp` receives; and `results.json` containing the results (accuracies, losses) that are generated at the strategy. -By default, the metrics: {`centralized_accuracy`, `centralized_loss`, `federated_evaluate_accuracy`, `federated_evaluate_loss`} will be logged to Weights & Biases (they are also stored to the `results.json` previously mentioned). Upon executing `flwr run` you'll see a URL linking to your Weight&Biases dashboard where you can see the metrics. +By default, the strategy logs train and evaluation metrics to Weights & Biases (these are also stored in `results.json`). This includes: + +- ClientApp-side metrics such as `train_loss`, `eval_loss`, `eval_acc`, `eval_acc_top3`, and `eval_acc_class_0` ... `eval_acc_class_9` +- ServerApp-side metrics such as `loss`, `accuracy`, `accuracy_top3`, and `accuracy_class_0` ... `accuracy_class_9` + +Upon executing `flwr run` you'll see a URL linking to your Weights & Biases dashboard where you can inspect these metrics. ![](_static/wandb_plots.png) @@ -69,10 +74,18 @@ The `results.json` would look along the lines of: }, "evaluate_metrics_clientapp": { "eval_loss": 2.303316633324679, - "eval_acc": 0.11882631674867869 + "eval_acc": 0.11882631674867869, + "eval_acc_top3": 0.35142118863049095, + "eval_acc_class_0": 0.0, + "eval_acc_class_1": 0.021739130434782608, + "eval_acc_class_2": 0.0 }, "evaluate_metrics_serverapp": { "accuracy": 0.1, + "accuracy_top3": 0.3012, + "accuracy_class_0": 0.0, + "accuracy_class_1": 0.0, + "accuracy_class_2": 0.998, "loss": 2.3280856304656203 } }, @@ -83,10 +96,18 @@ The `results.json` would look along the lines of: }, "evaluate_metrics_clientapp": { "eval_loss": 2.1314486836388467, - "eval_acc": 0.19826539462272333 + "eval_acc": 0.19826539462272333, + "eval_acc_top3": 0.4854715191260137, + "eval_acc_class_0": 0.06976744186046512, + "eval_acc_class_1": 0.10204081632653061, + "eval_acc_class_2": 0.0 }, "evaluate_metrics_serverapp": { "accuracy": 0.1, + "accuracy_top3": 0.3006, + "accuracy_class_0": 0.0, + "accuracy_class_1": 0.0, + "accuracy_class_2": 1.0, "loss": 2.2980988307501944 } }, diff --git a/examples/advanced-pytorch/pytorch_example/client_app.py b/examples/advanced-pytorch/pytorch_example/client_app.py index ba751470d116..180008c01462 100644 --- a/examples/advanced-pytorch/pytorch_example/client_app.py +++ b/examples/advanced-pytorch/pytorch_example/client_app.py @@ -3,8 +3,7 @@ import torch from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict from flwr.clientapp import ClientApp - -from pytorch_example.task import Net, load_data +from pytorch_example.task import NUM_CLASSES, Net, load_data from pytorch_example.task import test as test_fn from pytorch_example.task import train as train_fn @@ -88,7 +87,7 @@ def evaluate(msg: Message, context: Context): _, valloader = load_data(partition_id, num_partitions) # Call the evaluation function - eval_loss, eval_acc = test_fn( + eval_loss, eval_results = test_fn( model, valloader, device, @@ -97,9 +96,14 @@ def evaluate(msg: Message, context: Context): # Construct and return reply Message metrics = { "eval_loss": eval_loss, - "eval_acc": eval_acc, + "eval_acc": eval_results["accuracy"], + "eval_acc_top3": eval_results["top3_accuracy"], "num-examples": len(valloader.dataset), } + for class_idx in range(NUM_CLASSES): + metrics[f"eval_acc_class_{class_idx}"] = eval_results[ + f"class_accuracy_{class_idx}" + ] metric_record = MetricRecord(metrics) content = RecordDict({"metrics": metric_record}) return Message(content=content, reply_to=msg) diff --git a/examples/advanced-pytorch/pytorch_example/server_app.py b/examples/advanced-pytorch/pytorch_example/server_app.py index 1e52a659340d..4184d32fa3c4 100644 --- a/examples/advanced-pytorch/pytorch_example/server_app.py +++ b/examples/advanced-pytorch/pytorch_example/server_app.py @@ -1,13 +1,19 @@ """pytorch-example: A Flower / PyTorch app.""" import torch -from datasets import load_dataset from flwr.app import ArrayRecord, ConfigRecord, Context, MetricRecord from flwr.serverapp import Grid, ServerApp +from pytorch_example.strategy import CustomFedAvg +from pytorch_example.task import ( + NUM_CLASSES, + Net, + apply_eval_transforms, + create_run_dir, + test, +) from torch.utils.data import DataLoader -from pytorch_example.strategy import CustomFedAvg -from pytorch_example.task import Net, apply_eval_transforms, create_run_dir, test +from datasets import load_dataset # Create ServerApp app = ServerApp() @@ -70,7 +76,16 @@ def global_evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord: net = Net() net.load_state_dict(arrays.to_torch_state_dict()) net.to(device) - loss, accuracy = test(net, testloader, device=device) - return MetricRecord({"accuracy": accuracy, "loss": loss}) + loss, test_results = test(net, testloader, device=device) + metrics = { + "accuracy": test_results["accuracy"], + "accuracy_top3": test_results["top3_accuracy"], + "loss": loss, + } + for class_idx in range(NUM_CLASSES): + metrics[f"accuracy_class_{class_idx}"] = test_results[ + f"class_accuracy_{class_idx}" + ] + return MetricRecord(metrics) return global_evaluate diff --git a/examples/advanced-pytorch/pytorch_example/task.py b/examples/advanced-pytorch/pytorch_example/task.py index b4ddb0a645f7..ff11dedb7565 100644 --- a/examples/advanced-pytorch/pytorch_example/task.py +++ b/examples/advanced-pytorch/pytorch_example/task.py @@ -29,6 +29,7 @@ Normalize(*FM_NORMALIZATION), ] ) +NUM_CLASSES = 10 class Net(nn.Module): @@ -74,18 +75,46 @@ def train(net, trainloader, epochs, lr, device): def test(net, testloader, device): """Validate the model on the test set.""" net.to(device) + was_training = net.training + net.eval() criterion = torch.nn.CrossEntropyLoss() - correct, loss = 0, 0.0 + top1_correct, top3_correct, loss = 0, 0, 0.0 + class_correct = torch.zeros(NUM_CLASSES, device=device, dtype=torch.long) + class_total = torch.zeros(NUM_CLASSES, device=device, dtype=torch.long) with torch.no_grad(): for batch in testloader: images = batch["image"].to(device) labels = batch["label"].to(device) outputs = net(images) - loss += criterion(outputs, labels).item() - correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() - accuracy = correct / len(testloader.dataset) - loss = loss / len(testloader) - return loss, accuracy + loss += criterion(outputs, labels).item() * labels.size(0) + top1_preds = torch.max(outputs, 1)[1] + top1_correct += (top1_preds == labels).sum().item() + + top3_preds = torch.topk(outputs, k=3, dim=1).indices + top3_correct += (top3_preds == labels.unsqueeze(1)).any(dim=1).sum().item() + + class_total += torch.bincount(labels, minlength=NUM_CLASSES) + class_correct += torch.bincount( + labels[top1_preds == labels], minlength=NUM_CLASSES + ) + + top1_accuracy = top1_correct / len(testloader.dataset) + top3_accuracy = top3_correct / len(testloader.dataset) + loss = loss / len(testloader.dataset) + net.train(was_training) + class_correct_cpu = class_correct.cpu().tolist() + class_total_cpu = class_total.cpu().tolist() + class_accuracies = { + f"class_accuracy_{class_idx}": ( + class_correct_cpu[class_idx] / class_total_cpu[class_idx] + if class_total_cpu[class_idx] > 0 + else 0.0 + ) + for class_idx in range(NUM_CLASSES) + } + metrics = {"accuracy": top1_accuracy, "top3_accuracy": top3_accuracy} + metrics.update(class_accuracies) + return loss, metrics def apply_train_transforms(batch):