diff --git a/autoparallel/api.py b/autoparallel/api.py index c9995aa..e902132 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import copy +import functools import itertools import warnings from contextlib import ExitStack, contextmanager @@ -42,7 +43,11 @@ ) from .init_weights import hook_params_setters from .optimize_sharding import ShardingOptimizer -from .utils import _get_device_from_mesh +from .utils import ( + NumericsLogger, + _get_device_from_mesh, + debug_boxed_nop_preserve_node_meta, +) _APPLY_VIEW_MM_VIEW_PATTERN = False @@ -212,6 +217,7 @@ def __init__( reshard_after_forward: bool = True, dynamic: bool = False, loss_fn: Optional[Callable] = None, + numerics_logger: NumericsLogger | None = None, **kwargs, ): self.stack = ExitStack() @@ -239,7 +245,14 @@ def __init__( self.model = move_to_fake(model, self.fake_mode, device) self.input_fn = input_fn self.mesh = mesh - self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta + if compile: + self.compiler_fn = compile_fx_inner + elif numerics_logger: + self.compiler_fn = functools.partial( + debug_boxed_nop_preserve_node_meta, numerics_logger=numerics_logger + ) + else: + self.compiler_fn = boxed_nop_preserve_node_meta # type: ignore[assignment] self.enable_ac = enable_ac self.ac_stage_size_in_GiB = ac_stage_size_in_GiB self.reshard_after_forward = reshard_after_forward diff --git a/autoparallel/graph_pp_runner.py b/autoparallel/graph_pp_runner.py index 3ae965c..ed3a782 100644 --- a/autoparallel/graph_pp_runner.py +++ b/autoparallel/graph_pp_runner.py @@ -22,7 +22,7 @@ ) from torch.distributed.tensor import DTensor -from autoparallel.utils import DebugInterpreter +from autoparallel.utils import DebugInterpreter, NumericsLogger logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -89,6 +89,8 @@ def __init__( output_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None, group: Optional[torch.distributed.ProcessGroup] = None, dw_builder: Optional[Callable[[], Callable[..., None]]] = None, + numerics_logger: Optional[NumericsLogger] = None, + should_log_fw_outs: bool = False, ): super().__init__( submodule=submodule, @@ -100,6 +102,8 @@ def __init__( group=group, dw_builder=dw_builder, ) + self.numerics_logger = numerics_logger + self.should_log_fw_outs = should_log_fw_outs self.graph_callables = graph_callables self.graph_meta = graph_meta self.state: dict[str, list[Any]] = { @@ -376,6 +380,7 @@ def _prepare_fwd_args( def _post_fwd_common( + action: _Action, stage: GraphPipelineStage, mb_index: int, output: Any, @@ -411,6 +416,14 @@ def _post_fwd_common( if ctx.target_mbs is not None: ctx.schedule_ref._internal_losses.append(output) + if stage.should_log_fw_outs: + assert stage.numerics_logger is not None + stage.numerics_logger.log_diff( + output, + rank=torch.distributed.get_rank(), + prefix=f"mb{action.microbatch_index} fwd out", + ) + stage.fwd_cache[mb_index] = (output_tuple, saved_intermediates) # type: ignore[assignment] stage._validate_fwd_outputs(output_tuple) @@ -450,6 +463,7 @@ def stage_forward( ) _post_fwd_common( + action, stage, mb_index, output, @@ -793,6 +807,7 @@ def overlap_fw_bw( bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads) _post_fwd_common( + action, fw_stage, fw_mb_index, output, diff --git a/autoparallel/utils.py b/autoparallel/utils.py index b96f4db..6dfe88b 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -341,7 +341,7 @@ def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str): continue self._logs.append( - f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}" + f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)} nan={torch.any(torch.isnan(arg))}" ) def run_node(self, n: torch.fx.Node) -> Any: @@ -429,13 +429,27 @@ def log_model_weights(self, parallel_mod): print(f"Weight hashes written to {path}") - def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks): + def log_fw_intermediates(self, logs): + rank = torch.distributed.get_rank() + path = self.dir / f"rank_{rank}_fw_intermediates.log" + with open(path, "a") as f: + f.write("\n".join(logs) + "\n") + + def log_diff(self, t, rank=0, prefix="?"): + if self.rank == rank: + path = self.dir / "diff.log" + if isinstance(t, torch.distributed.tensor.DTensor): + t = t.to_local() + with open(path, "a") as f: + f.write(f"[{prefix}] hash={hash_tensor(t)}, norm={torch.norm(t)}\n") + + def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, should_log): path = self.dir / "pp_weights.log" torch.distributed.barrier() # First print the params of every stage for i in range(num_world_stages): - if self.rank in ranks and i in stage_mods: + if should_log and i in stage_mods: param_logs = [] real_params = dict(stage_mods[i].named_parameters()) for name, _ in orig_mod.named_parameters(): @@ -449,7 +463,7 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks): # Then print the buffers of every stage for i in range(num_world_stages): - if self.rank in ranks and i in stage_mods: + if should_log and i in stage_mods: buffer_logs = [] real_buffers = dict(stage_mods[i].named_buffers()) for name, _ in orig_mod.named_buffers(): @@ -463,3 +477,39 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks): if self.rank == 0: print(f"Weight hashes written to {path}") + + def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, should_log): + path = self.dir / "diff.log" + + for i in range(num_world_stages): + if should_log and i in stage_mods: + grad_logs = [] + real_params = dict(stage_mods[i].named_parameters()) + for name, _ in orig_mod.named_parameters(): + if name not in real_params: + continue + grad = real_params[name].grad + if grad is None: + grad_logs.append(f"[grad {name}] None") + else: + grad = grad.to_local() + grad_logs.append( + f"[grad {name}] hash={hash_tensor(grad)}, norm={torch.norm(grad)}" + ) + with open(path, "a") as f: + f.write("\n".join(grad_logs) + "\n") + torch.distributed.barrier() + + +def debug_boxed_nop_preserve_node_meta(fx_g, example_inputs, numerics_logger): + def run(args): + with torch.fx.traceback.preserve_node_meta(): + interp = DebugInterpreter(fx_g) + out = interp.boxed_run(args) + mylogs = interp.get_logs() + if numerics_logger: + numerics_logger.log_fw_intermediates(mylogs) + return out + + run._boxed_call = True + return run diff --git a/examples/example_ds3_local_map.py b/examples/example_ds3_local_map.py index d4e6b91..45fdc2f 100644 --- a/examples/example_ds3_local_map.py +++ b/examples/example_ds3_local_map.py @@ -118,7 +118,8 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): mscale=0.70, ) - bs = 4 * mesh.shape[0] * mesh.shape[1] + local_batch_size = 2 + global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1] device = torch.device(f"cuda:{local_rank}") # parallelize the model @@ -129,11 +130,16 @@ def input_fn(): return torch.randint( 0, config.vocab_size, - (bs, seq_len), + (global_batch_size, seq_len), device=device, ) - with AutoParallel(model, input_fn, mesh, dynamic=True) as autop: + numerics_logger = None + if rng_seed is not None: + numerics_logger = NumericsLogger(logs_dir) + with AutoParallel( + model, input_fn, mesh, dynamic=True, numerics_logger=None + ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) # x_sharding = (Shard(0), Replicate()) @@ -153,17 +159,22 @@ def input_fn(): # ) # maybe not correct value parallel_mod.init_weights(buffer_device=device, seed=rng_seed) if rng_seed is not None: - numerics_logger = NumericsLogger(logs_dir) numerics_logger.log_model_weights(parallel_mod) - - x = ( - torch.randint( - 0, - config.vocab_size, - (bs // mesh.shape[0] // mesh.shape[1], seq_len), - device=device, - ), + torch.manual_seed(rng_seed) + + n_microbatches = 16 + full_batch = torch.randint( + 0, + config.vocab_size, + (local_batch_size * n_microbatches, seq_len), + device=device, ) + microbatches = torch.split(full_batch, local_batch_size, dim=0) + assert len(microbatches) == n_microbatches + if rng_seed: + numerics_logger.log_diff( + full_batch.to(torch.float32), prefix="full batch input" + ) # Symbolically evaluate in case you want to test running a graph bigger than your gpu if fake_evaluate: @@ -173,15 +184,22 @@ def input_fn(): allow_non_fake_inputs=True, shape_env=shape_env, ): - # # now let's run it - out = parallel_mod(*x) - out.backward(torch.randn_like(out)) + # now let's run it + for x in microbatches: + out = parallel_mod(x) + out.backward(torch.ones_like(out)) else: - out = parallel_mod(*x) - assert not torch.any(torch.isnan(out)), "Found NaNs in forward output" + for i, x in enumerate(microbatches): + assert x.shape[0] == 2 + out = parallel_mod(x) + assert not torch.any(torch.isnan(out)), "Found NaNs in forward output" + out.backward(torch.ones_like(out)) + if rng_seed is not None: + numerics_logger.log_diff(out, prefix=f"mb{i} fwd out") + if rng_seed is not None: - numerics_logger.log_forward_output(out) - out.backward(torch.randn_like(out)) + for k, v in parallel_mod.named_parameters(): + numerics_logger.log_diff(v.grad, prefix=f"grad {k}") print("All good!") diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py index fbb7d0f..1f682fd 100644 --- a/examples/example_ds3_pp.py +++ b/examples/example_ds3_pp.py @@ -93,6 +93,7 @@ def build_pipeline_schedule( local_batch_size: int, pipeline_parallel_degree: int, backward_requires_autograd: bool = False, + scale_grads: bool = True, ) -> _PipelineSchedule: """Builds a pipeline schedule for the given configuration and stages.""" schedule_class = get_schedule_class(pipeline_parallel_schedule) @@ -118,6 +119,7 @@ def build_pipeline_schedule( n_microbatches=n_microbatches, loss_fn=loss_fn, backward_requires_autograd=backward_requires_autograd, + scale_grads=scale_grads, ) logger.info( f"Using pipeline schedule {pipeline_parallel_schedule} " @@ -195,9 +197,9 @@ def run_test( # This is the spmd mesh to be used for tracing mesh = world_mesh[("dp_mod_ep", "ep")] - global_batch_size = 32 * dp_degree # Batch size that will be supplied to the schedule and will be broken down into microbatches - local_batch_size = global_batch_size // dp_degree + local_batch_size = 32 + # global_batch_size = local_batch_size * dp_degree n_microbatches = 16 # Batch size with which the spmd graphs will actually be executed microbatch_size = local_batch_size // n_microbatches @@ -382,6 +384,31 @@ def last_stage_inp_with_loss_fn(): for stages in pp_rank_to_stage_indices.values(): assert len(stages) * pp_degree == len(virtual_pp_stages) stage_indices_current_pp_rank = pp_rank_to_stage_indices[pp_rank] + if rng_seed: + # Compute the ranks to log from + # 1. for fw_outs, log from coord [pp_rank_containing_last_stage, 0, 0] + last_stage_idx = total_pp_stages - 1 + pp_rank_containing_last_stage = None + for pp_rank_, stage_indices in pp_rank_to_stage_indices.items(): + if last_stage_idx in stage_indices: + assert pp_rank_containing_last_stage is None + pp_rank_containing_last_stage = pp_rank_ + + log_fw_out_rank_coordinate = [] + for mesh_dim_name in world_mesh.mesh_dim_names: + if mesh_dim_name == "pp": + log_fw_out_rank_coordinate.append(pp_rank_containing_last_stage) + else: + log_fw_out_rank_coordinate.append(0) + should_log_fw_outs = world_mesh.get_coordinate() == log_fw_out_rank_coordinate + + # 2. for weights, log from coords [:, 0, 0] + pp_world_size = world_mesh.shape[world_mesh._get_mesh_dim_by_name("pp")] + log_weights_rank_coordinates = [(i, 0, 0) for i in range(pp_world_size)] + should_log_weights = ( + tuple(world_mesh.get_coordinate()) in log_weights_rank_coordinates + ) + stage_mods: dict[int, torch.nn.Module] = {} stage_graphs: dict[int, GraphCallables] = {} stage_graph_metas: dict[int, GraphMeta] = {} @@ -507,10 +534,14 @@ def last_stage_inp_with_loss_fn(): world_size = torch.distributed.get_world_size() num_world_stages = world_size * len(stage_mods) + + numerics_logger = None if rng_seed is not None: - NumericsLogger(logs_dir).log_pp_model_weights( - model, stage_mods, num_world_stages, ranks=[0, 4] + numerics_logger = NumericsLogger(logs_dir) + numerics_logger.log_pp_model_weights( + model, stage_mods, num_world_stages, should_log=should_log_weights ) + torch.manual_seed(rng_seed) stages = [] # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata @@ -533,8 +564,11 @@ def last_stage_inp_with_loss_fn(): else shape_inference_fn_intermediate_stage() ), group=world_mesh.get_group("pp"), + numerics_logger=numerics_logger, + should_log_fw_outs=should_log_fw_outs, ) stages.append(stage) + # Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank schedule = build_pipeline_schedule( stages=stages, @@ -544,12 +578,12 @@ def last_stage_inp_with_loss_fn(): local_batch_size=local_batch_size, pipeline_parallel_degree=pp_degree, backward_requires_autograd=False, + scale_grads=rng_seed is None, # In determinism mode, don't scale grads ) assert isinstance(schedule, _PipelineScheduleRuntime) + # Step 6. Override the pipeline runner's action implementations - schedule.register_custom_function( - FORWARD, functools.partial(stage_forward, numerics_logs=None) - ) + schedule.register_custom_function(FORWARD, stage_forward) schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad) schedule.register_custom_function(RESHARD, stage_reshard) @@ -584,6 +618,10 @@ def last_stage_inp_with_loss_fn(): ) if pp_rank == 0: x = runtime_input_fn_first_stage() + if rng_seed: + numerics_logger.log_diff( + x.to(torch.float32), prefix="full batch input" + ) graph_pp_runner.step( x, target=target, losses=losses, return_outputs=False ) @@ -598,6 +636,10 @@ def last_stage_inp_with_loss_fn(): payload_fn=lambda: f"losses: {losses}", ) + numerics_logger.log_pp_grads( + model, stage_mods, num_world_stages, should_log=should_log_weights + ) + print("All good!") if torch.distributed.is_initialized():