diff --git a/fedlab_benchmarks/fedavg_v1.2.0/standalone.py b/fedlab_benchmarks/fedavg_v1.2.0/standalone.py index af7468b..835faac 100644 --- a/fedlab_benchmarks/fedavg_v1.2.0/standalone.py +++ b/fedlab_benchmarks/fedavg_v1.2.0/standalone.py @@ -18,6 +18,8 @@ from fedlab.utils.functional import evaluate from fedlab.utils.functional import get_best_gpu, load_dict +from fedlab.utils.dataset import MNISTPartitioner + # configuration parser = argparse.ArgumentParser(description="Standalone training example") parser.add_argument("--total_client", type=int, default=100) @@ -81,7 +83,11 @@ def forward(self, x): aggregator = Aggregators.fedavg_aggregate total_client_num = args.total_client # client总数 -data_indices = load_dict("mnist_partition.pkl") +#data_indices = load_dict("mnist_partition.pkl") +data_indices = MNISTPartitioner(trainset.targets, + args.total_client, + partition="iid", + seed=2025) # fedlab setup trainer = SubsetSerialTrainer(model=model, @@ -91,7 +97,8 @@ def forward(self, x): "batch_size": args.batch_size, "epochs": args.epochs, "lr": args.lr - }) + }, + cuda=args.cuda) # train procedure to_select = [i for i in range(total_client_num)]