From 488e6c924b512b8401c4d4b3d60334ede3852ffb Mon Sep 17 00:00:00 2001 From: zhongboz Date: Tue, 2 Sep 2025 10:52:57 -0700 Subject: [PATCH 1/5] fix cpu overhead Signed-off-by: zhongboz --- benchmarks/linear/benchmark_linear.py | 88 +++++++++++++++++++ .../linear/benchmark_linear_cpu_overhead.py | 74 ++++++++++++++++ transformer_engine/pytorch/module/linear.py | 10 ++- 3 files changed, 169 insertions(+), 3 deletions(-) create mode 100644 benchmarks/linear/benchmark_linear.py create mode 100644 benchmarks/linear/benchmark_linear_cpu_overhead.py diff --git a/benchmarks/linear/benchmark_linear.py b/benchmarks/linear/benchmark_linear.py new file mode 100644 index 0000000000..ab5a51a920 --- /dev/null +++ b/benchmarks/linear/benchmark_linear.py @@ -0,0 +1,88 @@ +import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch import fp8_model_init +from transformer_engine.common import recipe +import time +import csv + +def run_once(in_features, out_features, batch_size, iters=500): + + model = te.Linear(in_features, out_features, bias=True, device="cuda").half() + inp = torch.randn(batch_size, in_features, device="cuda").half() + torch_model = torch.nn.Linear(in_features, out_features, bias=True).cuda().half() + + fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3) + + + # Warm up for te model + for _ in range(50): + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + out = model(inp, is_first_microbatch=True) + # out = model(inp, is_first_microbatch=True) + # Warm up for torch model + for _ in range(50): + with torch.no_grad(): + out = torch_model(inp) + torch.cuda.synchronize() + + + times = [] + with torch.no_grad(): + start = time.time() + for i in range(iters): + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + out = model(inp, is_first_microbatch=False) + # out = model(inp, is_first_microbatch=False) + torch.cuda.synchronize() + end = time.time() + avg_time_te = (end - start) / iters + + start = time.time() + torch_times = [] + for i in range(iters): + with torch.no_grad(): + out = torch_model(inp) + torch.cuda.synchronize() + end = time.time() + avg_time_torch = (end - start) / iters + + del model + del inp + del torch_model + torch.cuda.empty_cache() + + return avg_time_te, avg_time_torch + +if __name__ == "__main__": + in_feature_list = [1280, 5120, 10240] + out_features_list = [1280, 5120, 10240] + batch_size_list = [512, 1024, 2048] + + # in_feature_list = [5120] + # out_features_list = [5120] + # batch_size_list = [1024] + + results = [] + for in_features in in_feature_list: + for out_features in out_features_list: + for batch_size in batch_size_list: + print(f"=============Test in_features X out_features X batch_size {in_features}x{out_features}x{batch_size}") + te_time, torch_time = run_once(in_features, out_features, batch_size) + print(f"te linear average costime: {te_time*1000:.6f} ms") + print(f"torch linear average costime: {torch_time*1000:.6f} ms") + results.append([in_features, out_features, batch_size, te_time*1000, torch_time*1000]) + + + import csv + import sys + + # Print CSV header and rows in a better format using csv.writer + writer = csv.writer(sys.stdout) + writer.writerow(["in_features", "out_features", "batch_size", "te_time(ms)", "torch_time(ms)"]) + for row in results: + writer.writerow(row) + + # print out speedup over torch with problem shape + for row in results: + in_features, out_features, batch_size, te_time, torch_time = row + print(f"Shape: in_features={in_features}, out_features={out_features}, batch_size={batch_size} | Speedup over torch: {torch_time/te_time}") \ No newline at end of file diff --git a/benchmarks/linear/benchmark_linear_cpu_overhead.py b/benchmarks/linear/benchmark_linear_cpu_overhead.py new file mode 100644 index 0000000000..97cfc10ea3 --- /dev/null +++ b/benchmarks/linear/benchmark_linear_cpu_overhead.py @@ -0,0 +1,74 @@ +from typing import List +import torch +import time +import argparse + +import transformer_engine.pytorch as te + + +def speedometer( + module: torch.nn.Module, + args: List[torch.Tensor], + timing_iters: int = 500, + warmup_iters: int = 50, + num_rounds: int = 5, +) -> float: + """Measure average run time for a PyTorch module""" + for _ in range(warmup_iters): + module(*args) + + gpu_times = [] + cpu_times = [] + for round_idx in range(num_rounds): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + cpu_start = time.time() + for _ in range(timing_iters): + module(*args) + cpu_end = time.time() + end.record() + torch.cuda.synchronize() + gpu_elapsed = start.elapsed_time(end) + cpu_elapsed = (cpu_end - cpu_start) * 1000 + gpu_times.append(gpu_elapsed) + cpu_times.append(cpu_elapsed) + print( + f"Round {round_idx+1}/{num_rounds}: GPU {gpu_elapsed/timing_iters*1000:.2f} µs, CPU {cpu_elapsed/timing_iters*1000:.2f} µs" + ) + print(f"Average GPU time over {num_rounds} rounds: {sum(gpu_times)/(num_rounds*timing_iters)*1000:.2f} µs") + print(f"Average CPU time over {num_rounds} rounds: {sum(cpu_times)/(num_rounds*timing_iters)*1000:.2f} µs") + + return sum(gpu_times) / num_rounds + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark torch.nn.Linear performance and CPU overhead.") + parser.add_argument("--hidden_size", type=int, default=2048, help="Hidden size") + parser.add_argument("--seq_length", type=int, default=8192, help="Sequence length") + parser.add_argument("--warmup", type=int, default=500, help="Number of warmup iterations") + parser.add_argument("--timing_iters", type=int, default=500, help="Number of timing iterations per round") + parser.add_argument("--num_rounds", type=int, default=3, help="Number of timing rounds") + parser.add_argument( + "--backend", type=str, choices=["torch", "te"], default="te", help="Linear backend: torch or te" + ) + args = parser.parse_args() + + x = torch.randn((args.seq_length, args.hidden_size), dtype=torch.bfloat16, device="cuda", requires_grad=True) + if args.backend == "torch": + model = torch.nn.Linear(args.hidden_size, args.hidden_size, bias=False).to(torch.bfloat16).cuda() + else: + model = te.Linear(args.hidden_size, args.hidden_size, bias=False, device="cuda").to(torch.bfloat16) + avg_gpu_time_per_round = speedometer( + model, [x], timing_iters=args.timing_iters, warmup_iters=args.warmup, num_rounds=args.num_rounds + ) + + total_ops = 2 * args.hidden_size * args.hidden_size * args.seq_length * args.timing_iters + + tflops = total_ops / avg_gpu_time_per_round / 1e9 + print(f"Estimated TFLOP/s: {tflops:.2f}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2ce6fb4c1d..432573c27a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -7,6 +7,7 @@ from functools import reduce from operator import mul as multiply_op import warnings +import contextlib import torch @@ -1387,9 +1388,12 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with torch.cuda.device( - getattr(self, list(self.named_parameters())[0][0]).device - ), self.prepare_forward( + if is_first_microbatch is None or is_first_microbatch: + device_ctx = torch.cuda.device(getattr(self, list(self.named_parameters())[0][0]).device) + else: + device_ctx = contextlib.nullcontext() + + with device_ctx, self.prepare_forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: From a831adb055bbb7aaf2ee7b36e7aefadfce419558 Mon Sep 17 00:00:00 2001 From: zhongboz Date: Tue, 2 Sep 2025 12:40:45 -0700 Subject: [PATCH 2/5] cache tn layout check Signed-off-by: zhongboz --- benchmarks/linear/benchmark_linear.py | 88 ------------------- .../linear/benchmark_linear_cpu_overhead.py | 39 +++++--- .../common/transformer_engine.cpp | 18 ++-- 3 files changed, 37 insertions(+), 108 deletions(-) delete mode 100644 benchmarks/linear/benchmark_linear.py diff --git a/benchmarks/linear/benchmark_linear.py b/benchmarks/linear/benchmark_linear.py deleted file mode 100644 index ab5a51a920..0000000000 --- a/benchmarks/linear/benchmark_linear.py +++ /dev/null @@ -1,88 +0,0 @@ -import torch -import transformer_engine.pytorch as te -from transformer_engine.pytorch import fp8_model_init -from transformer_engine.common import recipe -import time -import csv - -def run_once(in_features, out_features, batch_size, iters=500): - - model = te.Linear(in_features, out_features, bias=True, device="cuda").half() - inp = torch.randn(batch_size, in_features, device="cuda").half() - torch_model = torch.nn.Linear(in_features, out_features, bias=True).cuda().half() - - fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3) - - - # Warm up for te model - for _ in range(50): - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - out = model(inp, is_first_microbatch=True) - # out = model(inp, is_first_microbatch=True) - # Warm up for torch model - for _ in range(50): - with torch.no_grad(): - out = torch_model(inp) - torch.cuda.synchronize() - - - times = [] - with torch.no_grad(): - start = time.time() - for i in range(iters): - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - out = model(inp, is_first_microbatch=False) - # out = model(inp, is_first_microbatch=False) - torch.cuda.synchronize() - end = time.time() - avg_time_te = (end - start) / iters - - start = time.time() - torch_times = [] - for i in range(iters): - with torch.no_grad(): - out = torch_model(inp) - torch.cuda.synchronize() - end = time.time() - avg_time_torch = (end - start) / iters - - del model - del inp - del torch_model - torch.cuda.empty_cache() - - return avg_time_te, avg_time_torch - -if __name__ == "__main__": - in_feature_list = [1280, 5120, 10240] - out_features_list = [1280, 5120, 10240] - batch_size_list = [512, 1024, 2048] - - # in_feature_list = [5120] - # out_features_list = [5120] - # batch_size_list = [1024] - - results = [] - for in_features in in_feature_list: - for out_features in out_features_list: - for batch_size in batch_size_list: - print(f"=============Test in_features X out_features X batch_size {in_features}x{out_features}x{batch_size}") - te_time, torch_time = run_once(in_features, out_features, batch_size) - print(f"te linear average costime: {te_time*1000:.6f} ms") - print(f"torch linear average costime: {torch_time*1000:.6f} ms") - results.append([in_features, out_features, batch_size, te_time*1000, torch_time*1000]) - - - import csv - import sys - - # Print CSV header and rows in a better format using csv.writer - writer = csv.writer(sys.stdout) - writer.writerow(["in_features", "out_features", "batch_size", "te_time(ms)", "torch_time(ms)"]) - for row in results: - writer.writerow(row) - - # print out speedup over torch with problem shape - for row in results: - in_features, out_features, batch_size, te_time, torch_time = row - print(f"Shape: in_features={in_features}, out_features={out_features}, batch_size={batch_size} | Speedup over torch: {torch_time/te_time}") \ No newline at end of file diff --git a/benchmarks/linear/benchmark_linear_cpu_overhead.py b/benchmarks/linear/benchmark_linear_cpu_overhead.py index 97cfc10ea3..bfe23a85b7 100644 --- a/benchmarks/linear/benchmark_linear_cpu_overhead.py +++ b/benchmarks/linear/benchmark_linear_cpu_overhead.py @@ -6,16 +6,26 @@ import transformer_engine.pytorch as te +def run_once(module: torch.nn.Module, args: List[torch.Tensor], iters = 2000, use_te: bool = True, is_first_microbatch: bool = True): + if use_te: + for _ in range(iters): + module(*args, is_first_microbatch=is_first_microbatch) + else: + for _ in range(iters): + module(*args) + + def speedometer( module: torch.nn.Module, args: List[torch.Tensor], - timing_iters: int = 500, - warmup_iters: int = 50, + timing_iters: int = 2000, + warmup_iters: int = 100, num_rounds: int = 5, + use_te: bool = True, ) -> float: """Measure average run time for a PyTorch module""" - for _ in range(warmup_iters): - module(*args) + # warm up + run_once(module, args, iters=warmup_iters, use_te=use_te, is_first_microbatch=True) gpu_times = [] cpu_times = [] @@ -25,8 +35,8 @@ def speedometer( torch.cuda.synchronize() start.record() cpu_start = time.time() - for _ in range(timing_iters): - module(*args) + # run timing + run_once(module, args, iters=timing_iters, use_te=use_te, is_first_microbatch=False) cpu_end = time.time() end.record() torch.cuda.synchronize() @@ -45,24 +55,27 @@ def speedometer( def main(): parser = argparse.ArgumentParser(description="Benchmark torch.nn.Linear performance and CPU overhead.") - parser.add_argument("--hidden_size", type=int, default=2048, help="Hidden size") - parser.add_argument("--seq_length", type=int, default=8192, help="Sequence length") + parser.add_argument("--hidden_size", type=int, default=3072, help="Hidden size") + parser.add_argument("--seq_length", type=int, default=2048, help="Sequence length") parser.add_argument("--warmup", type=int, default=500, help="Number of warmup iterations") - parser.add_argument("--timing_iters", type=int, default=500, help="Number of timing iterations per round") - parser.add_argument("--num_rounds", type=int, default=3, help="Number of timing rounds") + parser.add_argument("--timing_iters", type=int, default=2000, help="Number of timing iterations per round") + parser.add_argument("--num_rounds", type=int, default=5, help="Number of timing rounds") parser.add_argument( "--backend", type=str, choices=["torch", "te"], default="te", help="Linear backend: torch or te" ) args = parser.parse_args() x = torch.randn((args.seq_length, args.hidden_size), dtype=torch.bfloat16, device="cuda", requires_grad=True) + use_te = True if args.backend == "torch": model = torch.nn.Linear(args.hidden_size, args.hidden_size, bias=False).to(torch.bfloat16).cuda() + use_te = False else: model = te.Linear(args.hidden_size, args.hidden_size, bias=False, device="cuda").to(torch.bfloat16) - avg_gpu_time_per_round = speedometer( - model, [x], timing_iters=args.timing_iters, warmup_iters=args.warmup, num_rounds=args.num_rounds - ) + with torch.no_grad(): + avg_gpu_time_per_round = speedometer( + model, [x], timing_iters=args.timing_iters, warmup_iters=args.warmup, num_rounds=args.num_rounds, use_te=use_te + ) total_ops = 2 * args.hidden_size * args.hidden_size * args.seq_length * args.timing_iters diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 55654989a7..8fda938342 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -641,11 +641,15 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { } int nvte_is_non_tn_fp8_gemm_supported() { - int deviceComputeCapability = - transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()); - - // Note: this is temporary restriction and should be lifted in the future. - // (remove the note once it's done.) - return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) || - deviceComputeCapability >= 130; + static int cached_result = 0; + static std::once_flag flag; + std::call_once(flag, []() { + int deviceComputeCapability = + transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()); + // Note: this is temporary restriction and should be lifted in the future. + // (remove the note once it's done.) + cached_result = (deviceComputeCapability >= 100 && deviceComputeCapability < 120) || + deviceComputeCapability >= 130; + }); + return cached_result; } From dc2532a3522390889505a38a60a688d93cff1cee Mon Sep 17 00:00:00 2001 From: zhongboz Date: Tue, 2 Sep 2025 12:47:13 -0700 Subject: [PATCH 3/5] apply to more modules Signed-off-by: zhongboz --- .../linear/benchmark_linear_cpu_overhead.py | 58 +++++++++++++++---- .../pytorch/module/grouped_linear.py | 11 +++- .../pytorch/module/layernorm_linear.py | 14 +++-- .../pytorch/module/layernorm_mlp.py | 11 +++- transformer_engine/pytorch/module/linear.py | 4 +- 5 files changed, 75 insertions(+), 23 deletions(-) diff --git a/benchmarks/linear/benchmark_linear_cpu_overhead.py b/benchmarks/linear/benchmark_linear_cpu_overhead.py index bfe23a85b7..8020e5df8f 100644 --- a/benchmarks/linear/benchmark_linear_cpu_overhead.py +++ b/benchmarks/linear/benchmark_linear_cpu_overhead.py @@ -6,7 +6,13 @@ import transformer_engine.pytorch as te -def run_once(module: torch.nn.Module, args: List[torch.Tensor], iters = 2000, use_te: bool = True, is_first_microbatch: bool = True): +def run_once( + module: torch.nn.Module, + args: List[torch.Tensor], + iters=2000, + use_te: bool = True, + is_first_microbatch: bool = True, +): if use_te: for _ in range(iters): module(*args, is_first_microbatch=is_first_microbatch) @@ -45,36 +51,64 @@ def speedometer( gpu_times.append(gpu_elapsed) cpu_times.append(cpu_elapsed) print( - f"Round {round_idx+1}/{num_rounds}: GPU {gpu_elapsed/timing_iters*1000:.2f} µs, CPU {cpu_elapsed/timing_iters*1000:.2f} µs" + f"Round {round_idx+1}/{num_rounds}: GPU {gpu_elapsed/timing_iters*1000:.2f} µs, CPU" + f" {cpu_elapsed/timing_iters*1000:.2f} µs" ) - print(f"Average GPU time over {num_rounds} rounds: {sum(gpu_times)/(num_rounds*timing_iters)*1000:.2f} µs") - print(f"Average CPU time over {num_rounds} rounds: {sum(cpu_times)/(num_rounds*timing_iters)*1000:.2f} µs") + print( + f"Average GPU time over {num_rounds} rounds:" + f" {sum(gpu_times)/(num_rounds*timing_iters)*1000:.2f} µs" + ) + print( + f"Average CPU time over {num_rounds} rounds:" + f" {sum(cpu_times)/(num_rounds*timing_iters)*1000:.2f} µs" + ) return sum(gpu_times) / num_rounds def main(): - parser = argparse.ArgumentParser(description="Benchmark torch.nn.Linear performance and CPU overhead.") + parser = argparse.ArgumentParser( + description="Benchmark torch.nn.Linear performance and CPU overhead." + ) parser.add_argument("--hidden_size", type=int, default=3072, help="Hidden size") parser.add_argument("--seq_length", type=int, default=2048, help="Sequence length") parser.add_argument("--warmup", type=int, default=500, help="Number of warmup iterations") - parser.add_argument("--timing_iters", type=int, default=2000, help="Number of timing iterations per round") + parser.add_argument( + "--timing_iters", type=int, default=2000, help="Number of timing iterations per round" + ) parser.add_argument("--num_rounds", type=int, default=5, help="Number of timing rounds") parser.add_argument( - "--backend", type=str, choices=["torch", "te"], default="te", help="Linear backend: torch or te" + "--backend", + type=str, + choices=["torch", "te"], + default="te", + help="Linear backend: torch or te", ) args = parser.parse_args() - x = torch.randn((args.seq_length, args.hidden_size), dtype=torch.bfloat16, device="cuda", requires_grad=True) + x = torch.randn( + (args.seq_length, args.hidden_size), dtype=torch.bfloat16, device="cuda", requires_grad=True + ) use_te = True if args.backend == "torch": - model = torch.nn.Linear(args.hidden_size, args.hidden_size, bias=False).to(torch.bfloat16).cuda() + model = ( + torch.nn.Linear(args.hidden_size, args.hidden_size, bias=False) + .to(torch.bfloat16) + .cuda() + ) use_te = False else: - model = te.Linear(args.hidden_size, args.hidden_size, bias=False, device="cuda").to(torch.bfloat16) + model = te.Linear(args.hidden_size, args.hidden_size, bias=False, device="cuda").to( + torch.bfloat16 + ) with torch.no_grad(): avg_gpu_time_per_round = speedometer( - model, [x], timing_iters=args.timing_iters, warmup_iters=args.warmup, num_rounds=args.num_rounds, use_te=use_te + model, + [x], + timing_iters=args.timing_iters, + warmup_iters=args.warmup, + num_rounds=args.num_rounds, + use_te=use_te, ) total_ops = 2 * args.hidden_size * args.hidden_size * args.seq_length * args.timing_iters @@ -84,4 +118,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 3d7a5efaca..bfca3a8567 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -748,9 +748,14 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with torch.cuda.device( - getattr(self, list(self.named_parameters())[0][0]).device - ), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: + if is_first_microbatch is None or is_first_microbatch: + device_ctx = torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ) + else: + device_ctx = contextlib.nullcontext() + + with device_ctx, self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: weight_tensors = self._get_weight_tensors() bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index cd02f31132..9d54e74830 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1502,10 +1502,16 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with torch.cuda.device( - getattr(self, list(self.named_parameters())[0][0]).device - ), self.prepare_forward( - inp, allow_non_contiguous=False # removed .contiguous from inside the layer + if is_first_microbatch is None or is_first_microbatch: + device_ctx = torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ) + else: + device_ctx = contextlib.nullcontext() + + with device_ctx, self.prepare_forward( + inp, + allow_non_contiguous=False, # removed .contiguous from inside the layer ) as inp: # Get concatenated weight and bias tensors diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 182bf99f86..ef0ef10d2c 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1764,9 +1764,14 @@ def forward( if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True - with torch.cuda.device( - getattr(self, list(self.named_parameters())[0][0]).device - ), self.prepare_forward(inp, num_gemms=2) as inp: + if is_first_microbatch is None or is_first_microbatch: + device_ctx = torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ) + else: + device_ctx = contextlib.nullcontext() + + with device_ctx, self.prepare_forward(inp, num_gemms=2) as inp: quantizers = ( self._get_quantizers(fp8_output) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 432573c27a..807a7270f7 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1389,7 +1389,9 @@ def forward( fp8_grad = True if is_first_microbatch is None or is_first_microbatch: - device_ctx = torch.cuda.device(getattr(self, list(self.named_parameters())[0][0]).device) + device_ctx = torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ) else: device_ctx = contextlib.nullcontext() From 5fbfb01bd850c36008b9a19239655b6eab4048c3 Mon Sep 17 00:00:00 2001 From: zhongboz Date: Tue, 2 Sep 2025 12:48:29 -0700 Subject: [PATCH 4/5] license Signed-off-by: zhongboz --- benchmarks/linear/benchmark_linear_cpu_overhead.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmarks/linear/benchmark_linear_cpu_overhead.py b/benchmarks/linear/benchmark_linear_cpu_overhead.py index 8020e5df8f..59ae371e67 100644 --- a/benchmarks/linear/benchmark_linear_cpu_overhead.py +++ b/benchmarks/linear/benchmark_linear_cpu_overhead.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + from typing import List import torch import time From 3af9438e0a4a029058cc1a36e3188daa74c1fd8d Mon Sep 17 00:00:00 2001 From: zhongboz Date: Tue, 2 Sep 2025 13:59:44 -0700 Subject: [PATCH 5/5] lint Signed-off-by: zhongboz --- transformer_engine/pytorch/module/grouped_linear.py | 1 + transformer_engine/pytorch/module/layernorm_linear.py | 1 + transformer_engine/pytorch/module/layernorm_mlp.py | 1 + 3 files changed, 3 insertions(+) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index bfca3a8567..03f59e388b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -5,6 +5,7 @@ """GroupedLinear API""" from typing import Union, Optional, Callable, Tuple, List import warnings +import contextlib import functools import torch diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 9d54e74830..8a493a8352 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -8,6 +8,7 @@ from typing import Callable, Dict, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op +import contextlib import torch from torch.nn import init diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ef0ef10d2c..abe7672a2a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -8,6 +8,7 @@ from typing import Callable, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op +import contextlib import torch from torch.nn.parameter import Parameter