-
Notifications
You must be signed in to change notification settings - Fork 511
[PyTorch] CPU Overhead Micro-optimizations #2146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
488e6c9
a831adb
dc2532a
5fbfb01
3af9438
7c48bea
f1f6891
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
from typing import List | ||
import torch | ||
import time | ||
import argparse | ||
|
||
import transformer_engine.pytorch as te | ||
|
||
|
||
def run_once( | ||
module: torch.nn.Module, | ||
args: List[torch.Tensor], | ||
iters=2000, | ||
use_te: bool = True, | ||
is_first_microbatch: bool = True, | ||
): | ||
if use_te: | ||
for _ in range(iters): | ||
module(*args, is_first_microbatch=is_first_microbatch) | ||
else: | ||
for _ in range(iters): | ||
module(*args) | ||
|
||
|
||
def speedometer( | ||
module: torch.nn.Module, | ||
args: List[torch.Tensor], | ||
timing_iters: int = 2000, | ||
warmup_iters: int = 100, | ||
num_rounds: int = 5, | ||
use_te: bool = True, | ||
) -> float: | ||
"""Measure average run time for a PyTorch module""" | ||
# warm up | ||
run_once(module, args, iters=warmup_iters, use_te=use_te, is_first_microbatch=True) | ||
|
||
gpu_times = [] | ||
cpu_times = [] | ||
for round_idx in range(num_rounds): | ||
start = torch.cuda.Event(enable_timing=True) | ||
end = torch.cuda.Event(enable_timing=True) | ||
torch.cuda.synchronize() | ||
start.record() | ||
cpu_start = time.time() | ||
# run timing | ||
run_once(module, args, iters=timing_iters, use_te=use_te, is_first_microbatch=False) | ||
cpu_end = time.time() | ||
end.record() | ||
torch.cuda.synchronize() | ||
gpu_elapsed = start.elapsed_time(end) | ||
cpu_elapsed = (cpu_end - cpu_start) * 1000 | ||
gpu_times.append(gpu_elapsed) | ||
cpu_times.append(cpu_elapsed) | ||
print( | ||
f"Round {round_idx+1}/{num_rounds}: GPU {gpu_elapsed/timing_iters*1000:.2f} µs, CPU" | ||
f" {cpu_elapsed/timing_iters*1000:.2f} µs" | ||
) | ||
print( | ||
f"Average GPU time over {num_rounds} rounds:" | ||
f" {sum(gpu_times)/(num_rounds*timing_iters)*1000:.2f} µs" | ||
) | ||
print( | ||
f"Average CPU time over {num_rounds} rounds:" | ||
f" {sum(cpu_times)/(num_rounds*timing_iters)*1000:.2f} µs" | ||
) | ||
|
||
return sum(gpu_times) / num_rounds | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="Benchmark torch.nn.Linear performance and CPU overhead." | ||
) | ||
parser.add_argument("--hidden_size", type=int, default=3072, help="Hidden size") | ||
parser.add_argument("--seq_length", type=int, default=2048, help="Sequence length") | ||
parser.add_argument("--warmup", type=int, default=500, help="Number of warmup iterations") | ||
parser.add_argument( | ||
"--timing_iters", type=int, default=2000, help="Number of timing iterations per round" | ||
) | ||
parser.add_argument("--num_rounds", type=int, default=5, help="Number of timing rounds") | ||
parser.add_argument( | ||
"--backend", | ||
type=str, | ||
choices=["torch", "te"], | ||
default="te", | ||
help="Linear backend: torch or te", | ||
) | ||
args = parser.parse_args() | ||
|
||
x = torch.randn( | ||
(args.seq_length, args.hidden_size), dtype=torch.bfloat16, device="cuda", requires_grad=True | ||
) | ||
use_te = True | ||
if args.backend == "torch": | ||
model = ( | ||
torch.nn.Linear(args.hidden_size, args.hidden_size, bias=False) | ||
.to(torch.bfloat16) | ||
.cuda() | ||
) | ||
use_te = False | ||
else: | ||
model = te.Linear(args.hidden_size, args.hidden_size, bias=False, device="cuda").to( | ||
torch.bfloat16 | ||
) | ||
with torch.no_grad(): | ||
avg_gpu_time_per_round = speedometer( | ||
model, | ||
[x], | ||
timing_iters=args.timing_iters, | ||
warmup_iters=args.warmup, | ||
num_rounds=args.num_rounds, | ||
use_te=use_te, | ||
) | ||
|
||
total_ops = 2 * args.hidden_size * args.hidden_size * args.seq_length * args.timing_iters | ||
|
||
tflops = total_ops / avg_gpu_time_per_round / 1e9 | ||
print(f"Estimated TFLOP/s: {tflops:.2f}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
from functools import reduce | ||
from operator import mul as multiply_op | ||
import warnings | ||
import contextlib | ||
|
||
import torch | ||
|
||
|
@@ -1401,9 +1402,14 @@ def forward( | |
).is_fp8_ubuf(): | ||
fp8_grad = True | ||
|
||
with torch.cuda.device( | ||
getattr(self, list(self.named_parameters())[0][0]).device | ||
), self.prepare_forward( | ||
if is_first_microbatch is None or is_first_microbatch: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we assume we can skip setting the device if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume that the device won't change across microbatches in a global batch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since in a CPU bounded fwd only case, skipping set device for every single forward pass could account for 10% perf difference. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This approach is really ad hoc. Personally, I think it would be better to not to support the multi-device case (basically revert #1974) than to have inconsistent multi-device support. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with it, but not sure if there are any potential impact for customers using this feature? |
||
device_ctx = torch.cuda.device( | ||
getattr(self, list(self.named_parameters())[0][0]).device | ||
) | ||
else: | ||
device_ctx = contextlib.nullcontext() | ||
|
||
with device_ctx, self.prepare_forward( | ||
inp, | ||
allow_non_contiguous=isinstance(inp, QuantizedTensor), | ||
) as inp: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't handle the case where we have multiple GPUs with different archs. We could add an arg for the device ID, but that just pushes the CPU overhead problem somewhere else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but we didn't really support this case anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For topology like 1 CPU 8/4GPUs with homogenous GPU arch, we can cache the TN layout check.