|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import time |
| 7 | +from functools import partial |
| 8 | + |
| 9 | +import torch |
| 10 | +from torch.distributed.fsdp import MixedPrecisionPolicy |
| 11 | +from torch.distributed.tensor.placement_types import Partial, Replicate, Shard |
| 12 | +from torch.testing._internal.distributed.fake_pg import FakeStore |
| 13 | + |
| 14 | +from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph |
| 15 | +from autoparallel._passes.split_di_dw_graph import split_di_dw_graph |
| 16 | +from autoparallel._passes.split_fsdp_collectives import split_fsdp_prefetch |
| 17 | +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs |
| 18 | +from autoparallel.api import AutoParallel |
| 19 | +from autoparallel.auto_bucketing import ( |
| 20 | + aten_autobucketing_config, |
| 21 | + aten_autobucketing_reordering_pass, |
| 22 | + simple_fsdp_autobucketing_reordering_pass, |
| 23 | + simplefsdp_autobucketing_config, |
| 24 | +) |
| 25 | + |
| 26 | +world_size = 64 |
| 27 | + |
| 28 | +fake_store = FakeStore() |
| 29 | +torch.distributed.init_process_group( |
| 30 | + "fake", store=fake_store, rank=0, world_size=world_size |
| 31 | +) |
| 32 | + |
| 33 | +use_1d_mesh = False |
| 34 | + |
| 35 | +if use_1d_mesh: |
| 36 | + mesh = torch.distributed.device_mesh.init_device_mesh( |
| 37 | + "cuda", (world_size,), mesh_dim_names=("dp",) |
| 38 | + ) |
| 39 | +else: |
| 40 | + mesh = torch.distributed.device_mesh.init_device_mesh( |
| 41 | + "cuda", |
| 42 | + (world_size // 8, 8), |
| 43 | + mesh_dim_names=( |
| 44 | + "dp", |
| 45 | + "tp", |
| 46 | + ), |
| 47 | + ) |
| 48 | + |
| 49 | +batch_size = 2 * mesh.shape[0] |
| 50 | +seqlen = 2048 * 4 |
| 51 | +vocab_size = 128256 |
| 52 | +use_vocab_parallel = not use_1d_mesh |
| 53 | +device = torch.device("cuda") |
| 54 | + |
| 55 | +model_type = "8b" |
| 56 | +enable_asynctp = False |
| 57 | + |
| 58 | + |
| 59 | +def model_fn(): |
| 60 | + if model_type == "8b": |
| 61 | + model_args = TransformerModelArgs( |
| 62 | + dim=4096, |
| 63 | + n_layers=1, |
| 64 | + n_heads=32, |
| 65 | + n_kv_heads=8, |
| 66 | + ffn_dim_multiplier=1.3, |
| 67 | + multiple_of=1024, |
| 68 | + rope_theta=500000, |
| 69 | + vocab_size=vocab_size, |
| 70 | + max_seq_len=seqlen, |
| 71 | + ) |
| 72 | + elif model_type == "70b": |
| 73 | + model_args = TransformerModelArgs( |
| 74 | + dim=8192, |
| 75 | + n_layers=80, |
| 76 | + n_heads=64, |
| 77 | + n_kv_heads=8, |
| 78 | + ffn_dim_multiplier=1.3, |
| 79 | + multiple_of=4096, |
| 80 | + rope_theta=500000, |
| 81 | + vocab_size=vocab_size, |
| 82 | + max_seq_len=seqlen, |
| 83 | + ) |
| 84 | + else: |
| 85 | + raise ValueError(f"{model_type} not available") |
| 86 | + m = Transformer(model_args) |
| 87 | + # I turned of the tok_embeddings layer because: |
| 88 | + # - I want the input to my joint graph to require grad, |
| 89 | + # so I can generate separate dI and dW gradients to carve out subgraphs for |
| 90 | + # - The input to tok_embeddings is an integral tensor which can't require grad. |
| 91 | + # Another option would be to manually apply autoparallel to a single transformer block layer. |
| 92 | + m.tok_embeddings = None |
| 93 | + return m |
| 94 | + |
| 95 | + |
| 96 | +def input_fn(): |
| 97 | + # 8192 for 70B |
| 98 | + x = torch.randn(batch_size, seqlen, 4096, device=device, requires_grad=True) |
| 99 | + return x |
| 100 | + |
| 101 | + |
| 102 | +autobucketing_level = "aten" |
| 103 | + |
| 104 | +if autobucketing_level == "aten": |
| 105 | + # this is from the stacked pr in https://github.com/pytorch/pytorch/pull/163960 |
| 106 | + torch._inductor.config.reorder_for_peak_memory = False |
| 107 | + torch._inductor.config.reorder_for_compute_comm_overlap = False |
| 108 | + aten_autobucketing_reordering_pass = partial( |
| 109 | + aten_autobucketing_reordering_pass, |
| 110 | + configs=aten_autobucketing_config, |
| 111 | + ) |
| 112 | + torch._inductor.config.post_grad_custom_post_pass = ( |
| 113 | + aten_autobucketing_reordering_pass |
| 114 | + ) |
| 115 | +elif autobucketing_level == "inductor": |
| 116 | + torch._inductor.config.allow_buffer_reuse = False |
| 117 | + torch._inductor.config.reorder_for_peak_memory = False |
| 118 | + torch._inductor.config.reorder_for_compute_comm_overlap = True |
| 119 | + simplefsdp_autobucketing_config.calibrate_number = 5 |
| 120 | + simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl" |
| 121 | + simple_fsdp_autobucketing_reordering_pass = partial( |
| 122 | + simple_fsdp_autobucketing_reordering_pass, |
| 123 | + configs=simplefsdp_autobucketing_config, |
| 124 | + ) |
| 125 | + torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ |
| 126 | + simple_fsdp_autobucketing_reordering_pass |
| 127 | + ] |
| 128 | +else: |
| 129 | + raise ValueError(f"Unknown autobucketing_level {autobucketing_level}") |
| 130 | + |
| 131 | + |
| 132 | +# parallelize the model |
| 133 | +with torch.device("meta"): |
| 134 | + model = model_fn() |
| 135 | + |
| 136 | +mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) |
| 137 | + |
| 138 | + |
| 139 | +def group_mm_nodes_with_its_gradients(nodes): |
| 140 | + fwd_nodes = [n for n in nodes if "nn_module_stack" in n.meta] |
| 141 | + bwd_nodes = [n for n in nodes if "fwd_nn_module_stack" in n.meta] |
| 142 | + assert len(fwd_nodes) * 2 == len(bwd_nodes) |
| 143 | + res = {} |
| 144 | + for fwd_node in fwd_nodes: |
| 145 | + o = [] |
| 146 | + for bwd_node in bwd_nodes: |
| 147 | + if fwd_node.meta["nn_module_stack"] == bwd_node.meta["fwd_nn_module_stack"]: |
| 148 | + o.append(bwd_node) |
| 149 | + assert len(o) == 2 |
| 150 | + res[fwd_node] = o |
| 151 | + return res |
| 152 | + |
| 153 | + |
| 154 | +def force_tp_constraints(autop, mm_nodes, feat_dim=1, bwd_constraint=False): |
| 155 | + # out = x @ w - S(0)R, RS(1) -> S(0)S(1) |
| 156 | + # g_w = g.T @ x - S(1)S(0), S(0)R -> PS(0) |
| 157 | + # g_x = g @ w.T - S(0)S(1), RS(0) -> S(0)P |
| 158 | + |
| 159 | + add_node_constraint = autop.sharding_optimizer.add_node_constraint |
| 160 | + fwd_bwd_groups = group_mm_nodes_with_its_gradients(mm_nodes) |
| 161 | + fwd_nodes = list(fwd_bwd_groups.keys()) |
| 162 | + dim1 = 0 if feat_dim == 1 else 1 |
| 163 | + dim2 = 1 if feat_dim == 1 else 0 |
| 164 | + # assume there are 7 mm nodes per transformer block |
| 165 | + # skip last mm as it's the final projection layer |
| 166 | + assert ( |
| 167 | + len(fwd_nodes) - 1 |
| 168 | + ) % 7 == 0, f"expected 7 mm nodes per transformer block, {len(fwd_nodes) - 1}" |
| 169 | + for block in range(0, len(fwd_nodes) - 1, 7): |
| 170 | + fwd_nodes_block = fwd_nodes[block : block + 7] |
| 171 | + # force the first 3 mm nodes to be S(0)S(1) |
| 172 | + the_nodes = fwd_nodes_block[:3] + fwd_nodes_block[4:6] |
| 173 | + for n in the_nodes: |
| 174 | + add_node_constraint(n, (Shard(0), Shard(feat_dim))) |
| 175 | + add_node_constraint(n.all_input_nodes[0], (Shard(0), Replicate())) |
| 176 | + add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(1))) |
| 177 | + |
| 178 | + if bwd_constraint: |
| 179 | + bwd_nodes = fwd_bwd_groups[n] |
| 180 | + # first is g_w, second is g_x |
| 181 | + add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim1))) |
| 182 | + add_node_constraint(bwd_nodes[1], (Shard(0), Partial())) |
| 183 | + |
| 184 | + # add reduction to finish TP, yielding S(0)P |
| 185 | + the_nodes = fwd_nodes_block[3:4] + fwd_nodes_block[6:7] |
| 186 | + for n in the_nodes: |
| 187 | + add_node_constraint(n, (Shard(0), Partial())) |
| 188 | + add_node_constraint(n.all_input_nodes[0], (Shard(0), Shard(feat_dim))) |
| 189 | + add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(0))) |
| 190 | + |
| 191 | + if bwd_constraint: |
| 192 | + bwd_nodes = fwd_bwd_groups[n] |
| 193 | + # first is g_w, second is g_x |
| 194 | + add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim2))) |
| 195 | + add_node_constraint(bwd_nodes[1], (Shard(0), Shard(feat_dim))) |
| 196 | + |
| 197 | + |
| 198 | +def add_tp_constraints(autop): |
| 199 | + mm_nodes = autop.gm.graph.find_nodes( |
| 200 | + op="call_function", target=torch.ops.aten.mm.default |
| 201 | + ) |
| 202 | + einsum_nodes = autop.gm.graph.find_nodes( |
| 203 | + op="call_function", target=torch.ops.aten.einsum.default |
| 204 | + ) |
| 205 | + assert (len(mm_nodes) > 0) ^ ( |
| 206 | + len(einsum_nodes) > 0 |
| 207 | + ), f"only one should be non-empty, got {len(mm_nodes)} and {len(einsum_nodes)}" |
| 208 | + feat_dim = 1 if len(mm_nodes) > 0 else 2 |
| 209 | + tgt_nodes = mm_nodes + einsum_nodes |
| 210 | + force_tp_constraints(autop, tgt_nodes, feat_dim=feat_dim, bwd_constraint=True) |
| 211 | + |
| 212 | + if einsum_nodes: |
| 213 | + # add sequence parallelism if we have einsum nodes |
| 214 | + autop.sharding_optimizer.add_node_constraint( |
| 215 | + list(tgt_nodes[3].users)[0], (Shard(0), Shard(1)) |
| 216 | + ) |
| 217 | + autop.sharding_optimizer.add_node_constraint( |
| 218 | + list(list(tgt_nodes[3].users)[0].users)[0], (Shard(0), Shard(1)) |
| 219 | + ) |
| 220 | + |
| 221 | + |
| 222 | +# parallelize the model |
| 223 | +with AutoParallel( |
| 224 | + model, input_fn, mesh, mp_policy, compile=True, repeated_subgraphs=True |
| 225 | +) as autop: |
| 226 | + autop.add_parameter_memory_constraint(low=None, high=None) |
| 227 | + |
| 228 | + x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) |
| 229 | + out_sharding = x_sharding |
| 230 | + if use_vocab_parallel: |
| 231 | + # add vocab parallel constraint |
| 232 | + assert mesh.ndim == 2, "Only 2d mesh supported here" |
| 233 | + out_sharding = (Shard(0), Shard(2)) |
| 234 | + |
| 235 | + autop.add_input_constraints([x_sharding]) |
| 236 | + autop.add_output_constraints([out_sharding]) |
| 237 | + |
| 238 | + enable_manual_constraint = False |
| 239 | + if enable_manual_constraint and not use_1d_mesh: |
| 240 | + add_tp_constraints(autop) |
| 241 | + |
| 242 | + if enable_asynctp: |
| 243 | + from torch.distributed._symmetric_memory import enable_symm_mem_for_group |
| 244 | + |
| 245 | + enable_symm_mem_for_group(mesh["dp"].get_group().group_name) |
| 246 | + enable_symm_mem_for_group(mesh["tp"].get_group().group_name) |
| 247 | + torch._inductor.config._micro_pipeline_tp = False |
| 248 | + from autoparallel.asynctp import micro_pipeline_tp_pass |
| 249 | + |
| 250 | + existing_post_grad_custom_post_pass = ( |
| 251 | + torch._inductor.config.post_grad_custom_post_pass |
| 252 | + ) |
| 253 | + |
| 254 | + def _pass(graph): |
| 255 | + if existing_post_grad_custom_post_pass is not None: |
| 256 | + existing_post_grad_custom_post_pass(graph) |
| 257 | + micro_pipeline_tp_pass(graph) |
| 258 | + |
| 259 | + torch._inductor.config.post_grad_custom_post_pass = _pass |
| 260 | + |
| 261 | + t = time.time() |
| 262 | + sharding_placement = autop.optimize_placement(verbose=True) |
| 263 | + print(f"Took {time.time() - t:.2f} s") |
| 264 | + parallel_mod = autop.apply_placement(sharding_placement) |
| 265 | + multiplex_graph = True |
| 266 | + if multiplex_graph: |
| 267 | + f_gm = autop.fw_module |
| 268 | + b_gm = autop.bw_module |
| 269 | + print("Original Fwd Graph:") |
| 270 | + print(f_gm.graph) |
| 271 | + print("Original Bwd Graph:") |
| 272 | + print(b_gm.graph) |
| 273 | + prefetch_f_gm, main_f_gm = split_fsdp_prefetch(f_gm) |
| 274 | + print("Main Fwd Graph:") |
| 275 | + print(main_f_gm.graph) |
| 276 | + print("Prefetch Fwd Graph:") |
| 277 | + print(prefetch_f_gm.graph) |
| 278 | + prefetch_b_gm, main_b_gm = split_fsdp_prefetch(b_gm) |
| 279 | + print("Main Bwd Graph:") |
| 280 | + print(main_b_gm.graph) |
| 281 | + print("Prefetch Bwd Graph:") |
| 282 | + print(prefetch_b_gm.graph) |
| 283 | + multiplexed_gm = multiplex_fw_bw_graph(main_f_gm, main_b_gm) |
| 284 | + print("Multiplexed Graph:") |
| 285 | + print(multiplexed_gm.graph) |
| 286 | + # in the AOTAutograd bw graph, the first K outputs correspond |
| 287 | + # to gradients for any params/buffers in the user model |
| 288 | + num_weight_gradients = autop.joint_with_descriptors._aot_state.aot_config.num_params_buffers |
| 289 | + main_b_gm_di, main_b_gm_dw = split_di_dw_graph(main_b_gm, num_weight_gradients=num_weight_gradients) |
| 290 | + print("Gradient w.r.t inputs Graph:") |
| 291 | + print(main_b_gm_di) |
| 292 | + print("Gradient w.r.t weights Graph:") |
| 293 | + print(main_b_gm_dw) |
| 294 | + # Test just to show that input/output calling conventions are correct |
| 295 | + # the pipeline runtime will need to do this |
| 296 | + num_total_gradients = len(main_b_gm.graph.find_nodes(op='output')[0].args[0]) |
| 297 | + num_input_gradients = num_total_gradients - num_weight_gradients |
| 298 | + bw_args = [x.meta['val'] for x in main_b_gm.graph.find_nodes(op="placeholder")] |
| 299 | + input_grads_and_activations = main_b_gm_di(*bw_args) |
| 300 | + input_grads, activations = input_grads_and_activations[:num_input_gradients], input_grads_and_activations[num_input_gradients:] |
| 301 | + weight_grads = main_b_gm_dw(*activations) |
| 302 | + print(f'num input grads: {len(input_grads)}') |
| 303 | + print(f'num weight grads: {len(weight_grads)}') |
| 304 | + |
| 305 | +# run weight init on our sharded DTensor params |
| 306 | +parallel_mod.to_empty(device="cuda") |
| 307 | +parallel_mod.init_weights() |
| 308 | + |
| 309 | +# now let's run it |
| 310 | +x = ( |
| 311 | + torch.randint( |
| 312 | + 0, |
| 313 | + vocab_size, |
| 314 | + (batch_size // mesh.shape[0], seqlen), |
| 315 | + device=torch.device("cuda"), |
| 316 | + ), |
| 317 | +) |
| 318 | +#out = parallel_mod(*x) |
| 319 | +#out.backward(torch.randn_like(out)) |
| 320 | +print("All good!") |
0 commit comments