-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcifar_script.py
More file actions
48 lines (44 loc) · 1.48 KB
/
cifar_script.py
File metadata and controls
48 lines (44 loc) · 1.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
from FL_models import *
import constants
data = np.load("./cifar_resnet56.npz")
train_imgs = data['arr_0']
train_labels = data['arr_1']
test_imgs = data['arr_3']
test_labels = data['arr_4']
train_imgs = torch.tensor(train_imgs, dtype=torch.float)
test_imgs = torch.tensor(test_imgs, dtype=torch.float)
train_labels = torch.tensor(train_labels, dtype=torch.long)
test_labels = torch.tensor(test_labels, dtype=torch.long)
print(f"Data loaded, training images: {train_imgs.size(0)}, testing images: {test_imgs.size(0)}")
print("Initializing...")
num_iter = 201
Ph = 50
hidden = 1024
malicious_factor = 0.3
for att_mode in constants.targeted_att:
for exp in [constants.p_trust, constants.np_cos, constants.p_cos]:
cgd = FL_torch(
num_iter=num_iter,
train_imgs=train_imgs,
train_labels=train_labels,
test_imgs=test_imgs,
test_labels=test_labels,
Ph=Ph,
malicious_factor=malicious_factor,
defender=exp['defender'],
n_H=hidden,
dataset="CIFAR",
start_attack=exp['start'],
attack_mode=att_mode,
k_nearest=35,
p_kernel=2,
local_epoch=2
)
cgd.shuffle_data()
cgd.federated_init()
cgd.grad_reset()
cgd.data_distribution()
print(f"Start {att_mode} attack to {exp['defender']}...")
cgd.eq_train()
print(f"{att_mode} attack to {exp['defender']} complete")