Skip to content

Commit b8546e1

Browse files
committed
Compare microbatch forward outputs and gradients
stack-info: PR: #246, branch: xmfan/stack/20
1 parent d17cb4c commit b8546e1

File tree

5 files changed

+174
-32
lines changed

5 files changed

+174
-32
lines changed

autoparallel/api.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import copy
7+
import functools
78
import itertools
89
import warnings
910
from contextlib import ExitStack, contextmanager
@@ -42,7 +43,11 @@
4243
)
4344
from .init_weights import hook_params_setters
4445
from .optimize_sharding import ShardingOptimizer
45-
from .utils import _get_device_from_mesh
46+
from .utils import (
47+
NumericsLogger,
48+
_get_device_from_mesh,
49+
debug_boxed_nop_preserve_node_meta,
50+
)
4651

4752
_APPLY_VIEW_MM_VIEW_PATTERN = False
4853

@@ -212,6 +217,7 @@ def __init__(
212217
reshard_after_forward: bool = True,
213218
dynamic: bool = False,
214219
loss_fn: Optional[Callable] = None,
220+
numerics_logger: NumericsLogger | None = None,
215221
**kwargs,
216222
):
217223
self.stack = ExitStack()
@@ -239,7 +245,14 @@ def __init__(
239245
self.model = move_to_fake(model, self.fake_mode, device)
240246
self.input_fn = input_fn
241247
self.mesh = mesh
242-
self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta
248+
if compile:
249+
self.compiler_fn = compile_fx_inner
250+
elif numerics_logger:
251+
self.compiler_fn = functools.partial(
252+
debug_boxed_nop_preserve_node_meta, numerics_logger=numerics_logger
253+
)
254+
else:
255+
self.compiler_fn = boxed_nop_preserve_node_meta # type: ignore[assignment]
243256
self.enable_ac = enable_ac
244257
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
245258
self.reshard_after_forward = reshard_after_forward

autoparallel/graph_pp_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def stage_forward(
214214
action: _Action,
215215
ctx: _PipelineContext,
216216
numerics_logs: Optional[list[str]] = None,
217+
forward_hook: Callable | None = None,
217218
) -> None:
218219
schedule, stage_index_to_stage, stage = _get_stage_from_action(action, ctx)
219220
stage_index = stage.stage_index
@@ -279,6 +280,8 @@ def stage_forward(
279280
stage.output_chunks.append(output)
280281
if ctx.target_mbs is not None:
281282
ctx.schedule_ref._internal_losses.append(output)
283+
if forward_hook:
284+
forward_hook(stage, action, output)
282285

283286
stage.fwd_cache[mb_index] = (output_tuple, saved_intermediates) # type: ignore[assignment]
284287

autoparallel/utils.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str):
341341
continue
342342

343343
self._logs.append(
344-
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}"
344+
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)} nan={torch.any(torch.isnan(arg))}"
345345
)
346346

347347
def run_node(self, n: torch.fx.Node) -> Any:
@@ -429,13 +429,27 @@ def log_model_weights(self, parallel_mod):
429429

430430
print(f"Weight hashes written to {path}")
431431

432-
def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
432+
def log_fw_intermediates(self, logs):
433+
rank = torch.distributed.get_rank()
434+
path = self.dir / f"rank_{rank}_fw_intermediates.log"
435+
with open(path, "a") as f:
436+
f.write("\n".join(logs) + "\n")
437+
438+
def log_diff(self, t, rank=0, prefix="?"):
439+
if self.rank == rank:
440+
path = self.dir / "diff.log"
441+
if isinstance(t, torch.distributed.tensor.DTensor):
442+
t = t.to_local()
443+
with open(path, "a") as f:
444+
f.write(f"[{prefix}] hash={hash_tensor(t)}, norm={torch.norm(t)}\n")
445+
446+
def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, should_log):
433447
path = self.dir / "pp_weights.log"
434448

435449
torch.distributed.barrier()
436450
# First print the params of every stage
437451
for i in range(num_world_stages):
438-
if self.rank in ranks and i in stage_mods:
452+
if should_log and i in stage_mods:
439453
param_logs = []
440454
real_params = dict(stage_mods[i].named_parameters())
441455
for name, _ in orig_mod.named_parameters():
@@ -449,7 +463,7 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
449463

450464
# Then print the buffers of every stage
451465
for i in range(num_world_stages):
452-
if self.rank in ranks and i in stage_mods:
466+
if should_log and i in stage_mods:
453467
buffer_logs = []
454468
real_buffers = dict(stage_mods[i].named_buffers())
455469
for name, _ in orig_mod.named_buffers():
@@ -463,3 +477,39 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
463477

464478
if self.rank == 0:
465479
print(f"Weight hashes written to {path}")
480+
481+
def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, should_log):
482+
path = self.dir / "diff.log"
483+
484+
for i in range(num_world_stages):
485+
if should_log and i in stage_mods:
486+
grad_logs = []
487+
real_params = dict(stage_mods[i].named_parameters())
488+
for name, _ in orig_mod.named_parameters():
489+
if name not in real_params:
490+
continue
491+
grad = real_params[name].grad
492+
if grad is None:
493+
grad_logs.append(f"[grad {name}] None")
494+
else:
495+
grad = grad.to_local()
496+
grad_logs.append(
497+
f"[grad {name}] hash={hash_tensor(grad)}, norm={torch.norm(grad)}"
498+
)
499+
with open(path, "a") as f:
500+
f.write("\n".join(grad_logs) + "\n")
501+
torch.distributed.barrier()
502+
503+
504+
def debug_boxed_nop_preserve_node_meta(fx_g, example_inputs, numerics_logger):
505+
def run(args):
506+
with torch.fx.traceback.preserve_node_meta():
507+
interp = DebugInterpreter(fx_g)
508+
out = interp.boxed_run(args)
509+
mylogs = interp.get_logs()
510+
if numerics_logger:
511+
numerics_logger.log_fw_intermediates(mylogs)
512+
return out
513+
514+
run._boxed_call = True
515+
return run

examples/example_ds3_local_map.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str):
118118
mscale=0.70,
119119
)
120120

121-
bs = 4 * mesh.shape[0] * mesh.shape[1]
121+
local_batch_size = 2
122+
global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1]
122123
device = torch.device(f"cuda:{local_rank}")
123124

124125
# parallelize the model
@@ -129,11 +130,16 @@ def input_fn():
129130
return torch.randint(
130131
0,
131132
config.vocab_size,
132-
(bs, seq_len),
133+
(global_batch_size, seq_len),
133134
device=device,
134135
)
135136

136-
with AutoParallel(model, input_fn, mesh, dynamic=True) as autop:
137+
numerics_logger = None
138+
if rng_seed is not None:
139+
numerics_logger = NumericsLogger(logs_dir)
140+
with AutoParallel(
141+
model, input_fn, mesh, dynamic=True, numerics_logger=None
142+
) as autop:
137143
autop.add_parameter_memory_constraint(low=None, high=None)
138144

139145
# x_sharding = (Shard(0), Replicate())
@@ -153,17 +159,22 @@ def input_fn():
153159
# ) # maybe not correct value
154160
parallel_mod.init_weights(buffer_device=device, seed=rng_seed)
155161
if rng_seed is not None:
156-
numerics_logger = NumericsLogger(logs_dir)
157162
numerics_logger.log_model_weights(parallel_mod)
158-
159-
x = (
160-
torch.randint(
161-
0,
162-
config.vocab_size,
163-
(bs // mesh.shape[0] // mesh.shape[1], seq_len),
164-
device=device,
165-
),
163+
torch.manual_seed(rng_seed)
164+
165+
n_microbatches = 16
166+
full_batch = torch.randint(
167+
0,
168+
config.vocab_size,
169+
(local_batch_size * n_microbatches, seq_len),
170+
device=device,
166171
)
172+
microbatches = torch.split(full_batch, local_batch_size, dim=0)
173+
assert len(microbatches) == n_microbatches
174+
if rng_seed:
175+
numerics_logger.log_diff(
176+
full_batch.to(torch.float32), prefix="full batch input"
177+
)
167178

168179
# Symbolically evaluate in case you want to test running a graph bigger than your gpu
169180
if fake_evaluate:
@@ -173,15 +184,22 @@ def input_fn():
173184
allow_non_fake_inputs=True,
174185
shape_env=shape_env,
175186
):
176-
# # now let's run it
177-
out = parallel_mod(*x)
178-
out.backward(torch.randn_like(out))
187+
# now let's run it
188+
for x in microbatches:
189+
out = parallel_mod(x)
190+
out.backward(torch.ones_like(out))
179191
else:
180-
out = parallel_mod(*x)
181-
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
192+
for i, x in enumerate(microbatches):
193+
assert x.shape[0] == 2
194+
out = parallel_mod(x)
195+
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
196+
out.backward(torch.ones_like(out))
197+
if rng_seed is not None:
198+
numerics_logger.log_diff(out, prefix=f"mb{i} fwd out")
199+
182200
if rng_seed is not None:
183-
numerics_logger.log_forward_output(out)
184-
out.backward(torch.randn_like(out))
201+
for k, v in parallel_mod.named_parameters():
202+
numerics_logger.log_diff(v.grad, prefix=f"grad {k}")
185203

186204
print("All good!")
187205

examples/example_ds3_pp.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def build_pipeline_schedule(
9090
local_batch_size: int,
9191
pipeline_parallel_degree: int,
9292
backward_requires_autograd: bool = False,
93+
scale_grads: bool = True,
9394
) -> _PipelineSchedule:
9495
"""Builds a pipeline schedule for the given configuration and stages."""
9596
schedule_class = get_schedule_class(pipeline_parallel_schedule)
@@ -115,6 +116,7 @@ def build_pipeline_schedule(
115116
n_microbatches=n_microbatches,
116117
loss_fn=loss_fn,
117118
backward_requires_autograd=backward_requires_autograd,
119+
scale_grads=scale_grads,
118120
)
119121
logger.info(
120122
f"Using pipeline schedule {pipeline_parallel_schedule} "
@@ -192,9 +194,9 @@ def run_test(
192194
# This is the spmd mesh to be used for tracing
193195
mesh = world_mesh[("dp_mod_ep", "ep")]
194196

195-
global_batch_size = 32 * dp_degree
196197
# Batch size that will be supplied to the schedule and will be broken down into microbatches
197-
local_batch_size = global_batch_size // dp_degree
198+
local_batch_size = 32
199+
# global_batch_size = local_batch_size * dp_degree
198200
n_microbatches = 16
199201
# Batch size with which the spmd graphs will actually be executed
200202
microbatch_size = local_batch_size // n_microbatches
@@ -379,6 +381,31 @@ def last_stage_inp_with_loss_fn():
379381
for stages in pp_rank_to_stage_indices.values():
380382
assert len(stages) * pp_degree == len(virtual_pp_stages)
381383
stage_indices_current_pp_rank = pp_rank_to_stage_indices[pp_rank]
384+
if rng_seed:
385+
# Compute the ranks to log from
386+
# 1. for fw_outs, log from coord [pp_rank_containing_last_stage, 0, 0]
387+
last_stage_idx = total_pp_stages - 1
388+
pp_rank_containing_last_stage = None
389+
for pp_rank_, stage_indices in pp_rank_to_stage_indices.items():
390+
if last_stage_idx in stage_indices:
391+
assert pp_rank_containing_last_stage is None
392+
pp_rank_containing_last_stage = pp_rank_
393+
394+
log_fw_out_rank_coordinate = []
395+
for mesh_dim_name in world_mesh.mesh_dim_names:
396+
if mesh_dim_name == "pp":
397+
log_fw_out_rank_coordinate.append(pp_rank_containing_last_stage)
398+
else:
399+
log_fw_out_rank_coordinate.append(0)
400+
should_log_fw_outs = world_mesh.get_coordinate() == log_fw_out_rank_coordinate
401+
402+
# 2. for weights, log from coords [:, 0, 0]
403+
pp_world_size = world_mesh.shape[world_mesh._get_mesh_dim_by_name("pp")]
404+
log_weights_rank_coordinates = [(i, 0, 0) for i in range(pp_world_size)]
405+
should_log_weights = (
406+
tuple(world_mesh.get_coordinate()) in log_weights_rank_coordinates
407+
)
408+
382409
stage_mods: dict[int, torch.nn.Module] = {}
383410
stage_graphs: dict[int, GraphCallables] = {}
384411
stage_graph_metas: dict[int, GraphMeta] = {}
@@ -504,10 +531,6 @@ def last_stage_inp_with_loss_fn():
504531

505532
world_size = torch.distributed.get_world_size()
506533
num_world_stages = world_size * len(stage_mods)
507-
if rng_seed is not None:
508-
NumericsLogger(logs_dir).log_pp_model_weights(
509-
model, stage_mods, num_world_stages, ranks=[0, 4]
510-
)
511534

512535
stages = []
513536
# Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata
@@ -532,6 +555,7 @@ def last_stage_inp_with_loss_fn():
532555
group=world_mesh.get_group("pp"),
533556
)
534557
stages.append(stage)
558+
535559
# Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank
536560
schedule = build_pipeline_schedule(
537561
stages=stages,
@@ -541,11 +565,37 @@ def last_stage_inp_with_loss_fn():
541565
local_batch_size=local_batch_size,
542566
pipeline_parallel_degree=pp_degree,
543567
backward_requires_autograd=False,
568+
scale_grads=rng_seed is None, # In determinism mode, don't scale grads
544569
)
545570
assert isinstance(schedule, _PipelineScheduleRuntime)
571+
572+
if rng_seed is not None:
573+
numerics_logger = NumericsLogger(logs_dir)
574+
numerics_logger.log_pp_model_weights(
575+
model, stage_mods, num_world_stages, should_log=should_log_weights
576+
)
577+
torch.manual_seed(rng_seed)
578+
579+
def last_stage_forward_hook(
580+
stage: GraphPipelineStage, action: str, output: torch.Tensor
581+
):
582+
if not stage.is_last or rng_seed is None:
583+
# hook is only for numerics mode
584+
return
585+
586+
if should_log_fw_outs:
587+
numerics_logger.log_diff(
588+
output,
589+
rank=torch.distributed.get_rank(),
590+
prefix=f"mb{action.microbatch_index} fwd out",
591+
)
592+
546593
# Step 6. Override the pipeline runner's action implementations
547594
schedule.register_custom_function(
548-
FORWARD, functools.partial(stage_forward, numerics_logs=None)
595+
FORWARD,
596+
functools.partial(
597+
stage_forward, numerics_logs=None, forward_hook=last_stage_forward_hook
598+
),
549599
)
550600
schedule.register_custom_function(FULL_BACKWARD, stage_full_backward)
551601
schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad)
@@ -576,6 +626,10 @@ def last_stage_inp_with_loss_fn():
576626
)
577627
if pp_rank == 0:
578628
x = runtime_input_fn_first_stage()
629+
if rng_seed:
630+
numerics_logger.log_diff(
631+
x.to(torch.float32), prefix="full batch input"
632+
)
579633
graph_pp_runner.step(
580634
x, target=target, losses=losses, return_outputs=False
581635
)
@@ -590,6 +644,10 @@ def last_stage_inp_with_loss_fn():
590644
payload_fn=lambda: f"losses: {losses}",
591645
)
592646

647+
numerics_logger.log_pp_grads(
648+
model, stage_mods, num_world_stages, should_log=should_log_weights
649+
)
650+
593651
print("All good!")
594652

595653
if torch.distributed.is_initialized():

0 commit comments

Comments
 (0)