|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +from .autobucketing_util import bucket_func, bucket_plan, bucket_utils, reorder |
| 9 | + |
| 10 | + |
| 11 | +class simplefsdp_autobucketing_config: |
| 12 | + """ |
| 13 | + Config for simplefsdp's autobucketing pass, which by default would give good performance. |
| 14 | + To make the results tunable, we expose the following parameters: |
| 15 | + - relax_ratio: relax comp time to include more comm in one bucket |
| 16 | + with this config, comp is updated as comp * (1 + relax_ratio) |
| 17 | + - peak_memory_offset: relax peak_memory to include more comm in one bucket |
| 18 | + with this config, peak_memory is updated as (peak_memory + peak_memory_offset) |
| 19 | + - load_cache: set to True to load cache from save_estimation_path |
| 20 | + - enable_bucket_ir: set to True to bucket all_gather/reduce_scatter |
| 21 | + - enable_reorder_ir: set to True to reorder all_gather/reduce_satter |
| 22 | + - calibrate_number: number of samples to calibrate during comm estimation |
| 23 | + """ |
| 24 | + |
| 25 | + relax_ratio = 0 |
| 26 | + peak_memory_offset = 0 |
| 27 | + load_cache = False |
| 28 | + save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast.pkl" |
| 29 | + enable_bucket_ir = True |
| 30 | + enable_reorder_ir = True |
| 31 | + calibrate_number = 40 |
| 32 | + |
| 33 | + |
| 34 | +def simple_fsdp_autobucketing_reordering_pass( |
| 35 | + snodes: list["torch._inductor.scheduler.BaseSchedulerNode"], |
| 36 | + configs: "simplefsdp_autobucketing_config", |
| 37 | +) -> list["torch._inductor.scheduler.BaseSchedulerNode"]: |
| 38 | + scheduler = snodes[0].scheduler |
| 39 | + bucketable_nodes = bucket_utils.get_bucketable_ir_nodes( |
| 40 | + snodes, scheduler.name_to_fused_node, scheduler.name_to_buf |
| 41 | + ) |
| 42 | + |
| 43 | + assert ( |
| 44 | + not torch._inductor.config.allow_buffer_reuse |
| 45 | + ), "bucketing algorithm requires torch._inductor.config.allow_buffer_reuse to be False" |
| 46 | + |
| 47 | + if configs.enable_bucket_ir: |
| 48 | + all_gather_plan, reduce_scatter_plan = bucket_plan.get_simplefsdp_auto_plan( |
| 49 | + scheduler, |
| 50 | + snodes, |
| 51 | + scheduler.name_to_buf, |
| 52 | + scheduler.name_to_fused_node, |
| 53 | + bucketable_nodes, |
| 54 | + configs, |
| 55 | + ) |
| 56 | + |
| 57 | + snodes = bucket_func.bucket_fsdp_all_gather_with_plan( |
| 58 | + scheduler, |
| 59 | + snodes, |
| 60 | + scheduler.name_to_buf, |
| 61 | + scheduler.name_to_fused_node, |
| 62 | + all_gather_plan, |
| 63 | + bucketable_nodes, |
| 64 | + ) |
| 65 | + if len(reduce_scatter_plan) > 0: |
| 66 | + snodes = bucket_func.bucket_fsdp_reduce_scatter_with_plan( |
| 67 | + scheduler, |
| 68 | + snodes, |
| 69 | + scheduler.name_to_buf, |
| 70 | + scheduler.name_to_fused_node, |
| 71 | + reduce_scatter_plan, |
| 72 | + bucketable_nodes, |
| 73 | + ) |
| 74 | + |
| 75 | + if configs.enable_reorder_ir: |
| 76 | + print("Reorder scheduler nodes with autobucketing algroithm") |
| 77 | + node_length = len(snodes) |
| 78 | + snodes = reorder.reorder_all_gather( |
| 79 | + snodes, bucketable_nodes, all_gather_before_last_wait=False |
| 80 | + ) |
| 81 | + assert node_length == len( |
| 82 | + snodes |
| 83 | + ), f"Missed nodes in reordering all gather: expected {node_length}, but got {len(snodes)}" |
| 84 | + snodes = reorder.reorder_reduce_scatter(snodes, bucketable_nodes) |
| 85 | + assert node_length == len( |
| 86 | + snodes |
| 87 | + ), f"Missed nodes in reordering reduce scatter: expected {node_length}, but got {len(snodes)}" |
| 88 | + |
| 89 | + return snodes |
0 commit comments