-
Notifications
You must be signed in to change notification settings - Fork 8
Grad scaling parity between both pp and non-pp #251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,6 +93,7 @@ def build_pipeline_schedule( | |
| n_microbatches=n_microbatches, | ||
| loss_fn=loss_fn, | ||
| backward_requires_autograd=backward_requires_autograd, | ||
| scale_grads=False, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are you proposing to land this change, or just using it while doing numerics tests?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'll change this to only be for numerics test mode maybe? it's either turn this off or emulate the scaling for non-pp |
||
| ) | ||
| logger.info( | ||
| f"Using pipeline schedule {pipeline_parallel_schedule} " | ||
|
|
@@ -166,9 +167,9 @@ def run_test( | |
| # This is the spmd mesh to be used for tracing | ||
| mesh = world_mesh[("dp_mod_ep", "ep")] | ||
|
|
||
| global_batch_size = 32 * dp_degree | ||
| # Batch size that will be supplied to the schedule and will be broken down into microbatches | ||
| local_batch_size = global_batch_size // dp_degree | ||
| local_batch_size = 32 | ||
| # global_batch_size = local_batch_size * dp_degree | ||
| n_microbatches = 16 | ||
| # Batch size with which the spmd graphs will actually be executed | ||
| microbatch_size = local_batch_size // n_microbatches | ||
|
|
@@ -472,10 +473,6 @@ def last_stage_inp_with_loss_fn(): | |
|
|
||
| world_size = torch.distributed.get_world_size() | ||
| num_world_stages = world_size * len(stage_mods) | ||
| if rng_seed is not None: | ||
| NumericsLogger(logs_dir).log_pp_model_weights( | ||
| model, stage_mods, num_world_stages, ranks=[0, 4] | ||
| ) | ||
|
|
||
| stages = [] | ||
| # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata | ||
|
|
@@ -500,6 +497,7 @@ def last_stage_inp_with_loss_fn(): | |
| group=world_mesh.get_group("pp"), | ||
| ) | ||
| stages.append(stage) | ||
|
|
||
| # Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank | ||
| schedule = build_pipeline_schedule( | ||
| stages=stages, | ||
|
|
@@ -511,9 +509,32 @@ def last_stage_inp_with_loss_fn(): | |
| backward_requires_autograd=False, | ||
| ) | ||
| assert isinstance(schedule, _PipelineScheduleRuntime) | ||
|
|
||
| if rng_seed is not None: | ||
| numerics_logger = NumericsLogger(logs_dir) | ||
| numerics_logger.log_pp_model_weights( | ||
| model, stage_mods, num_world_stages, ranks=[0, 4] | ||
| ) | ||
| torch.manual_seed(rng_seed) | ||
|
|
||
| def last_stage_forward_hook( | ||
| stage: GraphPipelineStage, action: str, output: torch.Tensor | ||
| ): | ||
| if not stage.is_last or rng_seed is None: | ||
| return | ||
|
|
||
| rank = torch.distributed.get_rank() | ||
| if rank == 4: | ||
| numerics_logger.log_diff( | ||
| output, rank=4, prefix=f"mb{action.microbatch_index} fwd out" | ||
| ) | ||
|
|
||
| # Step 6. Override the pipeline runner's action implementations | ||
| schedule.register_custom_function( | ||
| FORWARD, functools.partial(stage_forward, numerics_logs=None) | ||
| FORWARD, | ||
| functools.partial( | ||
| stage_forward, numerics_logs=None, forward_hook=last_stage_forward_hook | ||
| ), | ||
| ) | ||
| schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) | ||
| schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad) | ||
|
|
@@ -542,6 +563,10 @@ def last_stage_inp_with_loss_fn(): | |
| ) | ||
| if pp_rank == 0: | ||
| x = runtime_input_fn_first_stage() | ||
| if rng_seed: | ||
| numerics_logger.log_diff( | ||
| x.to(torch.float32), prefix="full batch input" | ||
| ) | ||
| graph_pp_runner.step( | ||
| x, target=target, losses=losses, return_outputs=False | ||
| ) | ||
|
|
@@ -556,6 +581,8 @@ def last_stage_inp_with_loss_fn(): | |
| payload_fn=lambda: f"losses: {losses}", | ||
| ) | ||
|
|
||
| numerics_logger.log_pp_grads(model, stage_mods, num_world_stages, ranks=[0, 4]) | ||
|
|
||
| print("All good!") | ||
|
|
||
| if torch.distributed.is_initialized(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no longer needed after rebase