Skip to content

Commit 10d8208

Browse files
authored
Compare microbatch forward outputs and gradients (#246)
stack-info: PR: #246, branch: xmfan/stack/20
1 parent 06d5f49 commit 10d8208

File tree

5 files changed

+171
-33
lines changed

5 files changed

+171
-33
lines changed

autoparallel/api.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import copy
7+
import functools
78
import itertools
89
import warnings
910
from contextlib import ExitStack, contextmanager
@@ -42,7 +43,11 @@
4243
)
4344
from .init_weights import hook_params_setters
4445
from .optimize_sharding import ShardingOptimizer
45-
from .utils import _get_device_from_mesh
46+
from .utils import (
47+
NumericsLogger,
48+
_get_device_from_mesh,
49+
debug_boxed_nop_preserve_node_meta,
50+
)
4651

4752
_APPLY_VIEW_MM_VIEW_PATTERN = False
4853

@@ -212,6 +217,7 @@ def __init__(
212217
reshard_after_forward: bool = True,
213218
dynamic: bool = False,
214219
loss_fn: Optional[Callable] = None,
220+
numerics_logger: NumericsLogger | None = None,
215221
**kwargs,
216222
):
217223
self.stack = ExitStack()
@@ -239,7 +245,14 @@ def __init__(
239245
self.model = move_to_fake(model, self.fake_mode, device)
240246
self.input_fn = input_fn
241247
self.mesh = mesh
242-
self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta
248+
if compile:
249+
self.compiler_fn = compile_fx_inner
250+
elif numerics_logger:
251+
self.compiler_fn = functools.partial(
252+
debug_boxed_nop_preserve_node_meta, numerics_logger=numerics_logger
253+
)
254+
else:
255+
self.compiler_fn = boxed_nop_preserve_node_meta # type: ignore[assignment]
243256
self.enable_ac = enable_ac
244257
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
245258
self.reshard_after_forward = reshard_after_forward

autoparallel/graph_pp_runner.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from torch.distributed.tensor import DTensor
2424

25-
from autoparallel.utils import DebugInterpreter
25+
from autoparallel.utils import DebugInterpreter, NumericsLogger
2626

2727
logger = logging.getLogger(__name__)
2828
logger.setLevel(logging.DEBUG)
@@ -89,6 +89,8 @@ def __init__(
8989
output_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None,
9090
group: Optional[torch.distributed.ProcessGroup] = None,
9191
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
92+
numerics_logger: Optional[NumericsLogger] = None,
93+
should_log_fw_outs: bool = False,
9294
):
9395
super().__init__(
9496
submodule=submodule,
@@ -100,6 +102,8 @@ def __init__(
100102
group=group,
101103
dw_builder=dw_builder,
102104
)
105+
self.numerics_logger = numerics_logger
106+
self.should_log_fw_outs = should_log_fw_outs
103107
self.graph_callables = graph_callables
104108
self.graph_meta = graph_meta
105109
self.state: dict[str, list[Any]] = {
@@ -376,6 +380,7 @@ def _prepare_fwd_args(
376380

377381

378382
def _post_fwd_common(
383+
action: _Action,
379384
stage: GraphPipelineStage,
380385
mb_index: int,
381386
output: Any,
@@ -411,6 +416,14 @@ def _post_fwd_common(
411416
if ctx.target_mbs is not None:
412417
ctx.schedule_ref._internal_losses.append(output)
413418

419+
if stage.should_log_fw_outs:
420+
assert stage.numerics_logger is not None
421+
stage.numerics_logger.log_diff(
422+
output,
423+
rank=torch.distributed.get_rank(),
424+
prefix=f"mb{action.microbatch_index} fwd out",
425+
)
426+
414427
stage.fwd_cache[mb_index] = (output_tuple, saved_intermediates) # type: ignore[assignment]
415428

416429
stage._validate_fwd_outputs(output_tuple)
@@ -450,6 +463,7 @@ def stage_forward(
450463
)
451464

452465
_post_fwd_common(
466+
action,
453467
stage,
454468
mb_index,
455469
output,
@@ -793,6 +807,7 @@ def overlap_fw_bw(
793807
bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads)
794808

795809
_post_fwd_common(
810+
action,
796811
fw_stage,
797812
fw_mb_index,
798813
output,

autoparallel/utils.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str):
341341
continue
342342

343343
self._logs.append(
344-
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}"
344+
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)} nan={torch.any(torch.isnan(arg))}"
345345
)
346346

347347
def run_node(self, n: torch.fx.Node) -> Any:
@@ -429,13 +429,27 @@ def log_model_weights(self, parallel_mod):
429429

430430
print(f"Weight hashes written to {path}")
431431

432-
def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
432+
def log_fw_intermediates(self, logs):
433+
rank = torch.distributed.get_rank()
434+
path = self.dir / f"rank_{rank}_fw_intermediates.log"
435+
with open(path, "a") as f:
436+
f.write("\n".join(logs) + "\n")
437+
438+
def log_diff(self, t, rank=0, prefix="?"):
439+
if self.rank == rank:
440+
path = self.dir / "diff.log"
441+
if isinstance(t, torch.distributed.tensor.DTensor):
442+
t = t.to_local()
443+
with open(path, "a") as f:
444+
f.write(f"[{prefix}] hash={hash_tensor(t)}, norm={torch.norm(t)}\n")
445+
446+
def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, should_log):
433447
path = self.dir / "pp_weights.log"
434448

435449
torch.distributed.barrier()
436450
# First print the params of every stage
437451
for i in range(num_world_stages):
438-
if self.rank in ranks and i in stage_mods:
452+
if should_log and i in stage_mods:
439453
param_logs = []
440454
real_params = dict(stage_mods[i].named_parameters())
441455
for name, _ in orig_mod.named_parameters():
@@ -449,7 +463,7 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
449463

450464
# Then print the buffers of every stage
451465
for i in range(num_world_stages):
452-
if self.rank in ranks and i in stage_mods:
466+
if should_log and i in stage_mods:
453467
buffer_logs = []
454468
real_buffers = dict(stage_mods[i].named_buffers())
455469
for name, _ in orig_mod.named_buffers():
@@ -463,3 +477,39 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
463477

464478
if self.rank == 0:
465479
print(f"Weight hashes written to {path}")
480+
481+
def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, should_log):
482+
path = self.dir / "diff.log"
483+
484+
for i in range(num_world_stages):
485+
if should_log and i in stage_mods:
486+
grad_logs = []
487+
real_params = dict(stage_mods[i].named_parameters())
488+
for name, _ in orig_mod.named_parameters():
489+
if name not in real_params:
490+
continue
491+
grad = real_params[name].grad
492+
if grad is None:
493+
grad_logs.append(f"[grad {name}] None")
494+
else:
495+
grad = grad.to_local()
496+
grad_logs.append(
497+
f"[grad {name}] hash={hash_tensor(grad)}, norm={torch.norm(grad)}"
498+
)
499+
with open(path, "a") as f:
500+
f.write("\n".join(grad_logs) + "\n")
501+
torch.distributed.barrier()
502+
503+
504+
def debug_boxed_nop_preserve_node_meta(fx_g, example_inputs, numerics_logger):
505+
def run(args):
506+
with torch.fx.traceback.preserve_node_meta():
507+
interp = DebugInterpreter(fx_g)
508+
out = interp.boxed_run(args)
509+
mylogs = interp.get_logs()
510+
if numerics_logger:
511+
numerics_logger.log_fw_intermediates(mylogs)
512+
return out
513+
514+
run._boxed_call = True
515+
return run

examples/example_ds3_local_map.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str):
118118
mscale=0.70,
119119
)
120120

121-
bs = 4 * mesh.shape[0] * mesh.shape[1]
121+
local_batch_size = 2
122+
global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1]
122123
device = torch.device(f"cuda:{local_rank}")
123124

124125
# parallelize the model
@@ -129,11 +130,16 @@ def input_fn():
129130
return torch.randint(
130131
0,
131132
config.vocab_size,
132-
(bs, seq_len),
133+
(global_batch_size, seq_len),
133134
device=device,
134135
)
135136

136-
with AutoParallel(model, input_fn, mesh, dynamic=True) as autop:
137+
numerics_logger = None
138+
if rng_seed is not None:
139+
numerics_logger = NumericsLogger(logs_dir)
140+
with AutoParallel(
141+
model, input_fn, mesh, dynamic=True, numerics_logger=None
142+
) as autop:
137143
autop.add_parameter_memory_constraint(low=None, high=None)
138144

139145
# x_sharding = (Shard(0), Replicate())
@@ -153,17 +159,22 @@ def input_fn():
153159
# ) # maybe not correct value
154160
parallel_mod.init_weights(buffer_device=device, seed=rng_seed)
155161
if rng_seed is not None:
156-
numerics_logger = NumericsLogger(logs_dir)
157162
numerics_logger.log_model_weights(parallel_mod)
158-
159-
x = (
160-
torch.randint(
161-
0,
162-
config.vocab_size,
163-
(bs // mesh.shape[0] // mesh.shape[1], seq_len),
164-
device=device,
165-
),
163+
torch.manual_seed(rng_seed)
164+
165+
n_microbatches = 16
166+
full_batch = torch.randint(
167+
0,
168+
config.vocab_size,
169+
(local_batch_size * n_microbatches, seq_len),
170+
device=device,
166171
)
172+
microbatches = torch.split(full_batch, local_batch_size, dim=0)
173+
assert len(microbatches) == n_microbatches
174+
if rng_seed:
175+
numerics_logger.log_diff(
176+
full_batch.to(torch.float32), prefix="full batch input"
177+
)
167178

168179
# Symbolically evaluate in case you want to test running a graph bigger than your gpu
169180
if fake_evaluate:
@@ -173,15 +184,22 @@ def input_fn():
173184
allow_non_fake_inputs=True,
174185
shape_env=shape_env,
175186
):
176-
# # now let's run it
177-
out = parallel_mod(*x)
178-
out.backward(torch.randn_like(out))
187+
# now let's run it
188+
for x in microbatches:
189+
out = parallel_mod(x)
190+
out.backward(torch.ones_like(out))
179191
else:
180-
out = parallel_mod(*x)
181-
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
192+
for i, x in enumerate(microbatches):
193+
assert x.shape[0] == 2
194+
out = parallel_mod(x)
195+
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
196+
out.backward(torch.ones_like(out))
197+
if rng_seed is not None:
198+
numerics_logger.log_diff(out, prefix=f"mb{i} fwd out")
199+
182200
if rng_seed is not None:
183-
numerics_logger.log_forward_output(out)
184-
out.backward(torch.randn_like(out))
201+
for k, v in parallel_mod.named_parameters():
202+
numerics_logger.log_diff(v.grad, prefix=f"grad {k}")
185203

186204
print("All good!")
187205

0 commit comments

Comments
 (0)