-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_dry_optimization.py
More file actions
195 lines (166 loc) · 6.86 KB
/
run_dry_optimization.py
File metadata and controls
195 lines (166 loc) · 6.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import os
import argparse
import pandas as pd
import networkx as nx
import copy
from collections import OrderedDict
from comp_graph.comp_graph import write_text_graph
from comp_graph.pytorch_graph import get_nx_node_attr, get_nx_edge_attr
from device_topo import DeviceTopo
from optimizer import Optimizer
from simulator import Simulator
from experiments.generate_testcase import load_comp_graph
parser = argparse.ArgumentParser()
parser.add_argument("--arch", type=str, default="res152_64")
parser.add_argument("--input_root", type=str, default="misc/text_graph_in/res152_64")
parser.add_argument("--output_root", type=str, default="misc/text_graph_out/res152_64")
def write_text_graph_in(args):
assert args.arch in ["res152_64", "bertl_32"]
write_text_graph(
load_comp_graph(args.arch)[0].g, root=args.input_root, arch=args.arch,
clear_partition=True, write_rawid_dict=True)
def convert_text_graph_to_comp_graph(arch, input_root) -> nx.MultiDiGraph:
# We only parse/restore attrs used by optimizer in nodes_dict and edges_dict
files = [
os.path.join(input_root, "%s_%s" % (arch, f))
for f in ["nodes.csv", "edges.csv"]
]
nodes_file, edges_file = files[0], files[1]
# Nodes
nodes_df = pd.read_csv(nodes_file, index_col="node_id")
nodes_dict = {
node_id: {attr: nodes_df.loc[node_id, attr] for attr in nodes_df.columns}
for node_id in nodes_df.index
}
node_default_attrs = {
"hasMultipleOutputs": False,
"scopeName": arch + "/Unknown",
}
for node_id in nodes_dict.keys():
for key, value in node_default_attrs.items():
nodes_dict[node_id][key] = value
# Edges
edges_df = pd.read_csv(edges_file, index_col="link")
edges_dict = {
tuple(int(id) for id in link.split(" ")): {
attr: edges_df.loc[link, attr] for attr in edges_df.columns
}
for link in edges_df.index
}
edge_default_attrs = {
"tensor_size": "Unknown",
"tensor_req_grad": "Unknown",
}
for edge_id in edges_dict.keys():
for key, value in edge_default_attrs.items():
edges_dict[edge_id][key] = value
# get "hasMultipleOutputs" from edges_dict automatically
for u, i, v, j in edges_dict.keys():
if i > 0:
nodes_dict[u]["hasMultipleOutputs"] = True
# init comp graph
g = nx.MultiDiGraph()
inputs = []
outputs = []
for this_uid, raw_node in nodes_dict.items():
g.add_node(
this_uid,
**get_nx_node_attr(
uid=this_uid,
node_type="OP",
is_task=True,
can_profile=True,
kind=raw_node["kind"],
hasMultipleOutputs=raw_node["hasMultipleOutputs"],
scopeName=raw_node["scopeName"],
)
)
attrs = ["params_mb", "fore_mem", "back_mem", "fore_time", "back_time"]
for attr in attrs:
g.nodes[this_uid][attr] = raw_node[attr]
g.nodes[this_uid]["part_id"] = raw_node["stage_id"]
if raw_node["is_input"]:
inputs.append(this_uid)
if raw_node["is_output"]:
outputs.append(this_uid)
for link, raw_edge in edges_dict.items():
u, _, v, _ = link
g.add_edge(
u, v,
**get_nx_edge_attr(
data_mb=raw_edge["data_mb"],
tensor_size=raw_edge["tensor_size"],
tensor_type=raw_edge["tensor_type"],
tensor_req_grad=raw_edge["tensor_req_grad"],
link=link,
)
)
g.graph["inputs"] = inputs
g.graph["outputs"] = outputs
g.graph["graph_type"] = arch
return g
def run_dry_optimization(args, s=None, r=None):
if s is None and r is None:
dp_cfg_path = "./config/dp_cfg.yaml"
elif s is not None and r is not None:
dp_cfg_path = "./config/dp_cfg_s%d_r%d.yaml" % (s, r)
else:
raise NotImplementedError("Error! Please check s and r!")
g = convert_text_graph_to_comp_graph(args.arch, args.input_root)
device_topo = DeviceTopo()
sim = Simulator(g, device_topo)
opt = Optimizer(g, device_topo)
ensemble_params = {
"parmesan" : {"k" : 48, "comm_coef" : 1e-3, "num_chunks": 4,
"max_time_coeff" : 1.5, "new_flow": True,
"batch_size" : 128 // 4, "dp_cfg_path": dp_cfg_path,
"profile_pos" : "", "use_refinement": True,
"visualization": False, "verbose" : True},
"one_cut": {"visualization": False, "comm_coef" : 1e-3,
"num_chunks": 4, "dp_cfg_path": dp_cfg_path},
}
sim_params = {"concurrent_copy_comp": True, "debug": True}
is_bert = "bert" in args.arch.lower()
mapping_params = {
"visualization": False,
"parallel_degree": 160, "parallel_same_t": 40,
"verbose": False,
}
place_sol = opt.optimization(
opt_method="ensemble", sim=sim,
opt_method_to_params=ensemble_params, sim_params=sim_params,
mapping_params=mapping_params,
dump_name=None, run_mapping=True,
)
# sim_time = sim.simulate_sync_pipe(place_sol, **sim_params) # no mapping
if not os.path.exists(args.output_root):
os.makedirs(args.output_root)
write_text_graph(g, root=args.output_root, arch=args.arch)
opt.generate_hybrid_conf(os.path.join(args.output_root, "%s_mapping_conf.json" % args.arch))
print("Finish dry run optimization")
def read_text_graph_out(args, generate_file=True):
comp_graph = load_comp_graph(args.arch)[0]
nodes_file = os.path.join(args.output_root, "%s_nodes.csv" % (args.arch))
nodes_df = pd.read_csv(nodes_file, index_col="node_id")
rawid_dict_file = os.path.join(args.input_root, "%s_rawid_dict.txt" % (args.arch))
with open(rawid_dict_file, 'r') as f:
lines = f.readlines()
lines = [line[:-1] for line in lines]
int_id2str_id = {int(line.split(" ")[0]): str(line.split(" ")[1]) for line in lines}
for node_id, node in nodes_df.iterrows():
if node_id in int_id2str_id.keys():
node_id = int_id2str_id[node_id]
comp_graph.g.nodes[node_id]["part_id"] = node["stage_id"]
if generate_file:
fileroot = os.path.join(args.output_root, "runtime")
comp_graph.generate_module_by_device_placement(
file_name="%s/stage" % fileroot, check_behavior=True)
hybrid_conf_file = os.path.join(fileroot, "hybrid_conf.json")
mapping_conf_file = os.path.join(args.output_root, "%s_mapping_conf.json" % args.arch)
os.system('cp %s %s' % (mapping_conf_file, hybrid_conf_file))
return comp_graph
if __name__ == "__main__":
args = parser.parse_args()
write_text_graph_in(args)
run_dry_optimization(args)
read_text_graph_out(args)