Skip to content

Commit 79cdaec

Browse files
committed
Add conv roofline
1 parent 3bc5d37 commit 79cdaec

File tree

1 file changed

+23
-227
lines changed

1 file changed

+23
-227
lines changed

benchmarks/float8/float8_inference_roofline.py

Lines changed: 23 additions & 227 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,6 @@
5353
from torchao.testing.training.roofline_utils import (
5454
get_inference_float8_mem_sympy,
5555
get_inference_gemm_time_sympy,
56-
get_specs,
57-
BYTES_PER_EL_BF16,
58-
BYTES_PER_EL_FLOAT8,
59-
KERNEL_LAUNCH_OVERHEAD_SEC,
6056
)
6157
from torchao.utils import is_MI300
6258

@@ -198,7 +194,10 @@ def get_conv_equivalent_gemm_dims(
198194
padding: int = 0,
199195
):
200196
"""
201-
Get GEMM dimensions from unfold.
197+
Get equivalent GEMM dimensions for a conv operation.
198+
199+
Uses torch.nn.functional.unfold to derive the correct GEMM dimensions
200+
that correspond to the conv operation.
202201
203202
Args:
204203
op_name: "conv2d" or "conv3d"
@@ -235,10 +234,6 @@ def get_conv_equivalent_gemm_dims(
235234

236235
elif op_name == "conv3d":
237236
x = torch.randn(batch, in_channels, D, H, W, device=device)
238-
239-
# Note: torch.nn.Unfold only supports 4-D tensors
240-
# (https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html)
241-
# For 3D conv, reshape (B,C,D,H,W) -> (B*D,C,H,W) and unfold H,W
242237
B, C, D_in, H_in, W_in = x.shape
243238
x_reshaped = x.permute(0, 2, 1, 3, 4).reshape(B * D_in, C, H_in, W_in)
244239
unfolded = torch.nn.functional.unfold(
@@ -260,179 +255,6 @@ def get_conv_equivalent_gemm_dims(
260255
return gemm_M, gemm_K, gemm_N
261256

262257

263-
def benchmark_im2col_unfold(
264-
op_name: str,
265-
batch: int,
266-
in_channels: int,
267-
kernel_size: int,
268-
D: Optional[int],
269-
H: int,
270-
W: int,
271-
stride: int = 1,
272-
padding: int = 0,
273-
dtype=torch.bfloat16,
274-
):
275-
"""
276-
Benchmark unfold operation.
277-
278-
Args:
279-
op_name: "conv2d" or "conv3d"
280-
batch: Batch size
281-
in_channels: Number of input channels
282-
kernel_size: Kernel size
283-
D: Depth dimension (required for conv3d)
284-
H: Height dimension (required for conv2d/conv3d)
285-
W: Width dimension (required for conv2d/conv3d)
286-
stride: Stride value
287-
padding: Padding value
288-
dtype: Data type
289-
290-
Returns:
291-
Measured time in seconds
292-
"""
293-
device = torch.device("cuda")
294-
295-
_validate_conv_params(op_name, kernel_size, D, H, W)
296-
297-
# Unfold doesn't support FP8; return -1 for unsupported dtypes
298-
if dtype not in (torch.bfloat16, torch.float16, torch.float32):
299-
return -1
300-
301-
# Create input tensor
302-
if op_name == "conv2d":
303-
x = torch.randn(batch, in_channels, H, W, dtype=dtype, device=device)
304-
elif op_name == "conv3d":
305-
x = torch.randn(batch, in_channels, D, H, W, dtype=dtype, device=device)
306-
else:
307-
raise ValueError(f"Unsupported op_name: {op_name}")
308-
309-
def _run_unfold():
310-
if op_name == "conv2d":
311-
return torch.nn.functional.unfold(
312-
x, kernel_size=(kernel_size, kernel_size), stride=stride, padding=padding
313-
)
314-
else: # conv3d: reshape to 4D since unfold only supports 4D
315-
B, C, D_in, H_in, W_in = x.shape
316-
x_reshaped = x.permute(0, 2, 1, 3, 4).reshape(B * D_in, C, H_in, W_in)
317-
return torch.nn.functional.unfold(
318-
x_reshaped, kernel_size=(kernel_size, kernel_size), stride=stride, padding=padding
319-
)
320-
321-
# Warm up
322-
for _ in range(2):
323-
_ = _run_unfold()
324-
torch.cuda.synchronize()
325-
326-
# Benchmark
327-
n_iter = 10
328-
start = torch.cuda.Event(enable_timing=True)
329-
end = torch.cuda.Event(enable_timing=True)
330-
331-
start.record()
332-
for _ in range(n_iter):
333-
_ = _run_unfold()
334-
end.record()
335-
torch.cuda.synchronize()
336-
337-
return start.elapsed_time(end) / 1000.0 / n_iter
338-
339-
340-
def get_im2col_memory_overhead_sympy(
341-
op_name: str,
342-
batch: int,
343-
in_channels: int,
344-
out_channels: int,
345-
kernel_size: int,
346-
D: Optional[int],
347-
H: int,
348-
W: int,
349-
stride: int = 1,
350-
padding: int = 0,
351-
dtype=torch.bfloat16,
352-
gpu_name: Optional[str] = None,
353-
):
354-
"""
355-
Calculate the memory overhead for im2col transformation in conv operations.
356-
357-
Im2col unfolds the input tensor into a 2D matrix for efficient GEMM computation.
358-
This involves:
359-
1. Reading the input tensor (batch × in_channels × spatial_dims)
360-
2. Writing the im2col matrix (output_spatial_positions × kernel_volume)
361-
362-
The im2col matrix is typically much larger than the input due to overlapping
363-
windows, especially with stride=1 and larger kernels.
364-
365-
Args:
366-
op_name: "conv2d" or "conv3d"
367-
batch: Batch size
368-
in_channels: Number of input channels
369-
out_channels: Number of output channels
370-
kernel_size: Kernel size
371-
D: Depth dimension (required for conv3d)
372-
H: Height dimension (required for conv2d/conv3d)
373-
W: Width dimension (required for conv2d/conv3d)
374-
stride: Stride value
375-
padding: Padding value
376-
dtype: Data type
377-
gpu_name: GPU name for specs
378-
379-
Returns:
380-
sympy expression for im2col memory overhead in seconds
381-
"""
382-
_validate_conv_params(op_name, kernel_size, D, H, W)
383-
specs = get_specs(gpu_name)
384-
385-
# Determine bytes per element based on dtype
386-
if dtype == torch.bfloat16:
387-
bytes_per_el = BYTES_PER_EL_BF16
388-
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
389-
bytes_per_el = BYTES_PER_EL_FLOAT8
390-
else:
391-
bytes_per_el = BYTES_PER_EL_BF16 # default
392-
393-
if op_name == "conv2d":
394-
395-
# Input size
396-
input_numel = batch * in_channels * H * W
397-
398-
# Output spatial dimensions
399-
H_out = (H - kernel_size + 2 * padding) // stride + 1
400-
W_out = (W - kernel_size + 2 * padding) // stride + 1
401-
402-
# Im2col matrix size: (batch * H_out * W_out) × (in_channels * kernel_size^2)
403-
im2col_numel = batch * H_out * W_out * in_channels * kernel_size * kernel_size
404-
405-
elif op_name == "conv3d":
406-
# Input size
407-
input_numel = batch * in_channels * D * H * W
408-
409-
# Output spatial dimensions
410-
D_out = (D - kernel_size + 2 * padding) // stride + 1
411-
H_out = (H - kernel_size + 2 * padding) // stride + 1
412-
W_out = (W - kernel_size + 2 * padding) // stride + 1
413-
414-
# Im2col matrix size: (batch * D_out * H_out * W_out) × (in_channels * kernel_size^3)
415-
im2col_numel = batch * D_out * H_out * W_out * in_channels * kernel_size * kernel_size * kernel_size
416-
417-
else:
418-
raise ValueError(f"Unsupported op_name: {op_name}")
419-
420-
# Memory traffic: read input + write im2col matrix
421-
# Note: In practice, some implementations may avoid materializing the full im2col
422-
# matrix, but we model the worst case here
423-
bytes_read = input_numel * bytes_per_el
424-
bytes_write = im2col_numel * bytes_per_el
425-
total_bytes = bytes_read + bytes_write
426-
427-
# Convert to time using memory bandwidth
428-
im2col_time_s = total_bytes / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"]
429-
430-
# Account for kernel launch overhead
431-
im2col_time_s = sympy.Max(im2col_time_s, KERNEL_LAUNCH_OVERHEAD_SEC)
432-
433-
return im2col_time_s
434-
435-
436258
def run(
437259
outfile: str,
438260
recipe_name: str,
@@ -449,6 +271,9 @@ def run(
449271
H: Optional[int] = None,
450272
W: Optional[int] = None,
451273
kernel_size: Optional[int] = None,
274+
stride: int = 1,
275+
padding: int = 0,
276+
verbose: bool = False,
452277
):
453278
"""
454279
Args:
@@ -457,11 +282,13 @@ def run(
457282
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, `sweep`, or `custom`
458283
* `M|K|N`: if shape_gen_name is `custom`, then these values are used for MKN
459284
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
460-
# `save_profile_traces (optional)`: if True, saves profiling traces
461-
# `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
462-
# `op_name`: linear, conv2d or conv3d, decides which op to benchmark
463-
# `D`, `H`, `W`: spatial dimensiosn for conv3d / conv2d
464-
# `kernel_size`: kernel_size for conv3d / conv2d
285+
* `save_profile_traces (optional)`: if True, saves profiling traces
286+
* `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
287+
* `op_name`: linear, conv2d or conv3d, decides which op to benchmark
288+
* `D`, `H`, `W`: spatial dimensions for conv3d / conv2d
289+
* `kernel_size`: kernel_size for conv3d / conv2d
290+
* `stride`: stride for conv ops (default: 1)
291+
* `padding`: padding for conv ops (default: 0)
465292
"""
466293
_SUPPORTED_OPS = ["linear", "conv2d", "conv3d"]
467294
assert op_name in _SUPPORTED_OPS, f"Unsupported op: {op_name}, supported: {_SUPPORTED_OPS}"
@@ -481,6 +308,8 @@ def run(
481308
["MKN", f"{M} {K} {N}"],
482309
["DHW", f"{D} {H} {W}"],
483310
["kernel_size", kernel_size],
311+
["stride", stride],
312+
["padding", padding],
484313
]
485314
print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple"))
486315

@@ -513,7 +342,7 @@ def run(
513342
print()
514343

515344
if op_name in ("conv2d", "conv3d"):
516-
print(f"{op_name}: GEMM dimensions from unfold, roofline from symbolic expressions")
345+
print(f"{op_name}: GEMM dimensions derived from conv params")
517346
print()
518347
elif op_name != "linear":
519348
raise ValueError(f"Unsupported op_name: {op_name}")
@@ -525,17 +354,11 @@ def run(
525354
# Roofline: GEMM time
526355
"r_bf16_gemm_s", "r_fp8_gemm_s",
527356

528-
# Roofline: im2col overhead
529-
"r_im2col_bf16_s", "r_im2col_fp8_s",
530-
531357
# Roofline: FP8 quantization overhead
532358
"r_fp8_ovhd_s",
533359

534-
# Roofline: GEMM-only metrics
535-
"r_fp8_gemm_and_ovhd_s", "r_fp8_gemm_and_ovhd_spdp",
536-
537-
# Roofline: Total (im2col + GEMM + quantization)
538-
"r_bf16_total_s", "r_fp8_total_s", "r_fp8_total_spdp",
360+
# Roofline: Total (GEMM + quantization)
361+
"r_fp8_gemm_and_ovhd_s", "r_fp8_speedup",
539362

540363
# Benchmarks: Direct GEMM
541364
"b_bf16_gemm_s", "b_fp8_gemm_s",
@@ -572,11 +395,6 @@ def run(
572395
r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
573396
r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)
574397

575-
# Linear ops don't have im2col overhead
576-
r_im2col_bf16_s, r_im2col_fp8_s = 0, 0
577-
r_bf16_total_s = r_bf16_gemm_time_s
578-
r_fp8_total_s = r_fp8_gemm_and_ovhd_s
579-
r_total_spdp = r_speedup
580398
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
581399
rb_bf16_gemm_ratio, rb_fp8_gemm_ratio = -1, -1
582400

@@ -599,8 +417,6 @@ def run(
599417
rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
600418

601419
elif op_name in ("conv2d", "conv3d"):
602-
# Get GEMM dimensions from unfold
603-
stride, padding = 1, 0
604420
gemm_M, gemm_K, gemm_N = get_conv_equivalent_gemm_dims(
605421
op_name=op_name,
606422
batch=M_val,
@@ -625,24 +441,11 @@ def run(
625441
fp8_ovhd_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N)
626442
)
627443

628-
# Compute combined metrics
629444
r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
630445
r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)
631446

632-
# Roofline im2col overhead (theoretical)
633-
r_im2col_bf16_s = float(get_im2col_memory_overhead_sympy(
634-
op_name, M_val, K_val, N_val, kernel_size,
635-
D, H, W, stride=1, padding=0, dtype=torch.bfloat16
636-
))
637-
r_im2col_fp8_s = r_im2col_bf16_s * 0.5
638-
639-
# Roofline total: im2col + GEMM + quantization
640-
r_bf16_total_s = r_bf16_gemm_time_s + r_im2col_bf16_s
641-
r_fp8_total_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s + r_im2col_fp8_s
642-
r_total_spdp = r_bf16_total_s / r_fp8_total_s
643-
644-
print(f" -> Im2col: BF16={r_im2col_bf16_s*1e6:.2f} µs, FP8={r_im2col_fp8_s*1e6:.2f} µs")
645-
print(f" -> Speedup: GEMM only={r_speedup:.3f}x | Total={r_total_spdp:.3f}x")
447+
print(f" -> GEMM dims: M={gemm_M}, K={gemm_K}, N={gemm_N}")
448+
print(f" -> Speedup: {r_speedup:.3f}x")
646449

647450
# GEMM benchmarks not yet implemented for conv ops
648451
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
@@ -768,18 +571,11 @@ def run(
768571
# Roofline: GEMM
769572
r_bf16_gemm_time_s,
770573
r_fp8_gemm_time_s,
771-
# Roofline: im2col
772-
r_im2col_bf16_s,
773-
r_im2col_fp8_s,
774-
# Roofline: FP8 quantization
574+
# Roofline: FP8 quantization overhead
775575
r_fp8_ovhd_time_s,
776-
# Roofline: GEMM-only
576+
# Roofline: Total (GEMM + quantization)
777577
r_fp8_gemm_and_ovhd_s,
778578
r_speedup,
779-
# Roofline: Total
780-
r_bf16_total_s,
781-
r_fp8_total_s,
782-
r_total_spdp,
783579
# Benchmarks: GEMM
784580
b_bf16_gemm_time_s,
785581
b_fp8_gemm_time_s,

0 commit comments

Comments
 (0)