Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 74 additions & 2 deletions tests/pytorch/debug/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.debug.pytorch.debug_state import TEDebugState

import math

fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_sanity(feature_dirs):


@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_numerics(fp8_recipe, feature_dirs):
def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling():
Expand Down Expand Up @@ -206,6 +206,78 @@ def test_numerics(fp8_recipe, feature_dirs):
assert overflows == pytest.approx(expected.cpu(), abs=1e-4)


LOG_HIGH_PRECISION_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
stats:
- dynamic_range
- max_blockwise_dynamic_range:
block_size: 4
dims: 1
- max_blockwise_dynamic_range:
block_size: 4
dims: 2
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""


def test_log_stats_numerics(feature_dirs):
"""Check corectness of dynamic range and max blockwise dynamic range stats"""
stats = ["dynamic_range", "max_blockwise_4_dynamic_range"]
log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG_BASE.format(stats=", ".join(stats))

with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir:
# There is 1024 x 1024 tensor with very small epsilon values in almost all elements,
# one row of large value A and three rows of large value B.
epsilon = 1e-10
A = 1000
B = 50
tensor = torch.zeros(1024, 1024).cuda() + epsilon
tensor[0, :] = A
tensor[1:4, :] = B

debug_api.transformer_engine.inspect_tensor(
layer_name="layer_name",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=None,
rowwise_quantized_tensor=None,
columnwise_quantized_tensor=None,
)
debug_api.step()

output = read_log(log_dir)

for line in output.splitlines():
if "max_blockwise_dynamic_range_block_size_4_dims_1" in line:
max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1])
expected = 0
assert max_blockwise_dynamic_range_block_size_4_dims_1 == pytest.approx(
expected, abs=1e-4
)
elif "max_blockwise_dynamic_range_block_size_4_dims_2" in line:
max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1])
expected = math.log2(A) - math.log2(B)
assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx(
expected, abs=1e-4
)
elif "dynamic_range" in line:
dynamic_range = float(line.split("value=")[1])
expected = math.log2(A) - math.log2(epsilon)
assert dynamic_range == pytest.approx(expected, abs=1e-4)


@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
if not fp8_available:
Expand Down
65 changes: 60 additions & 5 deletions transformer_engine/debug/features/log_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

"""LogTensorStats Feature support for nvidia-dlframework-inspect"""

from typing import Dict, Optional
from typing import Dict, Optional, List

import torch

Expand All @@ -19,6 +19,9 @@
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params
from transformer_engine.debug.features.utils.stats_computation import (
add_max_blockwise_dynamic_range_stats,
)


@Registry.register_feature(namespace="transformer_engine")
Expand All @@ -44,7 +47,12 @@ class LogTensorStats(BaseLogTensorStats):
- l1_norm
- l2_norm
- cur_amax – maximal absolute value of a tensor,
- dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)`
- dynamic_range – equal to `torch.log2(amax) - torch.log2(nonzero_amin)`
- max_blockwise_dynamic_range – Computes the maximum dynamic range `log2(amax) - log2(nonzero_amin)` across all blocks of size block_size within the tensor, where block_size is an integer specifying the block size. For `dim=1` there are block_size consecutive elements in the block, for `dim=2` the block is block_size x block_size elements tile.

- block_size: int, default = 32
- dims: int, default = 1, allowed values are 1 and 2

tensors/tensors_struct: List[str]
list of tensors to log

Expand Down Expand Up @@ -88,6 +96,51 @@ class LogTensorStats(BaseLogTensorStats):
stats: [dynamic_range]
"""

def _is_supported_stat(self, stat: str | Dict):
"""Returns True if the stat is supported by this feature, False otherwise."""
if isinstance(stat, dict):
stat_name = list(stat.keys())[0]
if stat_name == "max_blockwise_dynamic_range":
stat_dict = stat[stat_name]
if not isinstance(stat_dict, dict):
return False
# Ensure only supported keys are present
allowed_keys = {"block_size", "dims"}
if any(k not in allowed_keys for k in stat_dict.keys()):
return False
block_size = stat_dict.get("block_size", 32)
dims = stat_dict.get("dims", 1)
# Type and value validation
if not isinstance(block_size, int) or not isinstance(dims, int):
return False
if block_size > 0 and dims in [1, 2]:
return True
return False
return stat in BaseLogTensorStats._get_supported_stats_list(None) | {
"cur_amax",
"dynamic_range",
}

def _parse_max_blockwise_dynamic_range_stats(self, stats: List[str | Dict]) -> List[str]:
"""
Adds all max_blockwise_dynamic_range stats to the stat computation logic.
Changes the types of the stats from Dict to str, for other stats nothing is changed.
For example, if the stats is [{"max_blockwise_dynamic_range": {"block_size": 32, "dims": 1}}],
it will be changed to ["max_blockwise_dynamic_range_block_size_32_dims_1"].
"""
parsed_stats = []
for stat in stats:
if isinstance(stat, dict):
block_size = stat["max_blockwise_dynamic_range"].get("block_size", 32)
dims = stat["max_blockwise_dynamic_range"].get("dims", 1)
add_max_blockwise_dynamic_range_stats(block_size, dims)
parsed_stats.append(
f"max_blockwise_dynamic_range_block_size_{block_size}_dims_{dims}"
)
else:
parsed_stats.append(stat)
return parsed_stats

def _get_supported_stats_list(self):
"""Returns stats this feature can log."""
return BaseLogTensorStats._get_supported_stats_list(None) | {"cur_amax", "dynamic_range"}
Expand Down Expand Up @@ -141,14 +194,16 @@ def inspect_tensor(
)

for stat in config["stats"]:
assert (
stat in self._get_supported_stats_list()
assert self._is_supported_stat(
stat
), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported."

stats = self._parse_max_blockwise_dynamic_range_stats(config["stats"])

STATS_BUFFERS.try_add_buffer(
layer_name=layer_name,
tensor_name=tensor_name,
stats=config["stats"],
stats=stats,
options=options,
reduction_group=reduction_group,
reduce_within_microbatch=reduce_within_microbatch,
Expand Down
67 changes: 67 additions & 0 deletions transformer_engine/debug/features/utils/stats_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def _compute_dynamic_range_top(tensor):
return torch.log2(amax)


@torch.compile
def _compute_dynamic_range_bottom(tensor):
"""Computes the log2 of the amin of the tensor"""
tensor_abs = tensor.abs()
Expand All @@ -37,6 +38,54 @@ def _compute_dynamic_range_bottom(tensor):
return torch.log2(amin)


@torch.compile
def compute_max_blockwise_dynamic_range(tensor, block_size, dims):
"""
Max blockwise dynamic range (log2 max/min_nonzero).
Returns 0 if all blocks are zeros.
Otherwise computes dynamic range over non-zero blocks.

For dims = 1 blocks contain block_size consecutive elements,
for dims = 2 blocks contain block_size x block_size elements.
"""
total_numel = tensor.numel()
assert dims in [1, 2], f"dims must be 1 or 2, got {dims}"

# torch.compile friendly code - standard ** power does not work with jit
total_block_size = block_size * block_size if dims == 2 else block_size
assert (
total_numel % total_block_size == 0
), f"Tensor numel ({total_numel}) is not divisible by block_size ({block_size})."

tensor = tensor.abs().float()
if dims == 1:
tensor = tensor.reshape(-1, block_size)
per_block_amax = tensor.amax(dim=1)
per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=1)
else:
# We want to have tensor of shape [nr_blocks, block_size, block_size],
# where each block is a block_size x block_size tile of the original tensor.
dim_x = tensor.shape[-2] // block_size
dim_y = tensor.shape[-1] // block_size
tensor = (
tensor.reshape(-1, dim_x, block_size, dim_y, block_size)
.permute(0, 1, 3, 2, 4)
.reshape(-1, block_size, block_size)
)
per_block_amax = tensor.amax(dim=(1, 2))
per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=(1, 2))

# Identify blocks that contain any non-zero element
nonzero_blocks = per_block_amax != 0
dynamic_range_per_block = torch.where(
nonzero_blocks,
torch.log2(per_block_amax) - torch.log2(per_block_amin),
torch.zeros_like(per_block_amax, dtype=torch.float32),
)
return dynamic_range_per_block.max()


@torch.compile
def compute_variance(variances, numels, sums):
"""Welford algorithm is used for numerically stable distributed variance computation."""
mean = torch.sum(sums) / torch.sum(numels)
Expand All @@ -45,6 +94,7 @@ def compute_variance(variances, numels, sums):
return var


@torch.compile
def compute_std(variances, numels, sums):
"""Computates standard deviation."""
return torch.sqrt(compute_variance(variances, numels, sums))
Expand Down Expand Up @@ -316,6 +366,23 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False):
DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"}


def add_max_blockwise_dynamic_range_stats(block_size: int, dims: int):
"""Register max_blockwise_X_dynamic_range stats for the recipe."""
stat_name = f"max_blockwise_dynamic_range_block_size_{block_size}_dims_{dims}"
if stat_name in stats_to_num:
return # already registered
assert dims in [1, 2], f"dims must be 1 or 2, got {dims}"
stats_to_num[stat_name] = len(stats_to_num)
DEPENDENCIES[stat_name] = {stat_name}

STATS[stat_name] = (
lambda x, aux_dict, _block_size=block_size, _dims=dims: compute_max_blockwise_dynamic_range(
x, _block_size, _dims
),
lambda buffers: max(_get(buffers, stat_name)),
)


for _columnwise in [True, False]:
for _recipe_name in [
"", # default recipe
Expand Down