diff --git a/benchmarks/linear/benchmark_linear_cpu_overhead.py b/benchmarks/linear/benchmark_linear_cpu_overhead.py new file mode 100644 index 0000000000..59ae371e67 --- /dev/null +++ b/benchmarks/linear/benchmark_linear_cpu_overhead.py @@ -0,0 +1,125 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from typing import List +import torch +import time +import argparse + +import transformer_engine.pytorch as te + + +def run_once( + module: torch.nn.Module, + args: List[torch.Tensor], + iters=2000, + use_te: bool = True, + is_first_microbatch: bool = True, +): + if use_te: + for _ in range(iters): + module(*args, is_first_microbatch=is_first_microbatch) + else: + for _ in range(iters): + module(*args) + + +def speedometer( + module: torch.nn.Module, + args: List[torch.Tensor], + timing_iters: int = 2000, + warmup_iters: int = 100, + num_rounds: int = 5, + use_te: bool = True, +) -> float: + """Measure average run time for a PyTorch module""" + # warm up + run_once(module, args, iters=warmup_iters, use_te=use_te, is_first_microbatch=True) + + gpu_times = [] + cpu_times = [] + for round_idx in range(num_rounds): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + cpu_start = time.time() + # run timing + run_once(module, args, iters=timing_iters, use_te=use_te, is_first_microbatch=False) + cpu_end = time.time() + end.record() + torch.cuda.synchronize() + gpu_elapsed = start.elapsed_time(end) + cpu_elapsed = (cpu_end - cpu_start) * 1000 + gpu_times.append(gpu_elapsed) + cpu_times.append(cpu_elapsed) + print( + f"Round {round_idx+1}/{num_rounds}: GPU {gpu_elapsed/timing_iters*1000:.2f} µs, CPU" + f" {cpu_elapsed/timing_iters*1000:.2f} µs" + ) + print( + f"Average GPU time over {num_rounds} rounds:" + f" {sum(gpu_times)/(num_rounds*timing_iters)*1000:.2f} µs" + ) + print( + f"Average CPU time over {num_rounds} rounds:" + f" {sum(cpu_times)/(num_rounds*timing_iters)*1000:.2f} µs" + ) + + return sum(gpu_times) / num_rounds + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark torch.nn.Linear performance and CPU overhead." + ) + parser.add_argument("--hidden_size", type=int, default=3072, help="Hidden size") + parser.add_argument("--seq_length", type=int, default=2048, help="Sequence length") + parser.add_argument("--warmup", type=int, default=500, help="Number of warmup iterations") + parser.add_argument( + "--timing_iters", type=int, default=2000, help="Number of timing iterations per round" + ) + parser.add_argument("--num_rounds", type=int, default=5, help="Number of timing rounds") + parser.add_argument( + "--backend", + type=str, + choices=["torch", "te"], + default="te", + help="Linear backend: torch or te", + ) + args = parser.parse_args() + + x = torch.randn( + (args.seq_length, args.hidden_size), dtype=torch.bfloat16, device="cuda", requires_grad=True + ) + use_te = True + if args.backend == "torch": + model = ( + torch.nn.Linear(args.hidden_size, args.hidden_size, bias=False) + .to(torch.bfloat16) + .cuda() + ) + use_te = False + else: + model = te.Linear(args.hidden_size, args.hidden_size, bias=False, device="cuda").to( + torch.bfloat16 + ) + with torch.no_grad(): + avg_gpu_time_per_round = speedometer( + model, + [x], + timing_iters=args.timing_iters, + warmup_iters=args.warmup, + num_rounds=args.num_rounds, + use_te=use_te, + ) + + total_ops = 2 * args.hidden_size * args.hidden_size * args.seq_length * args.timing_iters + + tflops = total_ops / avg_gpu_time_per_round / 1e9 + print(f"Estimated TFLOP/s: {tflops:.2f}") + + +if __name__ == "__main__": + main() diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 55654989a7..8fda938342 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -641,11 +641,15 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { } int nvte_is_non_tn_fp8_gemm_supported() { - int deviceComputeCapability = - transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()); - - // Note: this is temporary restriction and should be lifted in the future. - // (remove the note once it's done.) - return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) || - deviceComputeCapability >= 130; + static int cached_result = 0; + static std::once_flag flag; + std::call_once(flag, []() { + int deviceComputeCapability = + transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()); + // Note: this is temporary restriction and should be lifted in the future. + // (remove the note once it's done.) + cached_result = (deviceComputeCapability >= 100 && deviceComputeCapability < 120) || + deviceComputeCapability >= 130; + }); + return cached_result; } diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 5749d96c9f..20ba2af592 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -5,6 +5,7 @@ """GroupedLinear API""" from typing import Union, Optional, Callable, Tuple, List import warnings +import contextlib import functools import torch @@ -745,9 +746,14 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with torch.cuda.device( - getattr(self, list(self.named_parameters())[0][0]).device - ), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: + if is_first_microbatch is None or is_first_microbatch: + device_ctx = torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ) + else: + device_ctx = contextlib.nullcontext() + + with device_ctx, self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: weight_tensors = self._get_weight_tensors() bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4d30be414e..f719c14325 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -8,6 +8,7 @@ from typing import Callable, Dict, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op +import contextlib import torch from torch.nn import init @@ -1517,10 +1518,16 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with torch.cuda.device( - getattr(self, list(self.named_parameters())[0][0]).device - ), self.prepare_forward( - inp, allow_non_contiguous=False # removed .contiguous from inside the layer + if is_first_microbatch is None or is_first_microbatch: + device_ctx = torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ) + else: + device_ctx = contextlib.nullcontext() + + with device_ctx, self.prepare_forward( + inp, + allow_non_contiguous=False, # removed .contiguous from inside the layer ) as inp: # Get concatenated weight and bias tensors diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9f799c5538..e8f4b4b2e2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -8,6 +8,7 @@ from typing import Callable, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op +import contextlib import torch from torch.nn.parameter import Parameter @@ -1768,9 +1769,14 @@ def forward( if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True - with torch.cuda.device( - getattr(self, list(self.named_parameters())[0][0]).device - ), self.prepare_forward(inp, num_gemms=2) as inp: + if is_first_microbatch is None or is_first_microbatch: + device_ctx = torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ) + else: + device_ctx = contextlib.nullcontext() + + with device_ctx, self.prepare_forward(inp, num_gemms=2) as inp: quantizers = ( self._get_quantizers(fp8_output) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7e526245c1..a0efdd8d05 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -7,6 +7,7 @@ from functools import reduce from operator import mul as multiply_op import warnings +import contextlib import torch @@ -1401,9 +1402,14 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with torch.cuda.device( - getattr(self, list(self.named_parameters())[0][0]).device - ), self.prepare_forward( + if is_first_microbatch is None or is_first_microbatch: + device_ctx = torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ) + else: + device_ctx = contextlib.nullcontext() + + with device_ctx, self.prepare_forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: