From f9e6ad5c756d02ce0957d7e46c5966b8071e6e97 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 29 Aug 2025 09:43:27 +0000 Subject: [PATCH 01/14] code drop Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 63 ++++++++++++++++++- .../debug/features/log_tensor_stats.py | 26 +++++++- .../debug/features/utils/stats_computation.py | 30 ++++++++- 3 files changed, 114 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index ca8e10ad69..45065cc664 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -14,7 +14,7 @@ import os from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.debug.pytorch.debug_state import TEDebugState - +import math fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -150,7 +150,7 @@ def test_sanity(feature_dirs): @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -def test_numerics(fp8_recipe, feature_dirs): +def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): if not fp8_available: pytest.skip(reason_for_no_fp8) if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling(): @@ -208,6 +208,61 @@ def test_numerics(fp8_recipe, feature_dirs): assert overflows == pytest.approx(expected.cpu(), abs=1e-4) +LOG_HIGH_PRECISION_CONFIG_BASE = """ +log: + layers: + layer_name_regex_pattern: .* + enabled: + True + transformer_engine: + LogTensorStats: + enabled: True + stats: [ + {stats} + ] + tensors: [activation, gradient, weight] + freq: 2 + start_step: 0 + end_step: 10 +""" + +def test_log_stats_numerics(feature_dirs): + stats = [ + "dynamic_range", + "max_blockwise_4_dynamic_range" + ] + log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG_BASE.format(stats=", ".join(stats)) + + with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir: + + epsilon = 1e-10 + tensor = torch.zeros(1024, 1024).cuda() + epsilon + tensor[0, :] = 1000 + + debug_api.transformer_engine.inspect_tensor( + layer_name="layer_name", + tensor_name="activation", + iteration=0, + tp_group=None, + tensor=tensor, + quantizer=None, + rowwise_quantized_tensor=None, + columnwise_quantized_tensor=None, + ) + debug_api.step() + + output = read_log(log_dir) + + for line in output.splitlines(): + if "max_blockwise_4_dynamic_range" in line: + max_blockwise_4_dynamic_range = float(line.split("value=")[1]) + expected = 0 + assert max_blockwise_4_dynamic_range == pytest.approx(expected, abs=1e-4) + elif "dynamic_range" in line: + dynamic_range = float(line.split("value=")[1]) + expected = math.log2(1000) - math.log2(epsilon) + assert dynamic_range == pytest.approx(expected, abs=1e-4) + @pytest.mark.parametrize("layer", ["linear", "transformer"]) def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): if not fp8_available: @@ -254,3 +309,7 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): debug_api.end_debug() TEDebugState._reset() + + +def test_max_blockwise_dynamic_range(feature_dirs): + pass \ No newline at end of file diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 7ba2f9f771..5a92fc069a 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -4,9 +4,10 @@ """LogTensorStats Feature support for nvidia-dlframework-inspect""" -from typing import Dict, Optional +from typing import Dict, Optional, List import torch +import re from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.registry import Registry, api_method @@ -19,7 +20,9 @@ from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params +from transformer_engine.debug.features.utils.stats_computation import add_max_blockwise_dynamic_range_stats +max_blockwise_regex = r"max_blockwise_\d+_dynamic_range" @Registry.register_feature(namespace="transformer_engine") class LogTensorStats(BaseLogTensorStats): @@ -45,6 +48,8 @@ class LogTensorStats(BaseLogTensorStats): - l2_norm - cur_amax – maximal absolute value of a tensor, - dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)` + - max_blockwise_X_dynamic_range: Computes the maximum dynamic range (log2(max) - log2(min)) across all blocks of size X within the tensor, where X is an integer specifying the block size. + tensors/tensors_struct: List[str] list of tensors to log @@ -88,6 +93,21 @@ class LogTensorStats(BaseLogTensorStats): stats: [dynamic_range] """ + def _is_supported_stat(self, stat: str): + """Returns True if the stat is supported by this feature.""" + + if re.match(max_blockwise_regex, stat): + return True + + return stat in BaseLogTensorStats._get_supported_stats_list(None) | {"cur_amax", "dynamic_range"} + + def _add_max_blockwise_dynamic_range_stats(self, stats: List[str]): + """Adds max_blockwise_X_dynamic_range stats for the recipe.""" + for stat in stats: + if re.match(max_blockwise_regex, stat): + block_size = int(stat.split("_")[2]) + add_max_blockwise_dynamic_range_stats(block_size) + def _get_supported_stats_list(self): """Returns stats this feature can log.""" return BaseLogTensorStats._get_supported_stats_list(None) | {"cur_amax", "dynamic_range"} @@ -142,8 +162,10 @@ def inspect_tensor( for stat in config["stats"]: assert ( - stat in self._get_supported_stats_list() + self._is_supported_stat(stat) ), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported." + + self._add_max_blockwise_dynamic_range_stats(config["stats"]) STATS_BUFFERS.try_add_buffer( layer_name=layer_name, diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 3842ab1c56..d230b5e44a 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -13,7 +13,7 @@ from transformer_engine.common.recipe import Format -@torch.compile +#@torch.compile def _compute_dynamic_range_top(tensor): """Computes the log2 of the amax of the tensor""" tensor_abs = tensor.abs() @@ -37,6 +37,21 @@ def _compute_dynamic_range_bottom(tensor): return torch.log2(amin) +@torch.compile +def compute_max_blockwise_dynamic_range(tensor, block_size): + """Computes the max dynamic range of the tensor.""" + total_numel = tensor.numel() + assert total_numel % block_size == 0, \ + f"Tensor numel ({total_numel}) is not divisible by block_size ({block_size})." + + abs_vals = tensor.reshape(-1, block_size).abs().float() + per_block_amax = abs_vals.amax(dim=1) + if torch.any(per_block_amax == 0): + return torch.tensor(float("inf"), device=tensor.device, dtype=tensor.dtype) + + per_block_amin_nz = abs_vals.masked_fill(abs_vals == 0, float("inf")).amin(dim=1) + return (torch.log2(per_block_amax) - torch.log2(per_block_amin_nz)).max() + def compute_variance(variances, numels, sums): """Welford algorithm is used for numerically stable distributed variance computation.""" mean = torch.sum(sums) / torch.sum(numels) @@ -305,6 +320,19 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False): DEPENDENCIES[stat_err] = {stat_err} DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"} +def add_max_blockwise_dynamic_range_stats(block_size: int): + """Register max_blockwise_X_dynamic_range stats for the recipe.""" + stat_name = f"max_blockwise_{block_size}_dynamic_range" + if stat_name in stats_to_num: + return # already registered + stats_to_num[stat_name] = len(stats_to_num) + DEPENDENCIES[stat_name] = {stat_name} + + STATS[stat_name] = ( + lambda x, aux_dict: compute_max_blockwise_dynamic_range(x, block_size), + lambda buffers: max(_get(buffers, stat_name)), + ) + for _columnwise in [True, False]: for _recipe_name in [ From dcc43bbd2b7f3dd17c3b3ecf019efa5f516e77a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Aug 2025 09:53:05 +0000 Subject: [PATCH 02/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/debug/test_log.py | 11 +++++----- .../debug/features/log_tensor_stats.py | 20 ++++++++++++------- .../debug/features/utils/stats_computation.py | 11 ++++++---- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 45065cc664..ca984fc3ac 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -226,15 +226,13 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): end_step: 10 """ + def test_log_stats_numerics(feature_dirs): - stats = [ - "dynamic_range", - "max_blockwise_4_dynamic_range" - ] + stats = ["dynamic_range", "max_blockwise_4_dynamic_range"] log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG_BASE.format(stats=", ".join(stats)) with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir: - + epsilon = 1e-10 tensor = torch.zeros(1024, 1024).cuda() + epsilon tensor[0, :] = 1000 @@ -263,6 +261,7 @@ def test_log_stats_numerics(feature_dirs): expected = math.log2(1000) - math.log2(epsilon) assert dynamic_range == pytest.approx(expected, abs=1e-4) + @pytest.mark.parametrize("layer", ["linear", "transformer"]) def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): if not fp8_available: @@ -312,4 +311,4 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): def test_max_blockwise_dynamic_range(feature_dirs): - pass \ No newline at end of file + pass diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 5a92fc069a..9dca215d7c 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -20,10 +20,13 @@ from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params -from transformer_engine.debug.features.utils.stats_computation import add_max_blockwise_dynamic_range_stats +from transformer_engine.debug.features.utils.stats_computation import ( + add_max_blockwise_dynamic_range_stats, +) max_blockwise_regex = r"max_blockwise_\d+_dynamic_range" + @Registry.register_feature(namespace="transformer_engine") class LogTensorStats(BaseLogTensorStats): """ @@ -49,7 +52,7 @@ class LogTensorStats(BaseLogTensorStats): - cur_amax – maximal absolute value of a tensor, - dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)` - max_blockwise_X_dynamic_range: Computes the maximum dynamic range (log2(max) - log2(min)) across all blocks of size X within the tensor, where X is an integer specifying the block size. - + tensors/tensors_struct: List[str] list of tensors to log @@ -98,8 +101,11 @@ def _is_supported_stat(self, stat: str): if re.match(max_blockwise_regex, stat): return True - - return stat in BaseLogTensorStats._get_supported_stats_list(None) | {"cur_amax", "dynamic_range"} + + return stat in BaseLogTensorStats._get_supported_stats_list(None) | { + "cur_amax", + "dynamic_range", + } def _add_max_blockwise_dynamic_range_stats(self, stats: List[str]): """Adds max_blockwise_X_dynamic_range stats for the recipe.""" @@ -161,10 +167,10 @@ def inspect_tensor( ) for stat in config["stats"]: - assert ( - self._is_supported_stat(stat) + assert self._is_supported_stat( + stat ), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported." - + self._add_max_blockwise_dynamic_range_stats(config["stats"]) STATS_BUFFERS.try_add_buffer( diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index d230b5e44a..df75b59b7c 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -13,7 +13,7 @@ from transformer_engine.common.recipe import Format -#@torch.compile +# @torch.compile def _compute_dynamic_range_top(tensor): """Computes the log2 of the amax of the tensor""" tensor_abs = tensor.abs() @@ -41,8 +41,9 @@ def _compute_dynamic_range_bottom(tensor): def compute_max_blockwise_dynamic_range(tensor, block_size): """Computes the max dynamic range of the tensor.""" total_numel = tensor.numel() - assert total_numel % block_size == 0, \ - f"Tensor numel ({total_numel}) is not divisible by block_size ({block_size})." + assert ( + total_numel % block_size == 0 + ), f"Tensor numel ({total_numel}) is not divisible by block_size ({block_size})." abs_vals = tensor.reshape(-1, block_size).abs().float() per_block_amax = abs_vals.amax(dim=1) @@ -52,6 +53,7 @@ def compute_max_blockwise_dynamic_range(tensor, block_size): per_block_amin_nz = abs_vals.masked_fill(abs_vals == 0, float("inf")).amin(dim=1) return (torch.log2(per_block_amax) - torch.log2(per_block_amin_nz)).max() + def compute_variance(variances, numels, sums): """Welford algorithm is used for numerically stable distributed variance computation.""" mean = torch.sum(sums) / torch.sum(numels) @@ -320,11 +322,12 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False): DEPENDENCIES[stat_err] = {stat_err} DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"} + def add_max_blockwise_dynamic_range_stats(block_size: int): """Register max_blockwise_X_dynamic_range stats for the recipe.""" stat_name = f"max_blockwise_{block_size}_dynamic_range" if stat_name in stats_to_num: - return # already registered + return # already registered stats_to_num[stat_name] = len(stats_to_num) DEPENDENCIES[stat_name] = {stat_name} From 0b84d6a7cdc735fdc283a865c8f9fcaa4e16b09b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 15 Sep 2025 12:58:19 +0000 Subject: [PATCH 03/14] fix Signed-off-by: Pawel Gadzinski --- .../debug/features/log_tensor_stats.py | 2 +- .../debug/features/utils/stats_computation.py | 21 +++++++++++++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 9dca215d7c..e7210e8aac 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -5,9 +5,9 @@ """LogTensorStats Feature support for nvidia-dlframework-inspect""" from typing import Dict, Optional, List +import re import torch -import re from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.registry import Registry, api_method diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 1625761543..5d09bfa2aa 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -39,7 +39,10 @@ def _compute_dynamic_range_bottom(tensor): @torch.compile def compute_max_blockwise_dynamic_range(tensor, block_size): - """Computes the max dynamic range of the tensor.""" + """Max blockwise dynamic range (log2 max/min_nonzero). + Returns 0 if all blocks are zeros. + Otherwise computes dynamic range over non-zero blocks. + """ total_numel = tensor.numel() assert ( total_numel % block_size == 0 @@ -47,11 +50,21 @@ def compute_max_blockwise_dynamic_range(tensor, block_size): abs_vals = tensor.reshape(-1, block_size).abs().float() per_block_amax = abs_vals.amax(dim=1) - if torch.any(per_block_amax == 0): - return torch.tensor(float("inf"), device=tensor.device, dtype=tensor.dtype) + # Identify blocks that contain any non-zero element + nonzero_blocks = per_block_amax != 0 + if not torch.any(nonzero_blocks): + # If all blocks are zero, return 0 + return torch.zeros((), device=tensor.device, dtype=abs_vals.dtype) + + # Compute smallest non-zero magnitude per block per_block_amin_nz = abs_vals.masked_fill(abs_vals == 0, float("inf")).amin(dim=1) - return (torch.log2(per_block_amax) - torch.log2(per_block_amin_nz)).max() + + # Restrict DR computation to blocks that are non-zero + amax_nz = per_block_amax[nonzero_blocks] + amin_nz = per_block_amin_nz[nonzero_blocks] + + return (torch.log2(amax_nz) - torch.log2(amin_nz)).max() def compute_variance(variances, numels, sums): From 7b9d336d6b5b043d8440e173a30bb9444f16b58e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Sep 2025 12:59:14 +0000 Subject: [PATCH 04/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/debug/features/utils/stats_computation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 5d09bfa2aa..d0c5145fe4 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -40,8 +40,8 @@ def _compute_dynamic_range_bottom(tensor): @torch.compile def compute_max_blockwise_dynamic_range(tensor, block_size): """Max blockwise dynamic range (log2 max/min_nonzero). - Returns 0 if all blocks are zeros. - Otherwise computes dynamic range over non-zero blocks. + Returns 0 if all blocks are zeros. + Otherwise computes dynamic range over non-zero blocks. """ total_numel = tensor.numel() assert ( From 1d09428c45aabb0203b556441fc6d5e680af814d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Sep 2025 13:40:57 +0000 Subject: [PATCH 05/14] fix Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 21 +++++++--- .../debug/features/log_tensor_stats.py | 38 +++++++++++-------- .../debug/features/utils/stats_computation.py | 27 ++++++++----- 3 files changed, 54 insertions(+), 32 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 09cbb6066c..5df1915147 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -215,9 +215,14 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): transformer_engine: LogTensorStats: enabled: True - stats: [ - {stats} - ] + stats: + - dynamic_range + - max_blockwise_dynamic_range: + block_size: 4 + dims: 1 + - max_blockwise_dynamic_range: + block_size: 4 + dims: 2 tensors: [activation, gradient, weight] freq: 2 start_step: 0 @@ -250,10 +255,14 @@ def test_log_stats_numerics(feature_dirs): output = read_log(log_dir) for line in output.splitlines(): - if "max_blockwise_4_dynamic_range" in line: - max_blockwise_4_dynamic_range = float(line.split("value=")[1]) + if "max_blockwise_dynamic_range_block_size_4_dims_1" in line: + max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1]) + expected = 0 + assert max_blockwise_dynamic_range_block_size_4_dims_1 == pytest.approx(expected, abs=1e-4) + elif "max_blockwise_dynamic_range_block_size_4_dims_2" in line: + max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1]) expected = 0 - assert max_blockwise_4_dynamic_range == pytest.approx(expected, abs=1e-4) + assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx(expected, abs=1e-4) elif "dynamic_range" in line: dynamic_range = float(line.split("value=")[1]) expected = math.log2(1000) - math.log2(epsilon) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index e7210e8aac..03b5c0404b 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -5,7 +5,6 @@ """LogTensorStats Feature support for nvidia-dlframework-inspect""" from typing import Dict, Optional, List -import re import torch @@ -24,8 +23,6 @@ add_max_blockwise_dynamic_range_stats, ) -max_blockwise_regex = r"max_blockwise_\d+_dynamic_range" - @Registry.register_feature(namespace="transformer_engine") class LogTensorStats(BaseLogTensorStats): @@ -51,7 +48,12 @@ class LogTensorStats(BaseLogTensorStats): - l2_norm - cur_amax – maximal absolute value of a tensor, - dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)` - - max_blockwise_X_dynamic_range: Computes the maximum dynamic range (log2(max) - log2(min)) across all blocks of size X within the tensor, where X is an integer specifying the block size. + - max_blockwise_dynamic_range: + Computes the maximum dynamic range (log2(max) - log2(min)) across all blocks of size block_size within the tensor, + where block_size is an integer specifying the block size. + For dim=1 there are block_size consecutive elements in the block, for dim=2 the block is block_size x block_size elements tile. + - block_size: int, default = 32 + - dims: int, default = 1, allowed values are 1 and 2 tensors/tensors_struct: List[str] list of tensors to log @@ -96,23 +98,27 @@ class LogTensorStats(BaseLogTensorStats): stats: [dynamic_range] """ - def _is_supported_stat(self, stat: str): + def _is_supported_stat(self, stat: str | Dict): """Returns True if the stat is supported by this feature.""" - - if re.match(max_blockwise_regex, stat): - return True - + if isinstance(stat, dict): + raise NotImplementedError("Max blockwise dynamic range is not supported for dict stats") return stat in BaseLogTensorStats._get_supported_stats_list(None) | { "cur_amax", - "dynamic_range", + "dynamic_range" } - def _add_max_blockwise_dynamic_range_stats(self, stats: List[str]): + def parse_max_blockwise_dynamic_range_stats(self, stats: List[str | Dict]): """Adds max_blockwise_X_dynamic_range stats for the recipe.""" + parsed_stats = [] for stat in stats: - if re.match(max_blockwise_regex, stat): - block_size = int(stat.split("_")[2]) - add_max_blockwise_dynamic_range_stats(block_size) + if isinstance(stat, dict): + block_size = stat["block_size"] + dims = stat["dims"] + add_max_blockwise_dynamic_range_stats(block_size, dims) + parsed_stats.append(f"max_blockwise_dynamic_range_block_size_{block_size}_dims_{dims}") + else: + parsed_stats.append(stat) + return parsed_stats def _get_supported_stats_list(self): """Returns stats this feature can log.""" @@ -171,12 +177,12 @@ def inspect_tensor( stat ), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported." - self._add_max_blockwise_dynamic_range_stats(config["stats"]) + stats = self.parse_max_blockwise_dynamic_range_stats(config["stats"]) STATS_BUFFERS.try_add_buffer( layer_name=layer_name, tensor_name=tensor_name, - stats=config["stats"], + stats=stats, options=options, reduction_group=reduction_group, reduce_within_microbatch=reduce_within_microbatch, diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index d0c5145fe4..4c0657a1a4 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -12,8 +12,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Format - -# @torch.compile +@torch.compile def _compute_dynamic_range_top(tensor): """Computes the log2 of the amax of the tensor""" tensor_abs = tensor.abs() @@ -26,6 +25,7 @@ def _compute_dynamic_range_top(tensor): return torch.log2(amax) +@torch.compile def _compute_dynamic_range_bottom(tensor): """Computes the log2 of the amin of the tensor""" tensor_abs = tensor.abs() @@ -36,7 +36,6 @@ def _compute_dynamic_range_bottom(tensor): amin = torch.tensor(1, device=tensor.device).to(torch.float) return torch.log2(amin) - @torch.compile def compute_max_blockwise_dynamic_range(tensor, block_size): """Max blockwise dynamic range (log2 max/min_nonzero). @@ -45,17 +44,21 @@ def compute_max_blockwise_dynamic_range(tensor, block_size): """ total_numel = tensor.numel() assert ( - total_numel % block_size == 0 + total_numel % (block_size ** dims) == 0 ), f"Tensor numel ({total_numel}) is not divisible by block_size ({block_size})." + assert dims in [1, 2], f"dims must be 1 or 2, got {dims}" - abs_vals = tensor.reshape(-1, block_size).abs().float() - per_block_amax = abs_vals.amax(dim=1) + tensor = tensor.abs().float() + if dims == 1: + per_block_amax = tensor.reshape(-1, block_size).amax(dim=1) + else: + per_block_amax = tensor.reshape(-1, block_size, block_size).amax(dim=(1, 2)) # Identify blocks that contain any non-zero element nonzero_blocks = per_block_amax != 0 if not torch.any(nonzero_blocks): # If all blocks are zero, return 0 - return torch.zeros((), device=tensor.device, dtype=abs_vals.dtype) + return torch.zeros((), device=tensor.device, dtype=torch.float32) # Compute smallest non-zero magnitude per block per_block_amin_nz = abs_vals.masked_fill(abs_vals == 0, float("inf")).amin(dim=1) @@ -67,6 +70,7 @@ def compute_max_blockwise_dynamic_range(tensor, block_size): return (torch.log2(amax_nz) - torch.log2(amin_nz)).max() +@torch.compile def compute_variance(variances, numels, sums): """Welford algorithm is used for numerically stable distributed variance computation.""" mean = torch.sum(sums) / torch.sum(numels) @@ -75,6 +79,7 @@ def compute_variance(variances, numels, sums): return var +@torch.compile def compute_std(variances, numels, sums): """Computates standard deviation.""" return torch.sqrt(compute_variance(variances, numels, sums)) @@ -346,16 +351,18 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False): DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"} -def add_max_blockwise_dynamic_range_stats(block_size: int): +def add_max_blockwise_dynamic_range_stats(block_size: int, dims: int): """Register max_blockwise_X_dynamic_range stats for the recipe.""" - stat_name = f"max_blockwise_{block_size}_dynamic_range" + stat_name = f"max_blockwise_dynamic_range_block_size_{block_size}_dims_{dims}" if stat_name in stats_to_num: return # already registered + assert dims in [1, 2], f"dims must be 1 or 2, got {dims}" stats_to_num[stat_name] = len(stats_to_num) DEPENDENCIES[stat_name] = {stat_name} STATS[stat_name] = ( - lambda x, aux_dict: compute_max_blockwise_dynamic_range(x, block_size), + lambda x, aux_dict, _block_size=block_size, _dims=dims: + compute_max_blockwise_dynamic_range(x, _block_size, _dims), lambda buffers: max(_get(buffers, stat_name)), ) From be57f93b5b669e701a44d0f65a710b6f9beeb5a5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Sep 2025 13:42:19 +0000 Subject: [PATCH 06/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/debug/test_log.py | 10 +++++++--- transformer_engine/debug/features/log_tensor_stats.py | 10 ++++++---- .../debug/features/utils/stats_computation.py | 9 ++++++--- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 5df1915147..e054e609ac 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -215,7 +215,7 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): transformer_engine: LogTensorStats: enabled: True - stats: + stats: - dynamic_range - max_blockwise_dynamic_range: block_size: 4 @@ -258,11 +258,15 @@ def test_log_stats_numerics(feature_dirs): if "max_blockwise_dynamic_range_block_size_4_dims_1" in line: max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1]) expected = 0 - assert max_blockwise_dynamic_range_block_size_4_dims_1 == pytest.approx(expected, abs=1e-4) + assert max_blockwise_dynamic_range_block_size_4_dims_1 == pytest.approx( + expected, abs=1e-4 + ) elif "max_blockwise_dynamic_range_block_size_4_dims_2" in line: max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1]) expected = 0 - assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx(expected, abs=1e-4) + assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx( + expected, abs=1e-4 + ) elif "dynamic_range" in line: dynamic_range = float(line.split("value=")[1]) expected = math.log2(1000) - math.log2(epsilon) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 03b5c0404b..0daadc3e61 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -48,8 +48,8 @@ class LogTensorStats(BaseLogTensorStats): - l2_norm - cur_amax – maximal absolute value of a tensor, - dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)` - - max_blockwise_dynamic_range: - Computes the maximum dynamic range (log2(max) - log2(min)) across all blocks of size block_size within the tensor, + - max_blockwise_dynamic_range: + Computes the maximum dynamic range (log2(max) - log2(min)) across all blocks of size block_size within the tensor, where block_size is an integer specifying the block size. For dim=1 there are block_size consecutive elements in the block, for dim=2 the block is block_size x block_size elements tile. - block_size: int, default = 32 @@ -104,7 +104,7 @@ def _is_supported_stat(self, stat: str | Dict): raise NotImplementedError("Max blockwise dynamic range is not supported for dict stats") return stat in BaseLogTensorStats._get_supported_stats_list(None) | { "cur_amax", - "dynamic_range" + "dynamic_range", } def parse_max_blockwise_dynamic_range_stats(self, stats: List[str | Dict]): @@ -115,7 +115,9 @@ def parse_max_blockwise_dynamic_range_stats(self, stats: List[str | Dict]): block_size = stat["block_size"] dims = stat["dims"] add_max_blockwise_dynamic_range_stats(block_size, dims) - parsed_stats.append(f"max_blockwise_dynamic_range_block_size_{block_size}_dims_{dims}") + parsed_stats.append( + f"max_blockwise_dynamic_range_block_size_{block_size}_dims_{dims}" + ) else: parsed_stats.append(stat) return parsed_stats diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 4c0657a1a4..143c9bda66 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -12,6 +12,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Format + @torch.compile def _compute_dynamic_range_top(tensor): """Computes the log2 of the amax of the tensor""" @@ -36,6 +37,7 @@ def _compute_dynamic_range_bottom(tensor): amin = torch.tensor(1, device=tensor.device).to(torch.float) return torch.log2(amin) + @torch.compile def compute_max_blockwise_dynamic_range(tensor, block_size): """Max blockwise dynamic range (log2 max/min_nonzero). @@ -44,7 +46,7 @@ def compute_max_blockwise_dynamic_range(tensor, block_size): """ total_numel = tensor.numel() assert ( - total_numel % (block_size ** dims) == 0 + total_numel % (block_size**dims) == 0 ), f"Tensor numel ({total_numel}) is not divisible by block_size ({block_size})." assert dims in [1, 2], f"dims must be 1 or 2, got {dims}" @@ -361,8 +363,9 @@ def add_max_blockwise_dynamic_range_stats(block_size: int, dims: int): DEPENDENCIES[stat_name] = {stat_name} STATS[stat_name] = ( - lambda x, aux_dict, _block_size=block_size, _dims=dims: - compute_max_blockwise_dynamic_range(x, _block_size, _dims), + lambda x, aux_dict, _block_size=block_size, _dims=dims: compute_max_blockwise_dynamic_range( + x, _block_size, _dims + ), lambda buffers: max(_get(buffers, stat_name)), ) From ae8d7bcceeb160ae6f9302351fb72a4464cbde40 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Sep 2025 14:44:46 +0000 Subject: [PATCH 07/14] fixes Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 10 +++++--- .../debug/features/log_tensor_stats.py | 11 ++++++--- .../debug/features/utils/stats_computation.py | 23 ++++++++++--------- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index e054e609ac..302e0beb53 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -238,7 +238,10 @@ def test_log_stats_numerics(feature_dirs): epsilon = 1e-10 tensor = torch.zeros(1024, 1024).cuda() + epsilon - tensor[0, :] = 1000 + A = 1000 + B = 50 + tensor[0, :] = A + tensor[1:4, :] = B debug_api.transformer_engine.inspect_tensor( layer_name="layer_name", @@ -254,6 +257,7 @@ def test_log_stats_numerics(feature_dirs): output = read_log(log_dir) + for line in output.splitlines(): if "max_blockwise_dynamic_range_block_size_4_dims_1" in line: max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1]) @@ -263,13 +267,13 @@ def test_log_stats_numerics(feature_dirs): ) elif "max_blockwise_dynamic_range_block_size_4_dims_2" in line: max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1]) - expected = 0 + expected = math.log2(A) - math.log2(B) assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx( expected, abs=1e-4 ) elif "dynamic_range" in line: dynamic_range = float(line.split("value=")[1]) - expected = math.log2(1000) - math.log2(epsilon) + expected = math.log2(A) - math.log2(epsilon) assert dynamic_range == pytest.approx(expected, abs=1e-4) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 0daadc3e61..ba23c977a5 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -101,7 +101,12 @@ class LogTensorStats(BaseLogTensorStats): def _is_supported_stat(self, stat: str | Dict): """Returns True if the stat is supported by this feature.""" if isinstance(stat, dict): - raise NotImplementedError("Max blockwise dynamic range is not supported for dict stats") + stat_name = list(stat.keys())[0] + assert stat_name == "max_blockwise_dynamic_range" + stat_dict = stat[stat_name] + assert set(stat_dict.keys()) == {"block_size", "dims"} + assert stat_dict["block_size"] > 0 and stat_dict["dims"] in [1, 2] + return True return stat in BaseLogTensorStats._get_supported_stats_list(None) | { "cur_amax", "dynamic_range", @@ -112,8 +117,8 @@ def parse_max_blockwise_dynamic_range_stats(self, stats: List[str | Dict]): parsed_stats = [] for stat in stats: if isinstance(stat, dict): - block_size = stat["block_size"] - dims = stat["dims"] + block_size = stat["max_blockwise_dynamic_range"].get("block_size", 32) + dims = stat["max_blockwise_dynamic_range"].get("dims", 1) add_max_blockwise_dynamic_range_stats(block_size, dims) parsed_stats.append( f"max_blockwise_dynamic_range_block_size_{block_size}_dims_{dims}" diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 143c9bda66..87b710581e 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -39,7 +39,7 @@ def _compute_dynamic_range_bottom(tensor): @torch.compile -def compute_max_blockwise_dynamic_range(tensor, block_size): +def compute_max_blockwise_dynamic_range(tensor, block_size, dims): """Max blockwise dynamic range (log2 max/min_nonzero). Returns 0 if all blocks are zeros. Otherwise computes dynamic range over non-zero blocks. @@ -52,24 +52,25 @@ def compute_max_blockwise_dynamic_range(tensor, block_size): tensor = tensor.abs().float() if dims == 1: - per_block_amax = tensor.reshape(-1, block_size).amax(dim=1) + tensor = tensor.reshape(-1, block_size) + per_block_amax = tensor.amax(dim=1) + per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=1) else: - per_block_amax = tensor.reshape(-1, block_size, block_size).amax(dim=(1, 2)) + dim_a = tensor.shape[-2] // block_size + dim_b = tensor.shape[-1] // block_size + tensor = tensor.reshape(-1, dim_a, block_size, dim_b, block_size).\ + permute(0, 1, 3, 2, 4).reshape(-1, block_size, block_size) + per_block_amax = tensor.amax(dim=(1, 2)) + per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=(1, 2)) # Identify blocks that contain any non-zero element nonzero_blocks = per_block_amax != 0 if not torch.any(nonzero_blocks): # If all blocks are zero, return 0 return torch.zeros((), device=tensor.device, dtype=torch.float32) + - # Compute smallest non-zero magnitude per block - per_block_amin_nz = abs_vals.masked_fill(abs_vals == 0, float("inf")).amin(dim=1) - - # Restrict DR computation to blocks that are non-zero - amax_nz = per_block_amax[nonzero_blocks] - amin_nz = per_block_amin_nz[nonzero_blocks] - - return (torch.log2(amax_nz) - torch.log2(amin_nz)).max() + return (torch.log2(per_block_amax) - torch.log2(per_block_amin)).max() @torch.compile From d52647366e1021e0fa6f2507141f9e24c73634ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:00:37 +0000 Subject: [PATCH 08/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/debug/test_log.py | 1 - .../debug/features/utils/stats_computation.py | 8 +++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 302e0beb53..63466f12ff 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -257,7 +257,6 @@ def test_log_stats_numerics(feature_dirs): output = read_log(log_dir) - for line in output.splitlines(): if "max_blockwise_dynamic_range_block_size_4_dims_1" in line: max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1]) diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 87b710581e..21525de86d 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -58,8 +58,11 @@ def compute_max_blockwise_dynamic_range(tensor, block_size, dims): else: dim_a = tensor.shape[-2] // block_size dim_b = tensor.shape[-1] // block_size - tensor = tensor.reshape(-1, dim_a, block_size, dim_b, block_size).\ - permute(0, 1, 3, 2, 4).reshape(-1, block_size, block_size) + tensor = ( + tensor.reshape(-1, dim_a, block_size, dim_b, block_size) + .permute(0, 1, 3, 2, 4) + .reshape(-1, block_size, block_size) + ) per_block_amax = tensor.amax(dim=(1, 2)) per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=(1, 2)) @@ -68,7 +71,6 @@ def compute_max_blockwise_dynamic_range(tensor, block_size, dims): if not torch.any(nonzero_blocks): # If all blocks are zero, return 0 return torch.zeros((), device=tensor.device, dtype=torch.float32) - return (torch.log2(per_block_amax) - torch.log2(per_block_amin)).max() From 1421021e7361184d4b4bacf4441b930e25c0f76e Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Sep 2025 15:28:19 +0000 Subject: [PATCH 09/14] fix Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 11 +++---- .../debug/features/log_tensor_stats.py | 6 ++-- .../debug/features/utils/stats_computation.py | 29 ++++++++++++------- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 63466f12ff..e89c2f6017 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -229,17 +229,18 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): end_step: 10 """ - def test_log_stats_numerics(feature_dirs): + """ Check corectness of dynamic range and max blockwise dynamic range stats """ stats = ["dynamic_range", "max_blockwise_4_dynamic_range"] log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG_BASE.format(stats=", ".join(stats)) with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir: - + # There is 1024 x 1024 tensor with very small epsilon values in almost all elements, + # one row of large value A and three rows of large value B. epsilon = 1e-10 - tensor = torch.zeros(1024, 1024).cuda() + epsilon A = 1000 B = 50 + tensor = torch.zeros(1024, 1024).cuda() + epsilon tensor[0, :] = A tensor[1:4, :] = B @@ -322,7 +323,3 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): debug_api.end_debug() TEDebugState._reset() - - -def test_max_blockwise_dynamic_range(feature_dirs): - pass diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index ba23c977a5..043cb5601c 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -48,10 +48,8 @@ class LogTensorStats(BaseLogTensorStats): - l2_norm - cur_amax – maximal absolute value of a tensor, - dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)` - - max_blockwise_dynamic_range: - Computes the maximum dynamic range (log2(max) - log2(min)) across all blocks of size block_size within the tensor, - where block_size is an integer specifying the block size. - For dim=1 there are block_size consecutive elements in the block, for dim=2 the block is block_size x block_size elements tile. + - max_blockwise_dynamic_range: Computes the maximum dynamic range (log2(max) - log2(min)) across all blocks of size block_size within the tensor, where block_size is an integer specifying the block size. For dim=1 there are block_size consecutive elements in the block, for dim=2 the block is block_size x block_size elements tile. + - block_size: int, default = 32 - dims: int, default = 1, allowed values are 1 and 2 diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 21525de86d..5cca78d78f 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -40,9 +40,13 @@ def _compute_dynamic_range_bottom(tensor): @torch.compile def compute_max_blockwise_dynamic_range(tensor, block_size, dims): - """Max blockwise dynamic range (log2 max/min_nonzero). - Returns 0 if all blocks are zeros. - Otherwise computes dynamic range over non-zero blocks. + """ + Max blockwise dynamic range (log2 max/min_nonzero). + Returns 0 if all blocks are zeros. + Otherwise computes dynamic range over non-zero blocks. + + For dims = 1 blocks contain block_size consecutive elements, + for dims = 2 blocks contain block_size x block_size elements. """ total_numel = tensor.numel() assert ( @@ -56,10 +60,12 @@ def compute_max_blockwise_dynamic_range(tensor, block_size, dims): per_block_amax = tensor.amax(dim=1) per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=1) else: - dim_a = tensor.shape[-2] // block_size - dim_b = tensor.shape[-1] // block_size + # We want to have tensor of shape [nr_blocks, block_size, block_size], + # where each block is a block_size x block_size tile of the original tensor. + dim_x = tensor.shape[-2] // block_size + dim_y = tensor.shape[-1] // block_size tensor = ( - tensor.reshape(-1, dim_a, block_size, dim_b, block_size) + tensor.reshape(-1, dim_x, block_size, dim_y, block_size) .permute(0, 1, 3, 2, 4) .reshape(-1, block_size, block_size) ) @@ -68,11 +74,12 @@ def compute_max_blockwise_dynamic_range(tensor, block_size, dims): # Identify blocks that contain any non-zero element nonzero_blocks = per_block_amax != 0 - if not torch.any(nonzero_blocks): - # If all blocks are zero, return 0 - return torch.zeros((), device=tensor.device, dtype=torch.float32) - - return (torch.log2(per_block_amax) - torch.log2(per_block_amin)).max() + dynamic_range_per_block = torch.where( + nonzero_blocks, + torch.log2(per_block_amax) - torch.log2(per_block_amin), + torch.zeros_like(per_block_amax, dtype=torch.float32), + ) + return dynamic_range_per_block.max() @torch.compile From a6f8546e692e25b5355a9665fdcc8be342018a14 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:29:17 +0000 Subject: [PATCH 10/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/debug/test_log.py | 3 ++- transformer_engine/debug/features/log_tensor_stats.py | 2 +- .../debug/features/utils/stats_computation.py | 10 +++++----- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index e89c2f6017..b79c94455b 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -229,8 +229,9 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): end_step: 10 """ + def test_log_stats_numerics(feature_dirs): - """ Check corectness of dynamic range and max blockwise dynamic range stats """ + """Check corectness of dynamic range and max blockwise dynamic range stats""" stats = ["dynamic_range", "max_blockwise_4_dynamic_range"] log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG_BASE.format(stats=", ".join(stats)) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 043cb5601c..d30b85de77 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -49,7 +49,7 @@ class LogTensorStats(BaseLogTensorStats): - cur_amax – maximal absolute value of a tensor, - dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)` - max_blockwise_dynamic_range: Computes the maximum dynamic range (log2(max) - log2(min)) across all blocks of size block_size within the tensor, where block_size is an integer specifying the block size. For dim=1 there are block_size consecutive elements in the block, for dim=2 the block is block_size x block_size elements tile. - + - block_size: int, default = 32 - dims: int, default = 1, allowed values are 1 and 2 diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 5cca78d78f..4962b2a20c 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -41,12 +41,12 @@ def _compute_dynamic_range_bottom(tensor): @torch.compile def compute_max_blockwise_dynamic_range(tensor, block_size, dims): """ - Max blockwise dynamic range (log2 max/min_nonzero). - Returns 0 if all blocks are zeros. - Otherwise computes dynamic range over non-zero blocks. + Max blockwise dynamic range (log2 max/min_nonzero). + Returns 0 if all blocks are zeros. + Otherwise computes dynamic range over non-zero blocks. - For dims = 1 blocks contain block_size consecutive elements, - for dims = 2 blocks contain block_size x block_size elements. + For dims = 1 blocks contain block_size consecutive elements, + for dims = 2 blocks contain block_size x block_size elements. """ total_numel = tensor.numel() assert ( From 70f977fea57d23f5ee09bc80095dfdf4c66803e9 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Sep 2025 15:35:06 +0000 Subject: [PATCH 11/14] fix Signed-off-by: Pawel Gadzinski --- .../debug/features/log_tensor_stats.py | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index d30b85de77..9a591e0935 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -97,21 +97,37 @@ class LogTensorStats(BaseLogTensorStats): """ def _is_supported_stat(self, stat: str | Dict): - """Returns True if the stat is supported by this feature.""" + """ Returns True if the stat is supported by this feature, False otherwise. """ if isinstance(stat, dict): stat_name = list(stat.keys())[0] - assert stat_name == "max_blockwise_dynamic_range" - stat_dict = stat[stat_name] - assert set(stat_dict.keys()) == {"block_size", "dims"} - assert stat_dict["block_size"] > 0 and stat_dict["dims"] in [1, 2] - return True + if stat_name == "max_blockwise_dynamic_range": + stat_dict = stat[stat_name] + if not isinstance(stat_dict, dict): + return False + # Ensure only supported keys are present + allowed_keys = {"block_size", "dims"} + if any(k not in allowed_keys for k in stat_dict.keys()): + return False + block_size = stat_dict.get("block_size", 32) + dims = stat_dict.get("dims", 1) + # Type and value validation + if not isinstance(block_size, int) or not isinstance(dims, int): + return False + if block_size > 0 and dims in [1, 2]: + return True + return False return stat in BaseLogTensorStats._get_supported_stats_list(None) | { "cur_amax", "dynamic_range", } - def parse_max_blockwise_dynamic_range_stats(self, stats: List[str | Dict]): - """Adds max_blockwise_X_dynamic_range stats for the recipe.""" + def _parse_max_blockwise_dynamic_range_stats(self, stats: List[str | Dict]) -> List[str]: + """ + Adds all max_blockwise_dynamic_range stats to the stat computation logic. + Changes the types of the stats from Dict to str, for other stats nothing is changed. + For example, if the stats is [{"max_blockwise_dynamic_range": {"block_size": 32, "dims": 1}}], + it will be changed to ["max_blockwise_dynamic_range_block_size_32_dims_1"]. + """ parsed_stats = [] for stat in stats: if isinstance(stat, dict): @@ -182,7 +198,7 @@ def inspect_tensor( stat ), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported." - stats = self.parse_max_blockwise_dynamic_range_stats(config["stats"]) + stats = self._parse_max_blockwise_dynamic_range_stats(config["stats"]) STATS_BUFFERS.try_add_buffer( layer_name=layer_name, From b19630fbd6814099073c08aef5dff20099926e1d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Sep 2025 15:36:52 +0000 Subject: [PATCH 12/14] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/features/log_tensor_stats.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 9a591e0935..d6ef8036f5 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -47,9 +47,9 @@ class LogTensorStats(BaseLogTensorStats): - l1_norm - l2_norm - cur_amax – maximal absolute value of a tensor, - - dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)` - - max_blockwise_dynamic_range: Computes the maximum dynamic range (log2(max) - log2(min)) across all blocks of size block_size within the tensor, where block_size is an integer specifying the block size. For dim=1 there are block_size consecutive elements in the block, for dim=2 the block is block_size x block_size elements tile. - + - dynamic_range – equal to `torch.log2(amax) - torch.log2(nonzero_amin)` + - max_blockwise_dynamic_range – Computes the maximum dynamic range `log2(amax) - log2(nonzero_amin)` across all blocks of size block_size within the tensor, where block_size is an integer specifying the block size. For `dim=1` there are block_size consecutive elements in the block, for `dim=2` the block is block_size x block_size elements tile. + - block_size: int, default = 32 - dims: int, default = 1, allowed values are 1 and 2 From 6cc2e2473e5fd917bbed835322d6d0621328d544 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:37:52 +0000 Subject: [PATCH 13/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../debug/features/log_tensor_stats.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index d6ef8036f5..62a4dae41b 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -49,7 +49,7 @@ class LogTensorStats(BaseLogTensorStats): - cur_amax – maximal absolute value of a tensor, - dynamic_range – equal to `torch.log2(amax) - torch.log2(nonzero_amin)` - max_blockwise_dynamic_range – Computes the maximum dynamic range `log2(amax) - log2(nonzero_amin)` across all blocks of size block_size within the tensor, where block_size is an integer specifying the block size. For `dim=1` there are block_size consecutive elements in the block, for `dim=2` the block is block_size x block_size elements tile. - + - block_size: int, default = 32 - dims: int, default = 1, allowed values are 1 and 2 @@ -97,7 +97,7 @@ class LogTensorStats(BaseLogTensorStats): """ def _is_supported_stat(self, stat: str | Dict): - """ Returns True if the stat is supported by this feature, False otherwise. """ + """Returns True if the stat is supported by this feature, False otherwise.""" if isinstance(stat, dict): stat_name = list(stat.keys())[0] if stat_name == "max_blockwise_dynamic_range": @@ -123,10 +123,10 @@ def _is_supported_stat(self, stat: str | Dict): def _parse_max_blockwise_dynamic_range_stats(self, stats: List[str | Dict]) -> List[str]: """ - Adds all max_blockwise_dynamic_range stats to the stat computation logic. - Changes the types of the stats from Dict to str, for other stats nothing is changed. - For example, if the stats is [{"max_blockwise_dynamic_range": {"block_size": 32, "dims": 1}}], - it will be changed to ["max_blockwise_dynamic_range_block_size_32_dims_1"]. + Adds all max_blockwise_dynamic_range stats to the stat computation logic. + Changes the types of the stats from Dict to str, for other stats nothing is changed. + For example, if the stats is [{"max_blockwise_dynamic_range": {"block_size": 32, "dims": 1}}], + it will be changed to ["max_blockwise_dynamic_range_block_size_32_dims_1"]. """ parsed_stats = [] for stat in stats: From 69d78fe047573dc4c8b3d9bb90e52e31e33b3820 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 22 Sep 2025 10:55:18 +0000 Subject: [PATCH 14/14] fix Signed-off-by: Pawel Gadzinski --- .../debug/features/utils/stats_computation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 4962b2a20c..e50cafffb9 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -49,10 +49,13 @@ def compute_max_blockwise_dynamic_range(tensor, block_size, dims): for dims = 2 blocks contain block_size x block_size elements. """ total_numel = tensor.numel() + assert dims in [1, 2], f"dims must be 1 or 2, got {dims}" + + # torch.compile friendly code - standard ** power does not work with jit + total_block_size = block_size * block_size if dims == 2 else block_size assert ( - total_numel % (block_size**dims) == 0 + total_numel % total_block_size == 0 ), f"Tensor numel ({total_numel}) is not divisible by block_size ({block_size})." - assert dims in [1, 2], f"dims must be 1 or 2, got {dims}" tensor = tensor.abs().float() if dims == 1: