Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions benchmarks/linear/benchmark_linear_cpu_overhead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

from typing import List
import torch
import time
import argparse

import transformer_engine.pytorch as te


def run_once(
module: torch.nn.Module,
args: List[torch.Tensor],
iters=2000,
use_te: bool = True,
is_first_microbatch: bool = True,
):
if use_te:
for _ in range(iters):
module(*args, is_first_microbatch=is_first_microbatch)
else:
for _ in range(iters):
module(*args)


def speedometer(
module: torch.nn.Module,
args: List[torch.Tensor],
timing_iters: int = 2000,
warmup_iters: int = 100,
num_rounds: int = 5,
use_te: bool = True,
) -> float:
"""Measure average run time for a PyTorch module"""
# warm up
run_once(module, args, iters=warmup_iters, use_te=use_te, is_first_microbatch=True)

gpu_times = []
cpu_times = []
for round_idx in range(num_rounds):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
cpu_start = time.time()
# run timing
run_once(module, args, iters=timing_iters, use_te=use_te, is_first_microbatch=False)
cpu_end = time.time()
end.record()
torch.cuda.synchronize()
gpu_elapsed = start.elapsed_time(end)
cpu_elapsed = (cpu_end - cpu_start) * 1000
gpu_times.append(gpu_elapsed)
cpu_times.append(cpu_elapsed)
print(
f"Round {round_idx+1}/{num_rounds}: GPU {gpu_elapsed/timing_iters*1000:.2f} µs, CPU"
f" {cpu_elapsed/timing_iters*1000:.2f} µs"
)
print(
f"Average GPU time over {num_rounds} rounds:"
f" {sum(gpu_times)/(num_rounds*timing_iters)*1000:.2f} µs"
)
print(
f"Average CPU time over {num_rounds} rounds:"
f" {sum(cpu_times)/(num_rounds*timing_iters)*1000:.2f} µs"
)

return sum(gpu_times) / num_rounds


def main():
parser = argparse.ArgumentParser(
description="Benchmark torch.nn.Linear performance and CPU overhead."
)
parser.add_argument("--hidden_size", type=int, default=3072, help="Hidden size")
parser.add_argument("--seq_length", type=int, default=2048, help="Sequence length")
parser.add_argument("--warmup", type=int, default=500, help="Number of warmup iterations")
parser.add_argument(
"--timing_iters", type=int, default=2000, help="Number of timing iterations per round"
)
parser.add_argument("--num_rounds", type=int, default=5, help="Number of timing rounds")
parser.add_argument(
"--backend",
type=str,
choices=["torch", "te"],
default="te",
help="Linear backend: torch or te",
)
args = parser.parse_args()

x = torch.randn(
(args.seq_length, args.hidden_size), dtype=torch.bfloat16, device="cuda", requires_grad=True
)
use_te = True
if args.backend == "torch":
model = (
torch.nn.Linear(args.hidden_size, args.hidden_size, bias=False)
.to(torch.bfloat16)
.cuda()
)
use_te = False
else:
model = te.Linear(args.hidden_size, args.hidden_size, bias=False, device="cuda").to(
torch.bfloat16
)
with torch.no_grad():
avg_gpu_time_per_round = speedometer(
model,
[x],
timing_iters=args.timing_iters,
warmup_iters=args.warmup,
num_rounds=args.num_rounds,
use_te=use_te,
)

total_ops = 2 * args.hidden_size * args.hidden_size * args.seq_length * args.timing_iters

tflops = total_ops / avg_gpu_time_per_round / 1e9
print(f"Estimated TFLOP/s: {tflops:.2f}")


if __name__ == "__main__":
main()
18 changes: 11 additions & 7 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,11 +641,15 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}

int nvte_is_non_tn_fp8_gemm_supported() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't handle the case where we have multiple GPUs with different archs. We could add an arg for the device ID, but that just pushes the CPU overhead problem somewhere else.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but we didn't really support this case anyway?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For topology like 1 CPU 8/4GPUs with homogenous GPU arch, we can cache the TN layout check.

int deviceComputeCapability =
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());

// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
static int cached_result = 0;
static std::once_flag flag;
std::call_once(flag, []() {
int deviceComputeCapability =
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
cached_result = (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
});
return cached_result;
}
12 changes: 9 additions & 3 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List
import warnings
import contextlib

import functools
import torch
Expand Down Expand Up @@ -745,9 +746,14 @@ def forward(
if skip_fp8_weight_update is not None:
is_first_microbatch = False

with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
if is_first_microbatch is None or is_first_microbatch:
device_ctx = torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
)
else:
device_ctx = contextlib.nullcontext()

with device_ctx, self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]

Expand Down
15 changes: 11 additions & 4 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Callable, Dict, Optional, Tuple, Union, List
from functools import reduce
from operator import mul as multiply_op
import contextlib

import torch
from torch.nn import init
Expand Down Expand Up @@ -1517,10 +1518,16 @@ def forward(
).is_fp8_ubuf():
fp8_grad = True

with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
if is_first_microbatch is None or is_first_microbatch:
device_ctx = torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
)
else:
device_ctx = contextlib.nullcontext()

with device_ctx, self.prepare_forward(
inp,
allow_non_contiguous=False, # removed .contiguous from inside the layer
) as inp:

# Get concatenated weight and bias tensors
Expand Down
12 changes: 9 additions & 3 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Callable, Optional, Tuple, Union, List
from functools import reduce
from operator import mul as multiply_op
import contextlib

import torch
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -1768,9 +1769,14 @@ def forward(
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True

with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=2) as inp:
if is_first_microbatch is None or is_first_microbatch:
device_ctx = torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
)
else:
device_ctx = contextlib.nullcontext()

with device_ctx, self.prepare_forward(inp, num_gemms=2) as inp:

quantizers = (
self._get_quantizers(fp8_output)
Expand Down
12 changes: 9 additions & 3 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import reduce
from operator import mul as multiply_op
import warnings
import contextlib

import torch

Expand Down Expand Up @@ -1401,9 +1402,14 @@ def forward(
).is_fp8_ubuf():
fp8_grad = True

with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
if is_first_microbatch is None or is_first_microbatch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we assume we can skip setting the device if is_first_microbatch=False?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that the device won't change across microbatches in a global batch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since in a CPU bounded fwd only case, skipping set device for every single forward pass could account for 10% perf difference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach is really ad hoc. Personally, I think it would be better to not to support the multi-device case (basically revert #1974) than to have inconsistent multi-device support.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with it, but not sure if there are any potential impact for customers using this feature?

device_ctx = torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
)
else:
device_ctx = contextlib.nullcontext()

with device_ctx, self.prepare_forward(
inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
Expand Down