From 6f037816901ac95e5b457c10a5332c79d1744650 Mon Sep 17 00:00:00 2001 From: Haoci Zhang Date: Sun, 23 Jun 2024 11:59:19 -0700 Subject: [PATCH 1/4] added flexible pp support --- pippy/PipelineSchedule.py | 154 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/pippy/PipelineSchedule.py b/pippy/PipelineSchedule.py index 4f79c1287..f11454787 100644 --- a/pippy/PipelineSchedule.py +++ b/pippy/PipelineSchedule.py @@ -776,3 +776,157 @@ def backward_stage_local_index(step): # Return losses if there is a container passed in self._update_losses(self._stages, losses) + +class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti): + def __init__( + self, + stages: List[PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + output_merge_spec=output_merge_spec, + ) + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Flexible Interleaved 1F1B requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + arg_mbs, kwarg_mbs = self._check_inputs( + arg_mbs, kwarg_mbs, target_mbs, losses + ) + warmup_steps = (self.n_local_stages - 1) * self.microbatches_per_round + warmup_steps += 2 * ((self.pp_group_size - 1) - self.rank) + warmup_steps = min( + warmup_steps, self._n_microbatches * self.n_local_stages + ) + fwd_bwd_steps = ( + self.n_local_stages * self._n_microbatches + ) - warmup_steps + cooldown_steps = warmup_steps + total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps + def forward_stage_local_index(step): + return (step // self.microbatches_per_round) % self.n_local_stages + def backward_stage_local_index(step): + return ( + self.n_local_stages + - 1 + - ((step - warmup_steps) // self.microbatches_per_round) + % self.n_local_stages + ) + fwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int) + bwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int) + # Delay send waits + sends_to_wait: List[dist.Work] = [] + count = 0 + # Store ops (potentially across steps) + ops: List[dist.P2POp] = [] + # Warmup Phase (forward only) + for step in range(warmup_steps): + fwd_stage = self._stages[forward_stage_local_index(step)] + # This will assign the current microbatch index and update it for future steps + fwd_stage_mb_index[fwd_stage] = ( + mb_index := fwd_stage_mb_index[fwd_stage] + ) + 1 + logger.debug( + f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}" + ) + with record_function(f"Forward {step}"): + ops.extend(fwd_stage.get_fwd_recv_ops()) + if ops: + work = dist.batch_isend_irecv(ops).pop() + work.wait() + ops.clear() + output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index] + ops.extend(fwd_stage.get_fwd_send_ops()) + # If we are right before the fwd-bwd step, then we need to delay the send to the next step, + # This is because fwd-bwd send/recvs among ranks need to be aligned to prevent a hang. + # In the edge cases where there are no fwd_bwds and cooldown is immediate, then no delay is needed + if ops and (step != warmup_steps - 1 or fwd_bwd_steps == 0): + work = dist.batch_isend_irecv(ops).pop() + sends_to_wait.append(work) + ops.clear() + self._maybe_compute_loss( + fwd_stage, output, target_mbs, mb_index + ) + # 1F1B Phase (forward and backward) + for step in range(warmup_steps, warmup_steps + fwd_bwd_steps): + fwd_stage = self._stages[forward_stage_local_index(step)] + bwd_stage = self._stages[backward_stage_local_index(step)] + fwd_stage_mb_index[fwd_stage] = ( + fwd_mb_index := fwd_stage_mb_index[fwd_stage] + ) + 1 + bwd_stage_mb_index[bwd_stage] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage] + ) + 1 + bwd_stage._configure_data_parallel_mode( + bwd_mb_index == self._n_microbatches - 1 + ) + logger.debug( + f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}" + ) + with record_function(f"1F1B {step}"): + ops.extend(fwd_stage.get_fwd_recv_ops()) + ops.extend(bwd_stage.get_bwd_recv_ops()) + if ops: + work = dist.batch_isend_irecv(ops).pop() + work.wait() + ops.clear() + # Forward + output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + ops.extend(fwd_stage.get_fwd_send_ops()) + self._maybe_compute_loss( + fwd_stage, output, target_mbs, fwd_mb_index + ) + # Backward + loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) + bwd_stage.backward_one_chunk(loss=loss) + ops.extend(bwd_stage.get_bwd_send_ops()) + # Cooldown Phase (backward only) + for step in range(warmup_steps + fwd_bwd_steps, total_steps): + bwd_stage = self._stages[backward_stage_local_index(step)] + bwd_stage_mb_index[bwd_stage] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage] + ) + 1 + bwd_stage._configure_data_parallel_mode( + bwd_mb_index == self._n_microbatches - 1 + ) + logger.debug( + f"Rank {self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}" + ) + with record_function(f"Cooldown {step}"): + ops.extend(bwd_stage.get_bwd_recv_ops()) + if ops: + work = dist.batch_isend_irecv(ops).pop() + work.wait() + ops.clear() + loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) + bwd_stage.backward_one_chunk(loss=loss) + ops.extend(bwd_stage.get_bwd_send_ops()) + if ops: + work = dist.batch_isend_irecv(ops).pop() + sends_to_wait.append(work) + ops.clear() + # Make sure all sends are finished + for work in sends_to_wait: + work.wait() + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) From 7802f61ad3a4a28acdf17d2258623ec209edb083 Mon Sep 17 00:00:00 2001 From: Haoci Zhang Date: Sun, 23 Jun 2024 11:59:19 -0700 Subject: [PATCH 2/4] added flexible pp support --- test/test_flex_pipeline_schedule_e2e.py | 325 ++++++++++++++++++++++++ test/test_pipeline_schedule_e2e.py | 69 ++++- 2 files changed, 383 insertions(+), 11 deletions(-) create mode 100644 test/test_flex_pipeline_schedule_e2e.py diff --git a/test/test_flex_pipeline_schedule_e2e.py b/test/test_flex_pipeline_schedule_e2e.py new file mode 100644 index 000000000..d8a401b83 --- /dev/null +++ b/test/test_flex_pipeline_schedule_e2e.py @@ -0,0 +1,325 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +""" +SINGLE HOST: + +python test_pipeline_schedule_e2e.py + +or + +with torchrun (1x2, 1 host with 2 processes): +torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:29500 --nnodes=1 --nproc-per-node=2 test_pipeline_schedule_e2e.py + +MULTIPLE HOSTS: + +torchrun --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR --nnodes=$NUM_NODES --nproc-per-node=$NUM_TRAINERS test_pipeline_schedule_e2e.py + +e.g. (2x2, 2 hosts with 2 processes) +torchrun --rdzv-backend=c10d --rdzv-endpoint=node1.example.com:29400 --nnodes=2 --nproc-per-node=2 test_pipeline_schedule_e2e.py +""" + +import argparse +import logging +import os +from contextlib import contextmanager, nullcontext + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from pippy import ( + ManualPipelineStage, + Schedule1F1B, + ScheduleGPipe, + ScheduleInterleaved1F1B, + ScheduleLoopedBFS, + ScheduleFlexibleInterleaved1F1B, +) + +from torch.distributed._tensor.device_mesh import init_device_mesh +from torch.profiler import record_function + +logger = logging.getLogger(__name__) + + +# profiling context manager +@contextmanager +def maybe_run_profiler( + use_profiler, trace_dir, schedule, rank, *args, **kwargs +): + def trace_handler(prof): + if rank == 0: + (f"about to EXPORT traces for {schedule} to {trace_dir}") + prof.export_chrome_trace( + f"{trace_dir}/{schedule}_rank{rank}_trace.json" + ) + + if use_profiler: + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + # schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1), + on_trace_ready=trace_handler, + profile_memory=True, + with_stack=False, + record_shapes=True, + ) as torch_profiler: + yield torch_profiler + else: + torch_profiler = nullcontext() + yield None + + +class MLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + out_dim: int, + ): + super().__init__() + self.wi = nn.Linear(dim, hidden_dim, bias=False) + self.wh1 = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.wh2 = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.wh3 = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.wo = nn.Linear(hidden_dim, out_dim, bias=False) + self.gelu_act = nn.GELU(approximate="tanh") + + # Apply Kaiming initialization + nn.init.kaiming_normal_(self.wi.weight) + nn.init.kaiming_normal_(self.wh1.weight) + nn.init.kaiming_normal_(self.wh3.weight) + nn.init.kaiming_normal_(self.wh3.weight) + nn.init.kaiming_normal_(self.wo.weight) + + def forward(self, x): + a = self.wi(x) + a = self.wh1(a) + a = self.wh2(a) + a = self.wh3(a) + b = self.gelu_act(a) + c = self.wo(b) + return c + + +def setup(local_rank, log_level, init_process_group=False): + # If this is a child process (i.e., its PID is not the same as the PID of the process that started this script) + if os.getppid() != os.getpid(): + set_up_logging(local_rank, log_level) + + # initialize the process group (not needed if using device_mesh) + if init_process_group: + logger.info(f"init for rank {local_rank}") + if torch.distributed.is_initialized(): + torch.cuda.set_device(local_rank) + logger.info(f"finish init for rank {local_rank}") + + +def main(**kwargs): + torch.manual_seed(42) + device = torch.device(kwargs["device"]) + + local_rank = kwargs["local_rank"] + world_size = kwargs["world_size"] + + def rank_print(msg): + if rank == 0: + print(f"{msg}") + + # use device mesh for cuda - create a device mesh based on the given world_size. + device_mesh = None + rank = None + + if device.type == "cuda": + device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),) + ) + + rank = device_mesh.get_rank() + rank_print(f"using {device_mesh=}") + + if not device_mesh: + rank = kwargs["rank"] + + init_process_group = bool(device_mesh is not None) + + setup( + local_rank, kwargs["log_level"], init_process_group=init_process_group + ) + + print( + f"====== World Rank {rank}, Local Rank {local_rank}, World Size {world_size}, Device {device} main ======" + ) + + print(f"kwargs are {kwargs}") + + microbatch_size = 8 + global_batch_size = 80 + num_stages_local = 2 + assert global_batch_size % microbatch_size == 0 + n_microbatches = int(global_batch_size / microbatch_size) + + input_dim = 900 + hidden_dim = 800 + output_dim = 900 + + model = torch.nn.Sequential( + *[MLP(input_dim, hidden_dim, output_dim) for _ in range(8)] + ) + + module_list = torch.nn.ModuleList( + modules=[model for i in range(world_size)] + ) + + print(f"{n_microbatches=}") + x = torch.randn([microbatch_size, input_dim]).to("cpu") + unused = torch.ones((1, 1), device="meta") + input_args = x + + if kwargs["stage_type"] == "manual": + stage_model = ManualPipelineStage( + module_list[rank], + stage_index=rank, + num_stages=world_size, + device=device, + input_args=input_args, + num_microbatches=n_microbatches, + ) + + stage_model_looped = [ + ManualPipelineStage( + module_list[rank], + stage_index=(world_size * i) + rank, + num_stages=num_stages_local * world_size, + device=device, + input_args=input_args, + num_microbatches=n_microbatches, + ) + for i in range(num_stages_local) + ] + elif kwargs["stage_type"] == "tracing": + pass + # TODO + # pipe = pipeline(model, n_microbatches, example_args=(input_args,)) + # stage = PipelineStage(pipe, rank, device) + + print(f"{[sm.stage_index for sm in stage_model_looped]}") + + x_cuda_empty = torch.empty_like(x, device="cuda") + + microbatches = [] + for i in range(n_microbatches): + unused = torch.ones((1, 1), device="cuda") + microbatches.append([torch.randn_like(x_cuda_empty)]) + + # Define a loss function + loss_fn = torch.nn.MSELoss(reduction="sum") + target_mbs = [ + torch.randn(microbatch_size, output_dim, device=device) + for _ in range(n_microbatches) + ] + + _run_profiler = kwargs["profiler"] + _trace_dir = kwargs["trace_dir"] + + my_schedule = ScheduleFlexibleInterleaved1F1B( + stage_model_looped, n_microbatches, loss_fn + ) + + if _run_profiler: + logger.info(f"====== Rank {rank} profile ======") + + with maybe_run_profiler( + _run_profiler, _trace_dir, schedule, rank + ) as _torch_profiler: + with record_function(schedule): + if rank == 0: + my_schedule._step_microbatches(microbatches) + elif rank == world_size - 1: + losses = [] + output = my_schedule._step_microbatches( + target_mbs=target_mbs, losses=losses + ) + else: + my_schedule._step_microbatches() + print(f"====== Rank {rank} finished {schedule} ======") + + +def main_wrapper(rank, local_rank, world_size, kwargs): + rank = int(rank) + world_size = int(world_size) + if local_rank is None: + local_rank = rank + local_rank = int(local_rank) + + os.environ["RANK"] = str(rank) + main(rank=rank, local_rank=local_rank, world_size=world_size, **kwargs) + + +def set_up_logging(rank, log_level): + """Set up logging""" + numeric_level = getattr(logging, log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError("Invalid log level: %s" % log_level) + logging.basicConfig(level=numeric_level) + + # TODO: seeing double logging due to global logging setup in + # - fx/passes/utils/matcher_utils.py + + # class FstringFormatter(logging.Formatter): + # def format(self, record): + # return f"[{rank}][{record.levelname}][{self.formatTime(record)}][{os.path.basename(__file__)}:{record.lineno}]:{record.getMessage()}" + + # formatter = FstringFormatter() + # handler.setFormatter(formatter) + # logger.addHandler(handler) + + +if __name__ == "__main__": + rank = os.environ.get("RANK", None) + local_rank = os.environ.get("LOCAL_RANK", None) + world_size = os.environ.get("WORLD_SIZE", None) + master_addr = os.environ.get("MASTER_ADDR", None) + master_port = os.environ.get("MASTER_PORT", None) + + parser = argparse.ArgumentParser(description="Pipeline Stages Runner") + parser.add_argument( + "-l", "--log_level", help="Set the logging level.", default="DEBUG" + ) + parser.add_argument("--profiler", type=bool, default=True) + parser.add_argument("--trace_dir", type=str, default="./traces") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--stage_type", type=str, default="manual") + args = parser.parse_args() + kwargs = vars(args) + + if ( + rank is None + or local_rank is None + or world_size is None + or master_addr is None + ): + # single host code path + master_port = "23456" + master_addr = "localhost" + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + n_gpus = 2 + world_size = n_gpus + os.environ["WORLD_SIZE"] = str(world_size) + print( + f"Torchrun was not used. Spawning {world_size} processes on {master_addr}:{master_port}" + ) + mp.spawn( + main_wrapper, + args=( + None, + world_size, + kwargs, + ), + nprocs=world_size, + ) + else: + # multihost code path (ran with torchrun) + main_wrapper(rank, local_rank, world_size, kwargs) diff --git a/test/test_pipeline_schedule_e2e.py b/test/test_pipeline_schedule_e2e.py index 4efd22455..ce21089c9 100644 --- a/test/test_pipeline_schedule_e2e.py +++ b/test/test_pipeline_schedule_e2e.py @@ -10,12 +10,16 @@ with torchrun (1x2, 1 host with 2 processes): torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:29500 --nnodes=1 --nproc-per-node=2 test_pipeline_schedule_e2e.py +for testing flexible pp schedules, use the (1x4, 1 host with 4 processes) config otherwise will fall back to interleaved 1f1b +torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:29500 --nnodes=1 --nproc-per-node=4 test_pipeline_schedule_e2e.py -- --schedules flexible_interleaved_1f1b + MULTIPLE HOSTS: torchrun --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR --nnodes=$NUM_NODES --nproc-per-node=$NUM_TRAINERS test_pipeline_schedule_e2e.py e.g. (2x2, 2 hosts with 2 processes) torchrun --rdzv-backend=c10d --rdzv-endpoint=node1.example.com:29400 --nnodes=2 --nproc-per-node=2 test_pipeline_schedule_e2e.py + """ import argparse @@ -32,6 +36,7 @@ ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleLoopedBFS, + ScheduleFlexibleInterleaved1F1B, ) from torch.distributed._tensor.device_mesh import init_device_mesh @@ -153,6 +158,19 @@ def rank_print(msg): rank_print(f"kwargs are {kwargs}") + microbatch_size = 8 + global_batch_size = 64 + # Flexible PP has to be tested on a different n_microbatches value + flex_pp_global_batch_size = 80 + + num_stages_local = 2 + + assert global_batch_size % microbatch_size == 0 + assert flex_pp_global_batch_size % microbatch_size == 0 + + n_microbatches = int(global_batch_size / microbatch_size) + n_microbatches_flex_pp = int(flex_pp_global_batch_size / microbatch_size) + input_dim = 900 hidden_dim = 800 output_dim = 900 @@ -164,10 +182,6 @@ def rank_print(msg): module_list = torch.nn.ModuleList( modules=[model for i in range(world_size)] ) - microbatch_size = 8 - global_batch_size = 64 - assert global_batch_size % microbatch_size == 0 - n_microbatches = int(global_batch_size / microbatch_size) x = torch.randn([microbatch_size, input_dim]).to("cpu") unused = torch.ones((1, 1), device="meta") @@ -176,7 +190,7 @@ def rank_print(msg): if kwargs["stage_type"] == "manual": stage_model = ManualPipelineStage( module_list[rank], - stage_id=rank, + stage_index=rank, num_stages=world_size, device=device, input_args=input_args, @@ -186,14 +200,27 @@ def rank_print(msg): stage_model_looped = [ ManualPipelineStage( module_list[rank], - stage_id=(world_size * i) + rank, - num_stages=world_size * world_size, + stage_index=(world_size * i) + rank, + num_stages=num_stages_local * world_size, device=device, input_args=input_args, num_microbatches=n_microbatches, ) - for i in range(world_size) + for i in range(num_stages_local) + ] + + flex_pp_stage_model_looped = [ + ManualPipelineStage( + module_list[rank], + stage_index=(world_size * i) + rank, + num_stages=num_stages_local * world_size, + device=device, + input_args=input_args, + num_microbatches=n_microbatches_flex_pp, + ) + for i in range(num_stages_local) ] + elif kwargs["stage_type"] == "tracing": pass # TODO @@ -205,20 +232,34 @@ def rank_print(msg): x_cuda_empty = torch.empty_like(x, device="cuda") microbatches = [] + flex_pp_microbatches = [] + for i in range(n_microbatches): unused = torch.ones((1, 1), device="cuda") microbatches.append([torch.randn_like(x_cuda_empty)]) + for i in range(n_microbatches_flex_pp): + unused = torch.ones((1, 1), device="cuda") + flex_pp_microbatches.append([torch.randn_like(x_cuda_empty)]) + # Define a loss function loss_fn = torch.nn.MSELoss(reduction="sum") target_mbs = [ torch.randn(microbatch_size, output_dim, device=device) for _ in range(n_microbatches) ] + target_mbs_flex_pp = [ + torch.randn(microbatch_size, output_dim, device=device) + for _ in range(n_microbatches_flex_pp) + ] _run_profiler = kwargs["profiler"] _trace_dir = kwargs["trace_dir"] for schedule in kwargs["schedules"]: + + current_microbatches = microbatches + current_target_mbs = target_mbs + logger.info(f"====== Rank {rank} running schedule {schedule} ======") if schedule == "gpipe": my_schedule = ScheduleGPipe(stage_model, n_microbatches, loss_fn) @@ -232,6 +273,12 @@ def rank_print(msg): my_schedule = ScheduleInterleaved1F1B( stage_model_looped, n_microbatches, loss_fn ) + elif schedule == "flexible_interleaved_1f1b": + my_schedule = ScheduleFlexibleInterleaved1F1B( + flex_pp_stage_model_looped, n_microbatches_flex_pp, loss_fn + ) + current_microbatches = flex_pp_microbatches + current_target_mbs = target_mbs_flex_pp if _run_profiler: logger.info(f"====== Rank {rank} profile ======") @@ -241,11 +288,11 @@ def rank_print(msg): ) as _torch_profiler: with record_function(schedule): if rank == 0: - my_schedule._step_microbatches(microbatches) + my_schedule._step_microbatches(current_microbatches) elif rank == world_size - 1: losses = [] output = my_schedule._step_microbatches( - target_mbs=target_mbs, losses=losses + target_mbs=current_target_mbs, losses=losses ) else: my_schedule._step_microbatches() @@ -299,7 +346,7 @@ def set_up_logging(rank, log_level): "--schedules", type=str, nargs="+", - choices=["gpipe", "1f1b", "looped_bfs", "interleaved_1f1b"], + choices=["gpipe", "1f1b", "looped_bfs", "interleaved_1f1b", "flexible_interleaved_1f1b"], default=["interleaved_1f1b"], ) parser.add_argument("--device", type=str, default="cuda") From b1a186fe53efba4ee33f720205cd2919e3a2d1af Mon Sep 17 00:00:00 2001 From: Haoci Zhang Date: Sun, 23 Jun 2024 12:41:30 -0700 Subject: [PATCH 3/4] Modified init file Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- pippy/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pippy/__init__.py b/pippy/__init__.py index 3edd41913..76c98d6f1 100644 --- a/pippy/__init__.py +++ b/pippy/__init__.py @@ -20,6 +20,7 @@ ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleLoopedBFS, + ScheduleFlexibleInterleaved1F1B, ) @@ -37,6 +38,7 @@ "ScheduleGPipe", "ScheduleInterleaved1F1B", "ScheduleLoopedBFS", + "ScheduleFlexibleInterleaved1F1B", "ManualPipelineStage", "ArgsChunkSpec", "KwargsChunkSpec", From c526f0f0d41cbcfe1083fb7891a9ad6981b3bd90 Mon Sep 17 00:00:00 2001 From: Haoci Zhang Date: Sun, 23 Jun 2024 12:45:52 -0700 Subject: [PATCH 4/4] Delete test_flex_pipeline_schedule_e2e.py --- test/test_flex_pipeline_schedule_e2e.py | 325 ------------------------ 1 file changed, 325 deletions(-) delete mode 100644 test/test_flex_pipeline_schedule_e2e.py diff --git a/test/test_flex_pipeline_schedule_e2e.py b/test/test_flex_pipeline_schedule_e2e.py deleted file mode 100644 index d8a401b83..000000000 --- a/test/test_flex_pipeline_schedule_e2e.py +++ /dev/null @@ -1,325 +0,0 @@ -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -""" -SINGLE HOST: - -python test_pipeline_schedule_e2e.py - -or - -with torchrun (1x2, 1 host with 2 processes): -torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:29500 --nnodes=1 --nproc-per-node=2 test_pipeline_schedule_e2e.py - -MULTIPLE HOSTS: - -torchrun --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR --nnodes=$NUM_NODES --nproc-per-node=$NUM_TRAINERS test_pipeline_schedule_e2e.py - -e.g. (2x2, 2 hosts with 2 processes) -torchrun --rdzv-backend=c10d --rdzv-endpoint=node1.example.com:29400 --nnodes=2 --nproc-per-node=2 test_pipeline_schedule_e2e.py -""" - -import argparse -import logging -import os -from contextlib import contextmanager, nullcontext - -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from pippy import ( - ManualPipelineStage, - Schedule1F1B, - ScheduleGPipe, - ScheduleInterleaved1F1B, - ScheduleLoopedBFS, - ScheduleFlexibleInterleaved1F1B, -) - -from torch.distributed._tensor.device_mesh import init_device_mesh -from torch.profiler import record_function - -logger = logging.getLogger(__name__) - - -# profiling context manager -@contextmanager -def maybe_run_profiler( - use_profiler, trace_dir, schedule, rank, *args, **kwargs -): - def trace_handler(prof): - if rank == 0: - (f"about to EXPORT traces for {schedule} to {trace_dir}") - prof.export_chrome_trace( - f"{trace_dir}/{schedule}_rank{rank}_trace.json" - ) - - if use_profiler: - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - # schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1), - on_trace_ready=trace_handler, - profile_memory=True, - with_stack=False, - record_shapes=True, - ) as torch_profiler: - yield torch_profiler - else: - torch_profiler = nullcontext() - yield None - - -class MLP(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - out_dim: int, - ): - super().__init__() - self.wi = nn.Linear(dim, hidden_dim, bias=False) - self.wh1 = nn.Linear(hidden_dim, hidden_dim, bias=False) - self.wh2 = nn.Linear(hidden_dim, hidden_dim, bias=False) - self.wh3 = nn.Linear(hidden_dim, hidden_dim, bias=False) - self.wo = nn.Linear(hidden_dim, out_dim, bias=False) - self.gelu_act = nn.GELU(approximate="tanh") - - # Apply Kaiming initialization - nn.init.kaiming_normal_(self.wi.weight) - nn.init.kaiming_normal_(self.wh1.weight) - nn.init.kaiming_normal_(self.wh3.weight) - nn.init.kaiming_normal_(self.wh3.weight) - nn.init.kaiming_normal_(self.wo.weight) - - def forward(self, x): - a = self.wi(x) - a = self.wh1(a) - a = self.wh2(a) - a = self.wh3(a) - b = self.gelu_act(a) - c = self.wo(b) - return c - - -def setup(local_rank, log_level, init_process_group=False): - # If this is a child process (i.e., its PID is not the same as the PID of the process that started this script) - if os.getppid() != os.getpid(): - set_up_logging(local_rank, log_level) - - # initialize the process group (not needed if using device_mesh) - if init_process_group: - logger.info(f"init for rank {local_rank}") - if torch.distributed.is_initialized(): - torch.cuda.set_device(local_rank) - logger.info(f"finish init for rank {local_rank}") - - -def main(**kwargs): - torch.manual_seed(42) - device = torch.device(kwargs["device"]) - - local_rank = kwargs["local_rank"] - world_size = kwargs["world_size"] - - def rank_print(msg): - if rank == 0: - print(f"{msg}") - - # use device mesh for cuda - create a device mesh based on the given world_size. - device_mesh = None - rank = None - - if device.type == "cuda": - device_mesh = init_device_mesh( - device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),) - ) - - rank = device_mesh.get_rank() - rank_print(f"using {device_mesh=}") - - if not device_mesh: - rank = kwargs["rank"] - - init_process_group = bool(device_mesh is not None) - - setup( - local_rank, kwargs["log_level"], init_process_group=init_process_group - ) - - print( - f"====== World Rank {rank}, Local Rank {local_rank}, World Size {world_size}, Device {device} main ======" - ) - - print(f"kwargs are {kwargs}") - - microbatch_size = 8 - global_batch_size = 80 - num_stages_local = 2 - assert global_batch_size % microbatch_size == 0 - n_microbatches = int(global_batch_size / microbatch_size) - - input_dim = 900 - hidden_dim = 800 - output_dim = 900 - - model = torch.nn.Sequential( - *[MLP(input_dim, hidden_dim, output_dim) for _ in range(8)] - ) - - module_list = torch.nn.ModuleList( - modules=[model for i in range(world_size)] - ) - - print(f"{n_microbatches=}") - x = torch.randn([microbatch_size, input_dim]).to("cpu") - unused = torch.ones((1, 1), device="meta") - input_args = x - - if kwargs["stage_type"] == "manual": - stage_model = ManualPipelineStage( - module_list[rank], - stage_index=rank, - num_stages=world_size, - device=device, - input_args=input_args, - num_microbatches=n_microbatches, - ) - - stage_model_looped = [ - ManualPipelineStage( - module_list[rank], - stage_index=(world_size * i) + rank, - num_stages=num_stages_local * world_size, - device=device, - input_args=input_args, - num_microbatches=n_microbatches, - ) - for i in range(num_stages_local) - ] - elif kwargs["stage_type"] == "tracing": - pass - # TODO - # pipe = pipeline(model, n_microbatches, example_args=(input_args,)) - # stage = PipelineStage(pipe, rank, device) - - print(f"{[sm.stage_index for sm in stage_model_looped]}") - - x_cuda_empty = torch.empty_like(x, device="cuda") - - microbatches = [] - for i in range(n_microbatches): - unused = torch.ones((1, 1), device="cuda") - microbatches.append([torch.randn_like(x_cuda_empty)]) - - # Define a loss function - loss_fn = torch.nn.MSELoss(reduction="sum") - target_mbs = [ - torch.randn(microbatch_size, output_dim, device=device) - for _ in range(n_microbatches) - ] - - _run_profiler = kwargs["profiler"] - _trace_dir = kwargs["trace_dir"] - - my_schedule = ScheduleFlexibleInterleaved1F1B( - stage_model_looped, n_microbatches, loss_fn - ) - - if _run_profiler: - logger.info(f"====== Rank {rank} profile ======") - - with maybe_run_profiler( - _run_profiler, _trace_dir, schedule, rank - ) as _torch_profiler: - with record_function(schedule): - if rank == 0: - my_schedule._step_microbatches(microbatches) - elif rank == world_size - 1: - losses = [] - output = my_schedule._step_microbatches( - target_mbs=target_mbs, losses=losses - ) - else: - my_schedule._step_microbatches() - print(f"====== Rank {rank} finished {schedule} ======") - - -def main_wrapper(rank, local_rank, world_size, kwargs): - rank = int(rank) - world_size = int(world_size) - if local_rank is None: - local_rank = rank - local_rank = int(local_rank) - - os.environ["RANK"] = str(rank) - main(rank=rank, local_rank=local_rank, world_size=world_size, **kwargs) - - -def set_up_logging(rank, log_level): - """Set up logging""" - numeric_level = getattr(logging, log_level.upper(), None) - if not isinstance(numeric_level, int): - raise ValueError("Invalid log level: %s" % log_level) - logging.basicConfig(level=numeric_level) - - # TODO: seeing double logging due to global logging setup in - # - fx/passes/utils/matcher_utils.py - - # class FstringFormatter(logging.Formatter): - # def format(self, record): - # return f"[{rank}][{record.levelname}][{self.formatTime(record)}][{os.path.basename(__file__)}:{record.lineno}]:{record.getMessage()}" - - # formatter = FstringFormatter() - # handler.setFormatter(formatter) - # logger.addHandler(handler) - - -if __name__ == "__main__": - rank = os.environ.get("RANK", None) - local_rank = os.environ.get("LOCAL_RANK", None) - world_size = os.environ.get("WORLD_SIZE", None) - master_addr = os.environ.get("MASTER_ADDR", None) - master_port = os.environ.get("MASTER_PORT", None) - - parser = argparse.ArgumentParser(description="Pipeline Stages Runner") - parser.add_argument( - "-l", "--log_level", help="Set the logging level.", default="DEBUG" - ) - parser.add_argument("--profiler", type=bool, default=True) - parser.add_argument("--trace_dir", type=str, default="./traces") - parser.add_argument("--device", type=str, default="cuda") - parser.add_argument("--stage_type", type=str, default="manual") - args = parser.parse_args() - kwargs = vars(args) - - if ( - rank is None - or local_rank is None - or world_size is None - or master_addr is None - ): - # single host code path - master_port = "23456" - master_addr = "localhost" - os.environ["MASTER_ADDR"] = master_addr - os.environ["MASTER_PORT"] = master_port - n_gpus = 2 - world_size = n_gpus - os.environ["WORLD_SIZE"] = str(world_size) - print( - f"Torchrun was not used. Spawning {world_size} processes on {master_addr}:{master_port}" - ) - mp.spawn( - main_wrapper, - args=( - None, - world_size, - kwargs, - ), - nprocs=world_size, - ) - else: - # multihost code path (ran with torchrun) - main_wrapper(rank, local_rank, world_size, kwargs)