-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathgen_data.py
More file actions
79 lines (63 loc) · 2.63 KB
/
gen_data.py
File metadata and controls
79 lines (63 loc) · 2.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
'''
Date: 2024-03-11 18:00:06
LastEditors: jackeymiao
LastEditTime: 2024-04-07 09:48:19
FilePath: /ADNet/gen_data.py
'''
import argparse
import os
import numpy as np
import torch
from utils.data_utils import check_extension, save_dataset
def generate_MultiPM_data(n_samples, n_users, pk, radius):
data = []
for _ in range(n_samples):
c = torch.FloatTensor(1, n_users).uniform_(2, 4)
for i in range(len(pk) - 1):
c_temp = c[-1]
factor = torch.FloatTensor(1, n_users).uniform_(0.8, 0.88)
c = torch.cat((c, c_temp * factor))
loc = torch.FloatTensor(n_users, 2).uniform_(0, 1)
combined = torch.cat((loc,c[0].reshape(n_users, 1)),axis=1)
data.append(dict(loc=loc,
radius=radius,
pk=torch.IntTensor(pk),
c = c,
combined=combined,
))
return data
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--filename", help="Filename of the dataset to create (ignores datadir)")
parser.add_argument("--data_dir", default='data', help="Create datasets in data_dir/problem (default 'data')")
parser.add_argument("--problem", type=str, default='MultiPM',
help="Problem, 'MultiPM' to generate")
parser.add_argument("--dataset_size", type=int, default=1000, help="Size of the dataset")
parser.add_argument('--graph_size', type=int, default=2000,
help="number of users")
parser.add_argument("-f", action='store_true', help="Set true to overwrite")
parser.add_argument('--seed', type=int, default=1234, help="Random seed")
parser.add_argument('--pk', nargs='+', type=int, default=[2, 4, 7, 9, 10, 13, 15])
opts = parser.parse_args()
assert opts.filename is None or (len(opts.problems) == 1 and len(opts.graph_sizes) == 1), \
"Can only specify filename when generating a single dataset"
torch.manual_seed(1234)
problem = opts.problem
n_users = opts.graph_size
datadir = os.path.join(opts.data_dir, problem)
os.makedirs(datadir, exist_ok=True)
if problem == 'MultiPM':
if n_users == 20:
radius = 0.32
elif n_users == 50:
radius = 0.24
elif n_users ==100:
radius = 0.16
else:
radius = 0.16
pk = opts.pk
filename = os.path.join(datadir, f"{problem}_{n_users}_{radius}_{pk}.pkl")
dataset = generate_MultiPM_data(opts.dataset_size, n_users, pk, radius)
else:
assert False, "Unknown problem: {}".format(problem)
save_dataset(dataset, filename)