|
5 | 5 |
|
6 | 6 | import time |
7 | 7 | from dataclasses import dataclass |
| 8 | +from functools import partial |
8 | 9 | from typing import ClassVar |
9 | 10 |
|
10 | 11 | import torch |
|
16 | 17 | from torch.testing._internal.distributed.fake_pg import FakeStore |
17 | 18 |
|
18 | 19 | from autoparallel.api import AutoParallel |
| 20 | +from autoparallel.auto_bucketing import ( |
| 21 | + simple_fsdp_autobucketing_reordering_pass, |
| 22 | + simplefsdp_autobucketing_config, |
| 23 | +) |
19 | 24 |
|
20 | 25 |
|
21 | 26 | def has_cuda_capability(major: int, minor: int) -> bool: |
@@ -558,14 +563,18 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None) |
558 | 563 |
|
559 | 564 | world_size = 256 |
560 | 565 |
|
561 | | -backend = "fake" |
562 | | -kwargs = {"rank": 0, "world_size": world_size} |
563 | | -if True: |
564 | | - backend = "nccl" |
565 | | - fake_store = None |
566 | | - kwargs = {} |
567 | | - world_size = 8 |
568 | | -torch.distributed.init_process_group(backend, store=fake_store, **kwargs) |
| 566 | +fake_store = FakeStore() |
| 567 | +torch.distributed.init_process_group( |
| 568 | + "fake", store=fake_store, rank=0, world_size=world_size |
| 569 | +) |
| 570 | +# backend = "fake" |
| 571 | +# kwargs = {"rank": 0, "world_size": world_size} |
| 572 | +# if True: |
| 573 | +# backend = "nccl" |
| 574 | +# fake_store = None |
| 575 | +# kwargs = {} |
| 576 | +# world_size = 8 |
| 577 | +# torch.distributed.init_process_group(backend, store=fake_store, **kwargs) |
569 | 578 |
|
570 | 579 | use_1d_mesh = False |
571 | 580 |
|
@@ -608,19 +617,11 @@ def input_fn(): |
608 | 617 | return x |
609 | 618 |
|
610 | 619 |
|
611 | | -from functools import partial |
612 | | - |
613 | | -from autoparallel.auto_bucketing import ( |
614 | | - simple_fsdp_autobucketing_reordering_pass, |
615 | | - simplefsdp_autobucketing_config, |
616 | | -) |
617 | | - |
618 | 620 | torch._inductor.config.allow_buffer_reuse = False |
619 | 621 | torch._inductor.config.reorder_for_peak_memory = False |
620 | 622 | torch._inductor.config.reorder_for_compute_comm_overlap = True |
621 | | -simplefsdp_autobucketing_config.save_estimation_path = ( |
622 | | - "/storage/home/fmassa/work/projects/autoparallel/estimation_mast.pkl" |
623 | | -) |
| 623 | +simplefsdp_autobucketing_config.calibrate_number = 5 |
| 624 | +simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl" |
624 | 625 | simple_fsdp_autobucketing_reordering_pass = partial( |
625 | 626 | simple_fsdp_autobucketing_reordering_pass, |
626 | 627 | configs=simplefsdp_autobucketing_config, |
|
0 commit comments