Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import copy
import functools
import itertools
import warnings
from contextlib import ExitStack, contextmanager
Expand Down Expand Up @@ -42,7 +43,11 @@
)
from .init_weights import hook_params_setters
from .optimize_sharding import ShardingOptimizer
from .utils import _get_device_from_mesh
from .utils import (
NumericsLogger,
_get_device_from_mesh,
debug_boxed_nop_preserve_node_meta,
)

_APPLY_VIEW_MM_VIEW_PATTERN = False

Expand Down Expand Up @@ -212,6 +217,7 @@ def __init__(
reshard_after_forward: bool = True,
dynamic: bool = False,
loss_fn: Optional[Callable] = None,
numerics_logger: NumericsLogger | None = None,
**kwargs,
):
self.stack = ExitStack()
Expand Down Expand Up @@ -239,7 +245,14 @@ def __init__(
self.model = move_to_fake(model, self.fake_mode, device)
self.input_fn = input_fn
self.mesh = mesh
self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta
if compile:
self.compiler_fn = compile_fx_inner
elif numerics_logger:
self.compiler_fn = functools.partial(
debug_boxed_nop_preserve_node_meta, numerics_logger=numerics_logger
)
else:
self.compiler_fn = boxed_nop_preserve_node_meta # type: ignore[assignment]
self.enable_ac = enable_ac
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
self.reshard_after_forward = reshard_after_forward
Expand Down
17 changes: 16 additions & 1 deletion autoparallel/graph_pp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from torch.distributed.tensor import DTensor

from autoparallel.utils import DebugInterpreter
from autoparallel.utils import DebugInterpreter, NumericsLogger

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -89,6 +89,8 @@ def __init__(
output_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None,
group: Optional[torch.distributed.ProcessGroup] = None,
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
numerics_logger: Optional[NumericsLogger] = None,
should_log_fw_outs: bool = False,
):
super().__init__(
submodule=submodule,
Expand All @@ -100,6 +102,8 @@ def __init__(
group=group,
dw_builder=dw_builder,
)
self.numerics_logger = numerics_logger
self.should_log_fw_outs = should_log_fw_outs
self.graph_callables = graph_callables
self.graph_meta = graph_meta
self.state: dict[str, list[Any]] = {
Expand Down Expand Up @@ -376,6 +380,7 @@ def _prepare_fwd_args(


def _post_fwd_common(
action: _Action,
stage: GraphPipelineStage,
mb_index: int,
output: Any,
Expand Down Expand Up @@ -411,6 +416,14 @@ def _post_fwd_common(
if ctx.target_mbs is not None:
ctx.schedule_ref._internal_losses.append(output)

if stage.should_log_fw_outs:
assert stage.numerics_logger is not None
stage.numerics_logger.log_diff(
output,
rank=torch.distributed.get_rank(),
prefix=f"mb{action.microbatch_index} fwd out",
)

stage.fwd_cache[mb_index] = (output_tuple, saved_intermediates) # type: ignore[assignment]

stage._validate_fwd_outputs(output_tuple)
Expand Down Expand Up @@ -450,6 +463,7 @@ def stage_forward(
)

_post_fwd_common(
action,
stage,
mb_index,
output,
Expand Down Expand Up @@ -793,6 +807,7 @@ def overlap_fw_bw(
bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads)

_post_fwd_common(
action,
fw_stage,
fw_mb_index,
output,
Expand Down
58 changes: 54 additions & 4 deletions autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str):
continue

self._logs.append(
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}"
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)} nan={torch.any(torch.isnan(arg))}"
)

def run_node(self, n: torch.fx.Node) -> Any:
Expand Down Expand Up @@ -429,13 +429,27 @@ def log_model_weights(self, parallel_mod):

print(f"Weight hashes written to {path}")

def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
def log_fw_intermediates(self, logs):
rank = torch.distributed.get_rank()
path = self.dir / f"rank_{rank}_fw_intermediates.log"
with open(path, "a") as f:
f.write("\n".join(logs) + "\n")

def log_diff(self, t, rank=0, prefix="?"):
if self.rank == rank:
path = self.dir / "diff.log"
if isinstance(t, torch.distributed.tensor.DTensor):
t = t.to_local()
with open(path, "a") as f:
f.write(f"[{prefix}] hash={hash_tensor(t)}, norm={torch.norm(t)}\n")

def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, should_log):
path = self.dir / "pp_weights.log"

torch.distributed.barrier()
# First print the params of every stage
for i in range(num_world_stages):
if self.rank in ranks and i in stage_mods:
if should_log and i in stage_mods:
param_logs = []
real_params = dict(stage_mods[i].named_parameters())
for name, _ in orig_mod.named_parameters():
Expand All @@ -449,7 +463,7 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):

# Then print the buffers of every stage
for i in range(num_world_stages):
if self.rank in ranks and i in stage_mods:
if should_log and i in stage_mods:
buffer_logs = []
real_buffers = dict(stage_mods[i].named_buffers())
for name, _ in orig_mod.named_buffers():
Expand All @@ -463,3 +477,39 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):

if self.rank == 0:
print(f"Weight hashes written to {path}")

def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, should_log):
path = self.dir / "diff.log"

for i in range(num_world_stages):
if should_log and i in stage_mods:
grad_logs = []
real_params = dict(stage_mods[i].named_parameters())
for name, _ in orig_mod.named_parameters():
if name not in real_params:
continue
grad = real_params[name].grad
if grad is None:
grad_logs.append(f"[grad {name}] None")
else:
grad = grad.to_local()
grad_logs.append(
f"[grad {name}] hash={hash_tensor(grad)}, norm={torch.norm(grad)}"
)
with open(path, "a") as f:
f.write("\n".join(grad_logs) + "\n")
torch.distributed.barrier()


def debug_boxed_nop_preserve_node_meta(fx_g, example_inputs, numerics_logger):
def run(args):
with torch.fx.traceback.preserve_node_meta():
interp = DebugInterpreter(fx_g)
out = interp.boxed_run(args)
mylogs = interp.get_logs()
if numerics_logger:
numerics_logger.log_fw_intermediates(mylogs)
return out

run._boxed_call = True
return run
56 changes: 37 additions & 19 deletions examples/example_ds3_local_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str):
mscale=0.70,
)

bs = 4 * mesh.shape[0] * mesh.shape[1]
local_batch_size = 2
global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1]
device = torch.device(f"cuda:{local_rank}")

# parallelize the model
Expand All @@ -129,11 +130,16 @@ def input_fn():
return torch.randint(
0,
config.vocab_size,
(bs, seq_len),
(global_batch_size, seq_len),
device=device,
)

with AutoParallel(model, input_fn, mesh, dynamic=True) as autop:
numerics_logger = None
if rng_seed is not None:
numerics_logger = NumericsLogger(logs_dir)
with AutoParallel(
model, input_fn, mesh, dynamic=True, numerics_logger=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be numerics_logger = numerics_logger?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's too noisy, it logs the intermediates for each op in the graph. i haven't thought of how to address it yet

) as autop:
autop.add_parameter_memory_constraint(low=None, high=None)

# x_sharding = (Shard(0), Replicate())
Expand All @@ -153,17 +159,22 @@ def input_fn():
# ) # maybe not correct value
parallel_mod.init_weights(buffer_device=device, seed=rng_seed)
if rng_seed is not None:
numerics_logger = NumericsLogger(logs_dir)
numerics_logger.log_model_weights(parallel_mod)

x = (
torch.randint(
0,
config.vocab_size,
(bs // mesh.shape[0] // mesh.shape[1], seq_len),
device=device,
),
torch.manual_seed(rng_seed)

n_microbatches = 16
full_batch = torch.randint(
0,
config.vocab_size,
(local_batch_size * n_microbatches, seq_len),
device=device,
)
microbatches = torch.split(full_batch, local_batch_size, dim=0)
assert len(microbatches) == n_microbatches
if rng_seed:
numerics_logger.log_diff(
full_batch.to(torch.float32), prefix="full batch input"
)

# Symbolically evaluate in case you want to test running a graph bigger than your gpu
if fake_evaluate:
Expand All @@ -173,15 +184,22 @@ def input_fn():
allow_non_fake_inputs=True,
shape_env=shape_env,
):
# # now let's run it
out = parallel_mod(*x)
out.backward(torch.randn_like(out))
# now let's run it
for x in microbatches:
out = parallel_mod(x)
out.backward(torch.ones_like(out))
else:
out = parallel_mod(*x)
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
for i, x in enumerate(microbatches):
assert x.shape[0] == 2
out = parallel_mod(x)
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
out.backward(torch.ones_like(out))
if rng_seed is not None:
numerics_logger.log_diff(out, prefix=f"mb{i} fwd out")

if rng_seed is not None:
numerics_logger.log_forward_output(out)
out.backward(torch.randn_like(out))
for k, v in parallel_mod.named_parameters():
numerics_logger.log_diff(v.grad, prefix=f"grad {k}")

print("All good!")

Expand Down
Loading
Loading