Skip to content

Commit d2ba202

Browse files
[Autobucketing] Add simplefsdp's autobucketing pass to autoparallel (#141)
* Update (base update) [ghstack-poisoned] * Update (base update) [ghstack-poisoned] * Update (base update) [ghstack-poisoned] * Update (base update) [ghstack-poisoned] * Update (#134) [ghstack-poisoned]
1 parent 90332a3 commit d2ba202

File tree

7 files changed

+2776
-0
lines changed

7 files changed

+2776
-0
lines changed

autoparallel/auto_bucketing.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
from .autobucketing_util import bucket_func, bucket_plan, bucket_utils, reorder
9+
10+
11+
class simplefsdp_autobucketing_config:
12+
"""
13+
Config for simplefsdp's autobucketing pass, which by default would give good performance.
14+
To make the results tunable, we expose the following parameters:
15+
- relax_ratio: relax comp time to include more comm in one bucket
16+
with this config, comp is updated as comp * (1 + relax_ratio)
17+
- peak_memory_offset: relax peak_memory to include more comm in one bucket
18+
with this config, peak_memory is updated as (peak_memory + peak_memory_offset)
19+
- load_cache: set to True to load cache from save_estimation_path
20+
- enable_bucket_ir: set to True to bucket all_gather/reduce_scatter
21+
- enable_reorder_ir: set to True to reorder all_gather/reduce_satter
22+
- calibrate_number: number of samples to calibrate during comm estimation
23+
"""
24+
25+
relax_ratio = 0
26+
peak_memory_offset = 0
27+
load_cache = False
28+
save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast.pkl"
29+
enable_bucket_ir = True
30+
enable_reorder_ir = True
31+
calibrate_number = 40
32+
33+
34+
def simple_fsdp_autobucketing_reordering_pass(
35+
snodes: list["torch._inductor.scheduler.BaseSchedulerNode"],
36+
configs: "simplefsdp_autobucketing_config",
37+
) -> list["torch._inductor.scheduler.BaseSchedulerNode"]:
38+
scheduler = snodes[0].scheduler
39+
bucketable_nodes = bucket_utils.get_bucketable_ir_nodes(
40+
snodes, scheduler.name_to_fused_node, scheduler.name_to_buf
41+
)
42+
43+
assert (
44+
not torch._inductor.config.allow_buffer_reuse
45+
), "bucketing algorithm requires torch._inductor.config.allow_buffer_reuse to be False"
46+
47+
if configs.enable_bucket_ir:
48+
all_gather_plan, reduce_scatter_plan = bucket_plan.get_simplefsdp_auto_plan(
49+
scheduler,
50+
snodes,
51+
scheduler.name_to_buf,
52+
scheduler.name_to_fused_node,
53+
bucketable_nodes,
54+
configs,
55+
)
56+
57+
snodes = bucket_func.bucket_fsdp_all_gather_with_plan(
58+
scheduler,
59+
snodes,
60+
scheduler.name_to_buf,
61+
scheduler.name_to_fused_node,
62+
all_gather_plan,
63+
bucketable_nodes,
64+
)
65+
if len(reduce_scatter_plan) > 0:
66+
snodes = bucket_func.bucket_fsdp_reduce_scatter_with_plan(
67+
scheduler,
68+
snodes,
69+
scheduler.name_to_buf,
70+
scheduler.name_to_fused_node,
71+
reduce_scatter_plan,
72+
bucketable_nodes,
73+
)
74+
75+
if configs.enable_reorder_ir:
76+
print("Reorder scheduler nodes with autobucketing algroithm")
77+
node_length = len(snodes)
78+
snodes = reorder.reorder_all_gather(
79+
snodes, bucketable_nodes, all_gather_before_last_wait=False
80+
)
81+
assert node_length == len(
82+
snodes
83+
), f"Missed nodes in reordering all gather: expected {node_length}, but got {len(snodes)}"
84+
snodes = reorder.reorder_reduce_scatter(snodes, bucketable_nodes)
85+
assert node_length == len(
86+
snodes
87+
), f"Missed nodes in reordering reduce scatter: expected {node_length}, but got {len(snodes)}"
88+
89+
return snodes

0 commit comments

Comments
 (0)