diff --git a/docs/debug.rst b/docs/debug.rst index d33568ea3b..b48302576d 100644 --- a/docs/debug.rst +++ b/docs/debug.rst @@ -11,4 +11,5 @@ Precision debug tools debug/1_getting_started.rst debug/2_config_file_structure.rst debug/api - debug/4_distributed.rst \ No newline at end of file + debug/4_distributed.rst + debug/5_custom_feature.ipynb \ No newline at end of file diff --git a/docs/debug/custom_feature_example/config.yaml b/docs/debug/custom_feature_example/config.yaml new file mode 100644 index 0000000000..705e02049e --- /dev/null +++ b/docs/debug/custom_feature_example/config.yaml @@ -0,0 +1,9 @@ +section_name_1: + enabled: true + layers: + layer_name_regex_pattern: ".*" + transformer_engine: + CustomPrecisionExampleFeature: + enabled: true + gemms: [fprop, wgrad, dgrad] + tensors: [activation, gradient] \ No newline at end of file diff --git a/docs/debug/custom_feature_example/features/cutom_precision_example_feature.py b/docs/debug/custom_feature_example/features/cutom_precision_example_feature.py new file mode 100644 index 0000000000..1143c30fe3 --- /dev/null +++ b/docs/debug/custom_feature_example/features/cutom_precision_example_feature.py @@ -0,0 +1,64 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Custom precision example feature""" + +from typing import Dict, Optional, Tuple +from nvdlfw_inspect.logging import MetricLogger +from nvdlfw_inspect.registry import Registry, api_method + +import torch +from transformer_engine.debug.features.api import TEConfigAPIMapper +from transformer_engine.pytorch.tensor import Quantizer + + +def custom_precision_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: + amax = torch.amax(tensor) + scale = 1.0 / amax + q_tensor = tensor * scale + + # tensor to -1/0/1 range (-1, -0.5) -> -1, (-0.5, , 0.5) -> 0, (0.5, 1) -> 1 + out_tensor = torch.where(q_tensor < -0.5, -1, torch.where(q_tensor > 0.5, 1, 0)) + return out_tensor, scale + + +def custom_precision_dequantize(tensor: torch.Tensor, scale: float) -> torch.Tensor: + return tensor * scale + + +@Registry.register_feature(namespace="transformer_engine") +class CustomPrecisionExampleFeature(TEConfigAPIMapper): + + @api_method + def modify_tensor_enabled( + self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int + ): + """API call used to determine whether to run process_tensor() in the forward.""" + return True, iteration + 1 + + @api_method + def modify_tensor( + self, + config, + layer_name: str, + gemm: str, + tensor_name: str, + tensor: torch.Tensor, + iteration: int, + default_quantizer: Quantizer, + out: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + ): # pylint: disable=unused-argument + """API call used to process the tensor.""" + + q_tensor, scale = custom_precision_quantize(tensor) + + MetricLogger.log_scalar( + f"custom_precision_scale {layer_name}_{gemm}_{tensor_name}", + scale, + iteration=iteration, + ) + + dq_tensor = custom_precision_dequantize(q_tensor, scale) + return dq_tensor diff --git a/docs/debug/custom_feature_example/utils.py b/docs/debug/custom_feature_example/utils.py new file mode 100644 index 0000000000..41a20e6c54 --- /dev/null +++ b/docs/debug/custom_feature_example/utils.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import time +import torch +import transformer_engine.pytorch as te +import nvdlfw_inspect.api as debug_api +from torch.utils.tensorboard import SummaryWriter + +writer = SummaryWriter(log_dir="runs/{}".format(time.time())) + + +def init_model() -> torch.nn.Module: + return te.TransformerLayer( + hidden_size=1024, + ffn_hidden_size=1024, + num_attention_heads=16, + ) + + +def run_example_fit(model: torch.nn.Module): + output_tensor_ref = torch.randn(1, 1, 1024).cuda() + input_tensor = torch.randn(1, 1, 1024).cuda() + + optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) + + for i in range(1000): + output = model(input_tensor) + loss = torch.nn.functional.mse_loss(output, output_tensor_ref) + loss.backward() + optimizer.step() + optimizer.zero_grad() + get_tb_writer().add_scalar("Loss", loss.item(), i) + + debug_api.step() + + +def get_tb_writer(): + return writer diff --git a/docs/debug/img/log_act.png b/docs/debug/img/log_act.png new file mode 100644 index 0000000000..5208aefc24 Binary files /dev/null and b/docs/debug/img/log_act.png differ diff --git a/docs/debug/img/loss.png b/docs/debug/img/loss.png new file mode 100644 index 0000000000..2a0165ba76 Binary files /dev/null and b/docs/debug/img/loss.png differ