Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/debug.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ Precision debug tools
debug/1_getting_started.rst
debug/2_config_file_structure.rst
debug/api
debug/4_distributed.rst
debug/4_distributed.rst
debug/5_custom_feature.ipynb
9 changes: 9 additions & 0 deletions docs/debug/custom_feature_example/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
section_name_1:
enabled: true
layers:
layer_name_regex_pattern: ".*"
transformer_engine:
CustomPrecisionExampleFeature:
enabled: true
gemms: [fprop, wgrad, dgrad]
tensors: [activation, gradient]
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Custom precision example feature"""

from typing import Dict, Optional, Tuple
from nvdlfw_inspect.logging import MetricLogger
from nvdlfw_inspect.registry import Registry, api_method

import torch
from transformer_engine.debug.features.api import TEConfigAPIMapper
from transformer_engine.pytorch.tensor import Quantizer


def custom_precision_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
amax = torch.amax(tensor)
scale = 1.0 / amax
q_tensor = tensor * scale

# tensor to -1/0/1 range (-1, -0.5) -> -1, (-0.5, , 0.5) -> 0, (0.5, 1) -> 1
out_tensor = torch.where(q_tensor < -0.5, -1, torch.where(q_tensor > 0.5, 1, 0))
return out_tensor, scale


def custom_precision_dequantize(tensor: torch.Tensor, scale: float) -> torch.Tensor:
return tensor * scale


@Registry.register_feature(namespace="transformer_engine")
class CustomPrecisionExampleFeature(TEConfigAPIMapper):

@api_method
def modify_tensor_enabled(
self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int
):
"""API call used to determine whether to run process_tensor() in the forward."""
return True, iteration + 1

@api_method
def modify_tensor(
self,
config,
layer_name: str,
gemm: str,
tensor_name: str,
tensor: torch.Tensor,
iteration: int,
default_quantizer: Quantizer,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
): # pylint: disable=unused-argument
"""API call used to process the tensor."""

q_tensor, scale = custom_precision_quantize(tensor)

MetricLogger.log_scalar(
f"custom_precision_scale {layer_name}_{gemm}_{tensor_name}",
scale,
iteration=iteration,
)

dq_tensor = custom_precision_dequantize(q_tensor, scale)
return dq_tensor
40 changes: 40 additions & 0 deletions docs/debug/custom_feature_example/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import time
import torch
import transformer_engine.pytorch as te
import nvdlfw_inspect.api as debug_api
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="runs/{}".format(time.time()))


def init_model() -> torch.nn.Module:
return te.TransformerLayer(
hidden_size=1024,
ffn_hidden_size=1024,
num_attention_heads=16,
)


def run_example_fit(model: torch.nn.Module):
output_tensor_ref = torch.randn(1, 1, 1024).cuda()
input_tensor = torch.randn(1, 1, 1024).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

for i in range(1000):
output = model(input_tensor)
loss = torch.nn.functional.mse_loss(output, output_tensor_ref)
loss.backward()
optimizer.step()
optimizer.zero_grad()
get_tb_writer().add_scalar("Loss", loss.item(), i)

debug_api.step()


def get_tb_writer():
return writer
Binary file added docs/debug/img/log_act.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/debug/img/loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.