Skip to content

Commit 8175160

Browse files
bdhirshxmfan
authored andcommitted
Add split dI/dW graph pass example
1 parent b4a76e9 commit 8175160

File tree

2 files changed

+370
-0
lines changed

2 files changed

+370
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
8+
import torch
9+
import torch.fx as fx
10+
from functorch.compile import default_partition
11+
12+
# we are running the default partitioner on the bw graph, which requires AC tags being removed.
13+
# At this stage we have already finished running AC anyway, since we have a bw graph
14+
def remove_recompute_tags(bw_gm):
15+
for n in bw_gm.graph.nodes:
16+
if 'recompute' in n.meta:
17+
del n.meta['recompute']
18+
19+
# We are using the default partitioner to split our backward into dI and dW subgraphs.
20+
# We want to generate the dI subgraph *first*, because:
21+
# - in pipelining we generally want to schedule dI compute before dW
22+
# - the dI compute will potentially compute more activations that we need to plumb into dW compute
23+
# Today, the default partitioner requires that your split on the first K outputs of your combined graph.
24+
# So here, we reorder the outputs of the backward so grad_inputs are first.
25+
def reorder_output_grads(bw_gm, num_weight_gradients):
26+
outputs = bw_gm.graph.find_nodes(op='output')
27+
assert len(outputs) == 1
28+
output = outputs[0]
29+
assert isinstance(output.args[0], tuple)
30+
grad_weights, grad_inputs = output.args[0][:num_weight_gradients], output.args[0][num_weight_gradients:]
31+
new_out_tuple = grad_inputs + grad_weights
32+
with bw_gm.graph.inserting_after(output):
33+
# TODO: also set the new node's meta properly
34+
new_out = bw_gm.graph.output(new_out_tuple)
35+
output.replace_all_uses_with(new_out)
36+
bw_gm.graph.erase_node(output)
37+
return len(grad_inputs)
38+
39+
# TODO: in theory we can infer num_weight_gradients from the graph metadata directly
40+
def split_di_dw_graph(bw_gm: fx.GraphModule, *, num_weight_gradients) -> tuple[fx.GraphModule, fx.GraphModule]:
41+
# we could consider doing this is a non-mutating way
42+
bw_gm = copy.deepcopy(bw_gm)
43+
remove_recompute_tags(bw_gm)
44+
num_input_gradients = reorder_output_grads(bw_gm, num_weight_gradients)
45+
bw_gm.recompile()
46+
47+
args = [x.meta['val'] for x in bw_gm.graph.find_nodes(op="placeholder")]
48+
49+
bw_inputs, bw_weights = default_partition(bw_gm, args, num_fwd_outputs=num_input_gradients)
50+
return bw_inputs, bw_weights

examples/example_llama3_di_dw.py

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import time
7+
from functools import partial
8+
9+
import torch
10+
from torch.distributed.fsdp import MixedPrecisionPolicy
11+
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard
12+
from torch.testing._internal.distributed.fake_pg import FakeStore
13+
14+
from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph
15+
from autoparallel._passes.split_di_dw_graph import split_di_dw_graph
16+
from autoparallel._passes.split_fsdp_collectives import split_fsdp_prefetch
17+
from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs
18+
from autoparallel.api import AutoParallel
19+
from autoparallel.auto_bucketing import (
20+
aten_autobucketing_config,
21+
aten_autobucketing_reordering_pass,
22+
simple_fsdp_autobucketing_reordering_pass,
23+
simplefsdp_autobucketing_config,
24+
)
25+
26+
world_size = 64
27+
28+
fake_store = FakeStore()
29+
torch.distributed.init_process_group(
30+
"fake", store=fake_store, rank=0, world_size=world_size
31+
)
32+
33+
use_1d_mesh = False
34+
35+
if use_1d_mesh:
36+
mesh = torch.distributed.device_mesh.init_device_mesh(
37+
"cuda", (world_size,), mesh_dim_names=("dp",)
38+
)
39+
else:
40+
mesh = torch.distributed.device_mesh.init_device_mesh(
41+
"cuda",
42+
(world_size // 8, 8),
43+
mesh_dim_names=(
44+
"dp",
45+
"tp",
46+
),
47+
)
48+
49+
batch_size = 2 * mesh.shape[0]
50+
seqlen = 2048 * 4
51+
vocab_size = 128256
52+
use_vocab_parallel = not use_1d_mesh
53+
device = torch.device("cuda")
54+
55+
model_type = "8b"
56+
enable_asynctp = False
57+
58+
59+
def model_fn():
60+
if model_type == "8b":
61+
model_args = TransformerModelArgs(
62+
dim=4096,
63+
n_layers=1,
64+
n_heads=32,
65+
n_kv_heads=8,
66+
ffn_dim_multiplier=1.3,
67+
multiple_of=1024,
68+
rope_theta=500000,
69+
vocab_size=vocab_size,
70+
max_seq_len=seqlen,
71+
)
72+
elif model_type == "70b":
73+
model_args = TransformerModelArgs(
74+
dim=8192,
75+
n_layers=80,
76+
n_heads=64,
77+
n_kv_heads=8,
78+
ffn_dim_multiplier=1.3,
79+
multiple_of=4096,
80+
rope_theta=500000,
81+
vocab_size=vocab_size,
82+
max_seq_len=seqlen,
83+
)
84+
else:
85+
raise ValueError(f"{model_type} not available")
86+
m = Transformer(model_args)
87+
# I turned of the tok_embeddings layer because:
88+
# - I want the input to my joint graph to require grad,
89+
# so I can generate separate dI and dW gradients to carve out subgraphs for
90+
# - The input to tok_embeddings is an integral tensor which can't require grad.
91+
# Another option would be to manually apply autoparallel to a single transformer block layer.
92+
m.tok_embeddings = None
93+
return m
94+
95+
96+
def input_fn():
97+
# 8192 for 70B
98+
x = torch.randn(batch_size, seqlen, 4096, device=device, requires_grad=True)
99+
return x
100+
101+
102+
autobucketing_level = "aten"
103+
104+
if autobucketing_level == "aten":
105+
# this is from the stacked pr in https://github.com/pytorch/pytorch/pull/163960
106+
torch._inductor.config.reorder_for_peak_memory = False
107+
torch._inductor.config.reorder_for_compute_comm_overlap = False
108+
aten_autobucketing_reordering_pass = partial(
109+
aten_autobucketing_reordering_pass,
110+
configs=aten_autobucketing_config,
111+
)
112+
torch._inductor.config.post_grad_custom_post_pass = (
113+
aten_autobucketing_reordering_pass
114+
)
115+
elif autobucketing_level == "inductor":
116+
torch._inductor.config.allow_buffer_reuse = False
117+
torch._inductor.config.reorder_for_peak_memory = False
118+
torch._inductor.config.reorder_for_compute_comm_overlap = True
119+
simplefsdp_autobucketing_config.calibrate_number = 5
120+
simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl"
121+
simple_fsdp_autobucketing_reordering_pass = partial(
122+
simple_fsdp_autobucketing_reordering_pass,
123+
configs=simplefsdp_autobucketing_config,
124+
)
125+
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
126+
simple_fsdp_autobucketing_reordering_pass
127+
]
128+
else:
129+
raise ValueError(f"Unknown autobucketing_level {autobucketing_level}")
130+
131+
132+
# parallelize the model
133+
with torch.device("meta"):
134+
model = model_fn()
135+
136+
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
137+
138+
139+
def group_mm_nodes_with_its_gradients(nodes):
140+
fwd_nodes = [n for n in nodes if "nn_module_stack" in n.meta]
141+
bwd_nodes = [n for n in nodes if "fwd_nn_module_stack" in n.meta]
142+
assert len(fwd_nodes) * 2 == len(bwd_nodes)
143+
res = {}
144+
for fwd_node in fwd_nodes:
145+
o = []
146+
for bwd_node in bwd_nodes:
147+
if fwd_node.meta["nn_module_stack"] == bwd_node.meta["fwd_nn_module_stack"]:
148+
o.append(bwd_node)
149+
assert len(o) == 2
150+
res[fwd_node] = o
151+
return res
152+
153+
154+
def force_tp_constraints(autop, mm_nodes, feat_dim=1, bwd_constraint=False):
155+
# out = x @ w - S(0)R, RS(1) -> S(0)S(1)
156+
# g_w = g.T @ x - S(1)S(0), S(0)R -> PS(0)
157+
# g_x = g @ w.T - S(0)S(1), RS(0) -> S(0)P
158+
159+
add_node_constraint = autop.sharding_optimizer.add_node_constraint
160+
fwd_bwd_groups = group_mm_nodes_with_its_gradients(mm_nodes)
161+
fwd_nodes = list(fwd_bwd_groups.keys())
162+
dim1 = 0 if feat_dim == 1 else 1
163+
dim2 = 1 if feat_dim == 1 else 0
164+
# assume there are 7 mm nodes per transformer block
165+
# skip last mm as it's the final projection layer
166+
assert (
167+
len(fwd_nodes) - 1
168+
) % 7 == 0, f"expected 7 mm nodes per transformer block, {len(fwd_nodes) - 1}"
169+
for block in range(0, len(fwd_nodes) - 1, 7):
170+
fwd_nodes_block = fwd_nodes[block : block + 7]
171+
# force the first 3 mm nodes to be S(0)S(1)
172+
the_nodes = fwd_nodes_block[:3] + fwd_nodes_block[4:6]
173+
for n in the_nodes:
174+
add_node_constraint(n, (Shard(0), Shard(feat_dim)))
175+
add_node_constraint(n.all_input_nodes[0], (Shard(0), Replicate()))
176+
add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(1)))
177+
178+
if bwd_constraint:
179+
bwd_nodes = fwd_bwd_groups[n]
180+
# first is g_w, second is g_x
181+
add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim1)))
182+
add_node_constraint(bwd_nodes[1], (Shard(0), Partial()))
183+
184+
# add reduction to finish TP, yielding S(0)P
185+
the_nodes = fwd_nodes_block[3:4] + fwd_nodes_block[6:7]
186+
for n in the_nodes:
187+
add_node_constraint(n, (Shard(0), Partial()))
188+
add_node_constraint(n.all_input_nodes[0], (Shard(0), Shard(feat_dim)))
189+
add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(0)))
190+
191+
if bwd_constraint:
192+
bwd_nodes = fwd_bwd_groups[n]
193+
# first is g_w, second is g_x
194+
add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim2)))
195+
add_node_constraint(bwd_nodes[1], (Shard(0), Shard(feat_dim)))
196+
197+
198+
def add_tp_constraints(autop):
199+
mm_nodes = autop.gm.graph.find_nodes(
200+
op="call_function", target=torch.ops.aten.mm.default
201+
)
202+
einsum_nodes = autop.gm.graph.find_nodes(
203+
op="call_function", target=torch.ops.aten.einsum.default
204+
)
205+
assert (len(mm_nodes) > 0) ^ (
206+
len(einsum_nodes) > 0
207+
), f"only one should be non-empty, got {len(mm_nodes)} and {len(einsum_nodes)}"
208+
feat_dim = 1 if len(mm_nodes) > 0 else 2
209+
tgt_nodes = mm_nodes + einsum_nodes
210+
force_tp_constraints(autop, tgt_nodes, feat_dim=feat_dim, bwd_constraint=True)
211+
212+
if einsum_nodes:
213+
# add sequence parallelism if we have einsum nodes
214+
autop.sharding_optimizer.add_node_constraint(
215+
list(tgt_nodes[3].users)[0], (Shard(0), Shard(1))
216+
)
217+
autop.sharding_optimizer.add_node_constraint(
218+
list(list(tgt_nodes[3].users)[0].users)[0], (Shard(0), Shard(1))
219+
)
220+
221+
222+
# parallelize the model
223+
with AutoParallel(
224+
model, input_fn, mesh, mp_policy, compile=True, repeated_subgraphs=True
225+
) as autop:
226+
autop.add_parameter_memory_constraint(low=None, high=None)
227+
228+
x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1)
229+
out_sharding = x_sharding
230+
if use_vocab_parallel:
231+
# add vocab parallel constraint
232+
assert mesh.ndim == 2, "Only 2d mesh supported here"
233+
out_sharding = (Shard(0), Shard(2))
234+
235+
autop.add_input_constraints([x_sharding])
236+
autop.add_output_constraints([out_sharding])
237+
238+
enable_manual_constraint = False
239+
if enable_manual_constraint and not use_1d_mesh:
240+
add_tp_constraints(autop)
241+
242+
if enable_asynctp:
243+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
244+
245+
enable_symm_mem_for_group(mesh["dp"].get_group().group_name)
246+
enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
247+
torch._inductor.config._micro_pipeline_tp = False
248+
from autoparallel.asynctp import micro_pipeline_tp_pass
249+
250+
existing_post_grad_custom_post_pass = (
251+
torch._inductor.config.post_grad_custom_post_pass
252+
)
253+
254+
def _pass(graph):
255+
if existing_post_grad_custom_post_pass is not None:
256+
existing_post_grad_custom_post_pass(graph)
257+
micro_pipeline_tp_pass(graph)
258+
259+
torch._inductor.config.post_grad_custom_post_pass = _pass
260+
261+
t = time.time()
262+
sharding_placement = autop.optimize_placement(verbose=True)
263+
print(f"Took {time.time() - t:.2f} s")
264+
parallel_mod = autop.apply_placement(sharding_placement)
265+
multiplex_graph = True
266+
if multiplex_graph:
267+
f_gm = autop.fw_module
268+
b_gm = autop.bw_module
269+
print("Original Fwd Graph:")
270+
print(f_gm.graph)
271+
print("Original Bwd Graph:")
272+
print(b_gm.graph)
273+
prefetch_f_gm, main_f_gm = split_fsdp_prefetch(f_gm)
274+
print("Main Fwd Graph:")
275+
print(main_f_gm.graph)
276+
print("Prefetch Fwd Graph:")
277+
print(prefetch_f_gm.graph)
278+
prefetch_b_gm, main_b_gm = split_fsdp_prefetch(b_gm)
279+
print("Main Bwd Graph:")
280+
print(main_b_gm.graph)
281+
print("Prefetch Bwd Graph:")
282+
print(prefetch_b_gm.graph)
283+
multiplexed_gm = multiplex_fw_bw_graph(main_f_gm, main_b_gm)
284+
print("Multiplexed Graph:")
285+
print(multiplexed_gm.graph)
286+
# in the AOTAutograd bw graph, the first K outputs correspond
287+
# to gradients for any params/buffers in the user model
288+
num_weight_gradients = autop.joint_with_descriptors._aot_state.aot_config.num_params_buffers
289+
main_b_gm_di, main_b_gm_dw = split_di_dw_graph(main_b_gm, num_weight_gradients=num_weight_gradients)
290+
print("Gradient w.r.t inputs Graph:")
291+
print(main_b_gm_di)
292+
print("Gradient w.r.t weights Graph:")
293+
print(main_b_gm_dw)
294+
# Test just to show that input/output calling conventions are correct
295+
# the pipeline runtime will need to do this
296+
num_total_gradients = len(main_b_gm.graph.find_nodes(op='output')[0].args[0])
297+
num_input_gradients = num_total_gradients - num_weight_gradients
298+
bw_args = [x.meta['val'] for x in main_b_gm.graph.find_nodes(op="placeholder")]
299+
input_grads_and_activations = main_b_gm_di(*bw_args)
300+
input_grads, activations = input_grads_and_activations[:num_input_gradients], input_grads_and_activations[num_input_gradients:]
301+
weight_grads = main_b_gm_dw(*activations)
302+
print(f'num input grads: {len(input_grads)}')
303+
print(f'num weight grads: {len(weight_grads)}')
304+
305+
# run weight init on our sharded DTensor params
306+
parallel_mod.to_empty(device="cuda")
307+
parallel_mod.init_weights()
308+
309+
# now let's run it
310+
x = (
311+
torch.randint(
312+
0,
313+
vocab_size,
314+
(batch_size // mesh.shape[0], seqlen),
315+
device=torch.device("cuda"),
316+
),
317+
)
318+
#out = parallel_mod(*x)
319+
#out.backward(torch.randn_like(out))
320+
print("All good!")

0 commit comments

Comments
 (0)