diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index dcc9861c84..b79c94455b 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -14,7 +14,7 @@ import os from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.debug.pytorch.debug_state import TEDebugState - +import math fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -150,7 +150,7 @@ def test_sanity(feature_dirs): @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -def test_numerics(fp8_recipe, feature_dirs): +def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): if not fp8_available: pytest.skip(reason_for_no_fp8) if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling(): @@ -206,6 +206,78 @@ def test_numerics(fp8_recipe, feature_dirs): assert overflows == pytest.approx(expected.cpu(), abs=1e-4) +LOG_HIGH_PRECISION_CONFIG_BASE = """ +log: + layers: + layer_name_regex_pattern: .* + enabled: + True + transformer_engine: + LogTensorStats: + enabled: True + stats: + - dynamic_range + - max_blockwise_dynamic_range: + block_size: 4 + dims: 1 + - max_blockwise_dynamic_range: + block_size: 4 + dims: 2 + tensors: [activation, gradient, weight] + freq: 2 + start_step: 0 + end_step: 10 +""" + + +def test_log_stats_numerics(feature_dirs): + """Check corectness of dynamic range and max blockwise dynamic range stats""" + stats = ["dynamic_range", "max_blockwise_4_dynamic_range"] + log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG_BASE.format(stats=", ".join(stats)) + + with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir: + # There is 1024 x 1024 tensor with very small epsilon values in almost all elements, + # one row of large value A and three rows of large value B. + epsilon = 1e-10 + A = 1000 + B = 50 + tensor = torch.zeros(1024, 1024).cuda() + epsilon + tensor[0, :] = A + tensor[1:4, :] = B + + debug_api.transformer_engine.inspect_tensor( + layer_name="layer_name", + tensor_name="activation", + iteration=0, + tp_group=None, + tensor=tensor, + quantizer=None, + rowwise_quantized_tensor=None, + columnwise_quantized_tensor=None, + ) + debug_api.step() + + output = read_log(log_dir) + + for line in output.splitlines(): + if "max_blockwise_dynamic_range_block_size_4_dims_1" in line: + max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1]) + expected = 0 + assert max_blockwise_dynamic_range_block_size_4_dims_1 == pytest.approx( + expected, abs=1e-4 + ) + elif "max_blockwise_dynamic_range_block_size_4_dims_2" in line: + max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1]) + expected = math.log2(A) - math.log2(B) + assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx( + expected, abs=1e-4 + ) + elif "dynamic_range" in line: + dynamic_range = float(line.split("value=")[1]) + expected = math.log2(A) - math.log2(epsilon) + assert dynamic_range == pytest.approx(expected, abs=1e-4) + + @pytest.mark.parametrize("layer", ["linear", "transformer"]) def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): if not fp8_available: diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 7ba2f9f771..62a4dae41b 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -4,7 +4,7 @@ """LogTensorStats Feature support for nvidia-dlframework-inspect""" -from typing import Dict, Optional +from typing import Dict, Optional, List import torch @@ -19,6 +19,9 @@ from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params +from transformer_engine.debug.features.utils.stats_computation import ( + add_max_blockwise_dynamic_range_stats, +) @Registry.register_feature(namespace="transformer_engine") @@ -44,7 +47,12 @@ class LogTensorStats(BaseLogTensorStats): - l1_norm - l2_norm - cur_amax – maximal absolute value of a tensor, - - dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)` + - dynamic_range – equal to `torch.log2(amax) - torch.log2(nonzero_amin)` + - max_blockwise_dynamic_range – Computes the maximum dynamic range `log2(amax) - log2(nonzero_amin)` across all blocks of size block_size within the tensor, where block_size is an integer specifying the block size. For `dim=1` there are block_size consecutive elements in the block, for `dim=2` the block is block_size x block_size elements tile. + + - block_size: int, default = 32 + - dims: int, default = 1, allowed values are 1 and 2 + tensors/tensors_struct: List[str] list of tensors to log @@ -88,6 +96,51 @@ class LogTensorStats(BaseLogTensorStats): stats: [dynamic_range] """ + def _is_supported_stat(self, stat: str | Dict): + """Returns True if the stat is supported by this feature, False otherwise.""" + if isinstance(stat, dict): + stat_name = list(stat.keys())[0] + if stat_name == "max_blockwise_dynamic_range": + stat_dict = stat[stat_name] + if not isinstance(stat_dict, dict): + return False + # Ensure only supported keys are present + allowed_keys = {"block_size", "dims"} + if any(k not in allowed_keys for k in stat_dict.keys()): + return False + block_size = stat_dict.get("block_size", 32) + dims = stat_dict.get("dims", 1) + # Type and value validation + if not isinstance(block_size, int) or not isinstance(dims, int): + return False + if block_size > 0 and dims in [1, 2]: + return True + return False + return stat in BaseLogTensorStats._get_supported_stats_list(None) | { + "cur_amax", + "dynamic_range", + } + + def _parse_max_blockwise_dynamic_range_stats(self, stats: List[str | Dict]) -> List[str]: + """ + Adds all max_blockwise_dynamic_range stats to the stat computation logic. + Changes the types of the stats from Dict to str, for other stats nothing is changed. + For example, if the stats is [{"max_blockwise_dynamic_range": {"block_size": 32, "dims": 1}}], + it will be changed to ["max_blockwise_dynamic_range_block_size_32_dims_1"]. + """ + parsed_stats = [] + for stat in stats: + if isinstance(stat, dict): + block_size = stat["max_blockwise_dynamic_range"].get("block_size", 32) + dims = stat["max_blockwise_dynamic_range"].get("dims", 1) + add_max_blockwise_dynamic_range_stats(block_size, dims) + parsed_stats.append( + f"max_blockwise_dynamic_range_block_size_{block_size}_dims_{dims}" + ) + else: + parsed_stats.append(stat) + return parsed_stats + def _get_supported_stats_list(self): """Returns stats this feature can log.""" return BaseLogTensorStats._get_supported_stats_list(None) | {"cur_amax", "dynamic_range"} @@ -141,14 +194,16 @@ def inspect_tensor( ) for stat in config["stats"]: - assert ( - stat in self._get_supported_stats_list() + assert self._is_supported_stat( + stat ), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported." + stats = self._parse_max_blockwise_dynamic_range_stats(config["stats"]) + STATS_BUFFERS.try_add_buffer( layer_name=layer_name, tensor_name=tensor_name, - stats=config["stats"], + stats=stats, options=options, reduction_group=reduction_group, reduce_within_microbatch=reduce_within_microbatch, diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 2fa6985acf..e50cafffb9 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -26,6 +26,7 @@ def _compute_dynamic_range_top(tensor): return torch.log2(amax) +@torch.compile def _compute_dynamic_range_bottom(tensor): """Computes the log2 of the amin of the tensor""" tensor_abs = tensor.abs() @@ -37,6 +38,54 @@ def _compute_dynamic_range_bottom(tensor): return torch.log2(amin) +@torch.compile +def compute_max_blockwise_dynamic_range(tensor, block_size, dims): + """ + Max blockwise dynamic range (log2 max/min_nonzero). + Returns 0 if all blocks are zeros. + Otherwise computes dynamic range over non-zero blocks. + + For dims = 1 blocks contain block_size consecutive elements, + for dims = 2 blocks contain block_size x block_size elements. + """ + total_numel = tensor.numel() + assert dims in [1, 2], f"dims must be 1 or 2, got {dims}" + + # torch.compile friendly code - standard ** power does not work with jit + total_block_size = block_size * block_size if dims == 2 else block_size + assert ( + total_numel % total_block_size == 0 + ), f"Tensor numel ({total_numel}) is not divisible by block_size ({block_size})." + + tensor = tensor.abs().float() + if dims == 1: + tensor = tensor.reshape(-1, block_size) + per_block_amax = tensor.amax(dim=1) + per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=1) + else: + # We want to have tensor of shape [nr_blocks, block_size, block_size], + # where each block is a block_size x block_size tile of the original tensor. + dim_x = tensor.shape[-2] // block_size + dim_y = tensor.shape[-1] // block_size + tensor = ( + tensor.reshape(-1, dim_x, block_size, dim_y, block_size) + .permute(0, 1, 3, 2, 4) + .reshape(-1, block_size, block_size) + ) + per_block_amax = tensor.amax(dim=(1, 2)) + per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=(1, 2)) + + # Identify blocks that contain any non-zero element + nonzero_blocks = per_block_amax != 0 + dynamic_range_per_block = torch.where( + nonzero_blocks, + torch.log2(per_block_amax) - torch.log2(per_block_amin), + torch.zeros_like(per_block_amax, dtype=torch.float32), + ) + return dynamic_range_per_block.max() + + +@torch.compile def compute_variance(variances, numels, sums): """Welford algorithm is used for numerically stable distributed variance computation.""" mean = torch.sum(sums) / torch.sum(numels) @@ -45,6 +94,7 @@ def compute_variance(variances, numels, sums): return var +@torch.compile def compute_std(variances, numels, sums): """Computates standard deviation.""" return torch.sqrt(compute_variance(variances, numels, sums)) @@ -316,6 +366,23 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False): DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"} +def add_max_blockwise_dynamic_range_stats(block_size: int, dims: int): + """Register max_blockwise_X_dynamic_range stats for the recipe.""" + stat_name = f"max_blockwise_dynamic_range_block_size_{block_size}_dims_{dims}" + if stat_name in stats_to_num: + return # already registered + assert dims in [1, 2], f"dims must be 1 or 2, got {dims}" + stats_to_num[stat_name] = len(stats_to_num) + DEPENDENCIES[stat_name] = {stat_name} + + STATS[stat_name] = ( + lambda x, aux_dict, _block_size=block_size, _dims=dims: compute_max_blockwise_dynamic_range( + x, _block_size, _dims + ), + lambda buffers: max(_get(buffers, stat_name)), + ) + + for _columnwise in [True, False]: for _recipe_name in [ "", # default recipe