From 26416503f4f3b02f1912240a87fcbab65977bef5 Mon Sep 17 00:00:00 2001 From: AlexEisie <1987460907@qq.com> Date: Sat, 15 Mar 2025 01:24:01 +0800 Subject: [PATCH] Update standalone.py --- fedlab_benchmarks/fedavg_v1.2.0/standalone.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/fedlab_benchmarks/fedavg_v1.2.0/standalone.py b/fedlab_benchmarks/fedavg_v1.2.0/standalone.py index af7468b7..835faac0 100644 --- a/fedlab_benchmarks/fedavg_v1.2.0/standalone.py +++ b/fedlab_benchmarks/fedavg_v1.2.0/standalone.py @@ -18,6 +18,8 @@ from fedlab.utils.functional import evaluate from fedlab.utils.functional import get_best_gpu, load_dict +from fedlab.utils.dataset import MNISTPartitioner + # configuration parser = argparse.ArgumentParser(description="Standalone training example") parser.add_argument("--total_client", type=int, default=100) @@ -81,7 +83,11 @@ def forward(self, x): aggregator = Aggregators.fedavg_aggregate total_client_num = args.total_client # client总数 -data_indices = load_dict("mnist_partition.pkl") +#data_indices = load_dict("mnist_partition.pkl") +data_indices = MNISTPartitioner(trainset.targets, + args.total_client, + partition="iid", + seed=2025) # fedlab setup trainer = SubsetSerialTrainer(model=model, @@ -91,7 +97,8 @@ def forward(self, x): "batch_size": args.batch_size, "epochs": args.epochs, "lr": args.lr - }) + }, + cuda=args.cuda) # train procedure to_select = [i for i in range(total_client_num)]