diff --git a/autoparallel/api.py b/autoparallel/api.py index 621ee0f..226ff32 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import copy +import functools import itertools import warnings from contextlib import ExitStack, contextmanager @@ -42,7 +43,11 @@ ) from .init_weights import hook_params_setters from .optimize_sharding import ShardingOptimizer -from .utils import _get_device_from_mesh +from .utils import ( + NumericsLogger, + _get_device_from_mesh, + debug_boxed_nop_preserve_node_meta, +) _APPLY_VIEW_MM_VIEW_PATTERN = False @@ -212,6 +217,7 @@ def __init__( reshard_after_forward: bool = True, dynamic: bool = False, loss_fn: Optional[Callable] = None, + numerics_logger: NumericsLogger | None = None, **kwargs, ): self.stack = ExitStack() @@ -239,7 +245,14 @@ def __init__( self.model = move_to_fake(model, self.fake_mode, device) self.input_fn = input_fn self.mesh = mesh - self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta + if compile: + self.compiler_fn = compile_fx_inner + elif numerics_logger: + self.compiler_fn = functools.partial( + debug_boxed_nop_preserve_node_meta, numerics_logger=numerics_logger + ) + else: + self.compiler_fn = boxed_nop_preserve_node_meta # type: ignore[assignment] self.enable_ac = enable_ac self.ac_stage_size_in_GiB = ac_stage_size_in_GiB self.reshard_after_forward = reshard_after_forward diff --git a/autoparallel/graph_pp_runner.py b/autoparallel/graph_pp_runner.py index cc3f5a2..83194e4 100644 --- a/autoparallel/graph_pp_runner.py +++ b/autoparallel/graph_pp_runner.py @@ -190,12 +190,12 @@ def _accumulate_stage_grads( ) -> None: assert len(unsharded_grads) == len(grads_to_accumulate) assert not all(grad is None for grad in grads_to_accumulate), "All grads are None" - for unsharded_grad, grad_to_accumulate in zip(unsharded_grads, grads_to_accumulate): - if grad_to_accumulate is not None: - if unsharded_grad is None: - unsharded_grad = grad_to_accumulate + for i in range(len(unsharded_grads)): + if grads_to_accumulate[i] is not None: + if unsharded_grads[i] is None: + unsharded_grads[i] = grads_to_accumulate[i] else: - unsharded_grad += grad_to_accumulate + unsharded_grads[i] += grads_to_accumulate[i] def _run_forward_microbatch( @@ -240,6 +240,7 @@ def stage_forward( action: _Action, ctx: _PipelineContext, numerics_logs: Optional[list[str]] = None, + forward_hook: Callable | None = None, ) -> None: schedule = ctx.schedule_ref assert isinstance(schedule, _PipelineScheduleRuntime) @@ -305,6 +306,8 @@ def stage_forward( stage.output_chunks.append(output) if ctx.target_mbs is not None: ctx.schedule_ref._internal_losses.append(output) + if forward_hook: + forward_hook(stage, action, output) stage.fwd_cache[mb_index] = ( output_tuple, # stage_output @@ -371,6 +374,7 @@ def stage_full_backward( # next stage # TODO(sanketpurandare) # HACK till we have loss function, we populate the tangents here manually + assert len(stage_output) == 1 bwd_kwargs = { "stage_output": loss, "tangents": [torch.ones_like(stage_output[0])], diff --git a/autoparallel/utils.py b/autoparallel/utils.py index b96f4db..e67e8db 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -341,7 +341,7 @@ def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str): continue self._logs.append( - f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}" + f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)} nan={torch.any(torch.isnan(arg))}" ) def run_node(self, n: torch.fx.Node) -> Any: @@ -429,6 +429,20 @@ def log_model_weights(self, parallel_mod): print(f"Weight hashes written to {path}") + def log_fw_intermediates(self, logs): + rank = torch.distributed.get_rank() + path = self.dir / f"rank_{rank}_fw_intermediates.log" + with open(path, "a") as f: + f.write("\n".join(logs) + "\n") + + def log_diff(self, t, rank=0, prefix="?"): + if self.rank == rank: + path = self.dir / "diff.log" + if isinstance(t, torch.distributed.tensor.DTensor): + t = t.to_local() + with open(path, "a") as f: + f.write(f"[{prefix}] hash={hash_tensor(t)}, norm={torch.norm(t)}\n") + def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks): path = self.dir / "pp_weights.log" @@ -463,3 +477,40 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks): if self.rank == 0: print(f"Weight hashes written to {path}") + + def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, ranks): + path = self.dir / "diff.log" + + torch.distributed.barrier() + for i in range(num_world_stages): + if self.rank in ranks and i in stage_mods: + grad_logs = [] + real_params = dict(stage_mods[i].named_parameters()) + for name, _ in orig_mod.named_parameters(): + if name not in real_params: + continue + grad = real_params[name].grad + if grad is None: + grad_logs.append(f"[grad {name}] None") + else: + grad = grad.to_local() + grad_logs.append( + f"[grad {name}] hash={hash_tensor(grad)}, norm={torch.norm(grad)}" + ) + with open(path, "a") as f: + f.write("\n".join(grad_logs) + "\n") + torch.distributed.barrier() + + +def debug_boxed_nop_preserve_node_meta(fx_g, example_inputs, numerics_logger): + def run(args): + with torch.fx.traceback.preserve_node_meta(): + interp = DebugInterpreter(fx_g) + out = interp.boxed_run(args) + mylogs = interp.get_logs() + if numerics_logger: + numerics_logger.log_fw_intermediates(mylogs) + return out + + run._boxed_call = True + return run diff --git a/examples/example_ds3_local_map.py b/examples/example_ds3_local_map.py index d4e6b91..45fdc2f 100644 --- a/examples/example_ds3_local_map.py +++ b/examples/example_ds3_local_map.py @@ -118,7 +118,8 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): mscale=0.70, ) - bs = 4 * mesh.shape[0] * mesh.shape[1] + local_batch_size = 2 + global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1] device = torch.device(f"cuda:{local_rank}") # parallelize the model @@ -129,11 +130,16 @@ def input_fn(): return torch.randint( 0, config.vocab_size, - (bs, seq_len), + (global_batch_size, seq_len), device=device, ) - with AutoParallel(model, input_fn, mesh, dynamic=True) as autop: + numerics_logger = None + if rng_seed is not None: + numerics_logger = NumericsLogger(logs_dir) + with AutoParallel( + model, input_fn, mesh, dynamic=True, numerics_logger=None + ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) # x_sharding = (Shard(0), Replicate()) @@ -153,17 +159,22 @@ def input_fn(): # ) # maybe not correct value parallel_mod.init_weights(buffer_device=device, seed=rng_seed) if rng_seed is not None: - numerics_logger = NumericsLogger(logs_dir) numerics_logger.log_model_weights(parallel_mod) - - x = ( - torch.randint( - 0, - config.vocab_size, - (bs // mesh.shape[0] // mesh.shape[1], seq_len), - device=device, - ), + torch.manual_seed(rng_seed) + + n_microbatches = 16 + full_batch = torch.randint( + 0, + config.vocab_size, + (local_batch_size * n_microbatches, seq_len), + device=device, ) + microbatches = torch.split(full_batch, local_batch_size, dim=0) + assert len(microbatches) == n_microbatches + if rng_seed: + numerics_logger.log_diff( + full_batch.to(torch.float32), prefix="full batch input" + ) # Symbolically evaluate in case you want to test running a graph bigger than your gpu if fake_evaluate: @@ -173,15 +184,22 @@ def input_fn(): allow_non_fake_inputs=True, shape_env=shape_env, ): - # # now let's run it - out = parallel_mod(*x) - out.backward(torch.randn_like(out)) + # now let's run it + for x in microbatches: + out = parallel_mod(x) + out.backward(torch.ones_like(out)) else: - out = parallel_mod(*x) - assert not torch.any(torch.isnan(out)), "Found NaNs in forward output" + for i, x in enumerate(microbatches): + assert x.shape[0] == 2 + out = parallel_mod(x) + assert not torch.any(torch.isnan(out)), "Found NaNs in forward output" + out.backward(torch.ones_like(out)) + if rng_seed is not None: + numerics_logger.log_diff(out, prefix=f"mb{i} fwd out") + if rng_seed is not None: - numerics_logger.log_forward_output(out) - out.backward(torch.randn_like(out)) + for k, v in parallel_mod.named_parameters(): + numerics_logger.log_diff(v.grad, prefix=f"grad {k}") print("All good!") diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py index 8f75460..0726fd3 100644 --- a/examples/example_ds3_pp.py +++ b/examples/example_ds3_pp.py @@ -93,6 +93,7 @@ def build_pipeline_schedule( n_microbatches=n_microbatches, loss_fn=loss_fn, backward_requires_autograd=backward_requires_autograd, + scale_grads=False, ) logger.info( f"Using pipeline schedule {pipeline_parallel_schedule} " @@ -166,9 +167,9 @@ def run_test( # This is the spmd mesh to be used for tracing mesh = world_mesh[("dp_mod_ep", "ep")] - global_batch_size = 32 * dp_degree # Batch size that will be supplied to the schedule and will be broken down into microbatches - local_batch_size = global_batch_size // dp_degree + local_batch_size = 32 + # global_batch_size = local_batch_size * dp_degree n_microbatches = 16 # Batch size with which the spmd graphs will actually be executed microbatch_size = local_batch_size // n_microbatches @@ -472,10 +473,6 @@ def last_stage_inp_with_loss_fn(): world_size = torch.distributed.get_world_size() num_world_stages = world_size * len(stage_mods) - if rng_seed is not None: - NumericsLogger(logs_dir).log_pp_model_weights( - model, stage_mods, num_world_stages, ranks=[0, 4] - ) stages = [] # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata @@ -500,6 +497,7 @@ def last_stage_inp_with_loss_fn(): group=world_mesh.get_group("pp"), ) stages.append(stage) + # Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank schedule = build_pipeline_schedule( stages=stages, @@ -511,9 +509,32 @@ def last_stage_inp_with_loss_fn(): backward_requires_autograd=False, ) assert isinstance(schedule, _PipelineScheduleRuntime) + + if rng_seed is not None: + numerics_logger = NumericsLogger(logs_dir) + numerics_logger.log_pp_model_weights( + model, stage_mods, num_world_stages, ranks=[0, 4] + ) + torch.manual_seed(rng_seed) + + def last_stage_forward_hook( + stage: GraphPipelineStage, action: str, output: torch.Tensor + ): + if not stage.is_last or rng_seed is None: + return + + rank = torch.distributed.get_rank() + if rank == 4: + numerics_logger.log_diff( + output, rank=4, prefix=f"mb{action.microbatch_index} fwd out" + ) + # Step 6. Override the pipeline runner's action implementations schedule.register_custom_function( - FORWARD, functools.partial(stage_forward, numerics_logs=None) + FORWARD, + functools.partial( + stage_forward, numerics_logs=None, forward_hook=last_stage_forward_hook + ), ) schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad) @@ -542,6 +563,10 @@ def last_stage_inp_with_loss_fn(): ) if pp_rank == 0: x = runtime_input_fn_first_stage() + if rng_seed: + numerics_logger.log_diff( + x.to(torch.float32), prefix="full batch input" + ) graph_pp_runner.step( x, target=target, losses=losses, return_outputs=False ) @@ -556,6 +581,8 @@ def last_stage_inp_with_loss_fn(): payload_fn=lambda: f"losses: {losses}", ) + numerics_logger.log_pp_grads(model, stage_mods, num_world_stages, ranks=[0, 4]) + print("All good!") if torch.distributed.is_initialized():