Skip to content

Commit f1cc5fc

Browse files
committed
fix test errors
1 parent 798d879 commit f1cc5fc

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

examples/example_llama3.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import time
77
from dataclasses import dataclass
8+
from functools import partial
89
from typing import ClassVar
910

1011
import torch
@@ -16,6 +17,10 @@
1617
from torch.testing._internal.distributed.fake_pg import FakeStore
1718

1819
from autoparallel.api import AutoParallel
20+
from autoparallel.auto_bucketing import (
21+
simple_fsdp_autobucketing_reordering_pass,
22+
simplefsdp_autobucketing_config,
23+
)
1924

2025

2126
def has_cuda_capability(major: int, minor: int) -> bool:
@@ -558,14 +563,18 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
558563

559564
world_size = 256
560565

561-
backend = "fake"
562-
kwargs = {"rank": 0, "world_size": world_size}
563-
if True:
564-
backend = "nccl"
565-
fake_store = None
566-
kwargs = {}
567-
world_size = 8
568-
torch.distributed.init_process_group(backend, store=fake_store, **kwargs)
566+
fake_store = FakeStore()
567+
torch.distributed.init_process_group(
568+
"fake", store=fake_store, rank=0, world_size=world_size
569+
)
570+
# backend = "fake"
571+
# kwargs = {"rank": 0, "world_size": world_size}
572+
# if True:
573+
# backend = "nccl"
574+
# fake_store = None
575+
# kwargs = {}
576+
# world_size = 8
577+
# torch.distributed.init_process_group(backend, store=fake_store, **kwargs)
569578

570579
use_1d_mesh = False
571580

@@ -608,19 +617,11 @@ def input_fn():
608617
return x
609618

610619

611-
from functools import partial
612-
613-
from autoparallel.auto_bucketing import (
614-
simple_fsdp_autobucketing_reordering_pass,
615-
simplefsdp_autobucketing_config,
616-
)
617-
618620
torch._inductor.config.allow_buffer_reuse = False
619621
torch._inductor.config.reorder_for_peak_memory = False
620622
torch._inductor.config.reorder_for_compute_comm_overlap = True
621-
simplefsdp_autobucketing_config.save_estimation_path = (
622-
"/storage/home/fmassa/work/projects/autoparallel/estimation_mast.pkl"
623-
)
623+
simplefsdp_autobucketing_config.calibrate_number = 5
624+
simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl"
624625
simple_fsdp_autobucketing_reordering_pass = partial(
625626
simple_fsdp_autobucketing_reordering_pass,
626627
configs=simplefsdp_autobucketing_config,

0 commit comments

Comments
 (0)