5353from torchao .testing .training .roofline_utils import (
5454 get_inference_float8_mem_sympy ,
5555 get_inference_gemm_time_sympy ,
56- get_specs ,
57- BYTES_PER_EL_BF16 ,
58- BYTES_PER_EL_FLOAT8 ,
59- KERNEL_LAUNCH_OVERHEAD_SEC ,
6056)
6157from torchao .utils import is_MI300
6258
@@ -198,7 +194,10 @@ def get_conv_equivalent_gemm_dims(
198194 padding : int = 0 ,
199195):
200196 """
201- Get GEMM dimensions from unfold.
197+ Get equivalent GEMM dimensions for a conv operation.
198+
199+ Uses torch.nn.functional.unfold to derive the correct GEMM dimensions
200+ that correspond to the conv operation.
202201
203202 Args:
204203 op_name: "conv2d" or "conv3d"
@@ -235,10 +234,6 @@ def get_conv_equivalent_gemm_dims(
235234
236235 elif op_name == "conv3d" :
237236 x = torch .randn (batch , in_channels , D , H , W , device = device )
238-
239- # Note: torch.nn.Unfold only supports 4-D tensors
240- # (https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html)
241- # For 3D conv, reshape (B,C,D,H,W) -> (B*D,C,H,W) and unfold H,W
242237 B , C , D_in , H_in , W_in = x .shape
243238 x_reshaped = x .permute (0 , 2 , 1 , 3 , 4 ).reshape (B * D_in , C , H_in , W_in )
244239 unfolded = torch .nn .functional .unfold (
@@ -260,179 +255,6 @@ def get_conv_equivalent_gemm_dims(
260255 return gemm_M , gemm_K , gemm_N
261256
262257
263- def benchmark_im2col_unfold (
264- op_name : str ,
265- batch : int ,
266- in_channels : int ,
267- kernel_size : int ,
268- D : Optional [int ],
269- H : int ,
270- W : int ,
271- stride : int = 1 ,
272- padding : int = 0 ,
273- dtype = torch .bfloat16 ,
274- ):
275- """
276- Benchmark unfold operation.
277-
278- Args:
279- op_name: "conv2d" or "conv3d"
280- batch: Batch size
281- in_channels: Number of input channels
282- kernel_size: Kernel size
283- D: Depth dimension (required for conv3d)
284- H: Height dimension (required for conv2d/conv3d)
285- W: Width dimension (required for conv2d/conv3d)
286- stride: Stride value
287- padding: Padding value
288- dtype: Data type
289-
290- Returns:
291- Measured time in seconds
292- """
293- device = torch .device ("cuda" )
294-
295- _validate_conv_params (op_name , kernel_size , D , H , W )
296-
297- # Unfold doesn't support FP8; return -1 for unsupported dtypes
298- if dtype not in (torch .bfloat16 , torch .float16 , torch .float32 ):
299- return - 1
300-
301- # Create input tensor
302- if op_name == "conv2d" :
303- x = torch .randn (batch , in_channels , H , W , dtype = dtype , device = device )
304- elif op_name == "conv3d" :
305- x = torch .randn (batch , in_channels , D , H , W , dtype = dtype , device = device )
306- else :
307- raise ValueError (f"Unsupported op_name: { op_name } " )
308-
309- def _run_unfold ():
310- if op_name == "conv2d" :
311- return torch .nn .functional .unfold (
312- x , kernel_size = (kernel_size , kernel_size ), stride = stride , padding = padding
313- )
314- else : # conv3d: reshape to 4D since unfold only supports 4D
315- B , C , D_in , H_in , W_in = x .shape
316- x_reshaped = x .permute (0 , 2 , 1 , 3 , 4 ).reshape (B * D_in , C , H_in , W_in )
317- return torch .nn .functional .unfold (
318- x_reshaped , kernel_size = (kernel_size , kernel_size ), stride = stride , padding = padding
319- )
320-
321- # Warm up
322- for _ in range (2 ):
323- _ = _run_unfold ()
324- torch .cuda .synchronize ()
325-
326- # Benchmark
327- n_iter = 10
328- start = torch .cuda .Event (enable_timing = True )
329- end = torch .cuda .Event (enable_timing = True )
330-
331- start .record ()
332- for _ in range (n_iter ):
333- _ = _run_unfold ()
334- end .record ()
335- torch .cuda .synchronize ()
336-
337- return start .elapsed_time (end ) / 1000.0 / n_iter
338-
339-
340- def get_im2col_memory_overhead_sympy (
341- op_name : str ,
342- batch : int ,
343- in_channels : int ,
344- out_channels : int ,
345- kernel_size : int ,
346- D : Optional [int ],
347- H : int ,
348- W : int ,
349- stride : int = 1 ,
350- padding : int = 0 ,
351- dtype = torch .bfloat16 ,
352- gpu_name : Optional [str ] = None ,
353- ):
354- """
355- Calculate the memory overhead for im2col transformation in conv operations.
356-
357- Im2col unfolds the input tensor into a 2D matrix for efficient GEMM computation.
358- This involves:
359- 1. Reading the input tensor (batch × in_channels × spatial_dims)
360- 2. Writing the im2col matrix (output_spatial_positions × kernel_volume)
361-
362- The im2col matrix is typically much larger than the input due to overlapping
363- windows, especially with stride=1 and larger kernels.
364-
365- Args:
366- op_name: "conv2d" or "conv3d"
367- batch: Batch size
368- in_channels: Number of input channels
369- out_channels: Number of output channels
370- kernel_size: Kernel size
371- D: Depth dimension (required for conv3d)
372- H: Height dimension (required for conv2d/conv3d)
373- W: Width dimension (required for conv2d/conv3d)
374- stride: Stride value
375- padding: Padding value
376- dtype: Data type
377- gpu_name: GPU name for specs
378-
379- Returns:
380- sympy expression for im2col memory overhead in seconds
381- """
382- _validate_conv_params (op_name , kernel_size , D , H , W )
383- specs = get_specs (gpu_name )
384-
385- # Determine bytes per element based on dtype
386- if dtype == torch .bfloat16 :
387- bytes_per_el = BYTES_PER_EL_BF16
388- elif dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
389- bytes_per_el = BYTES_PER_EL_FLOAT8
390- else :
391- bytes_per_el = BYTES_PER_EL_BF16 # default
392-
393- if op_name == "conv2d" :
394-
395- # Input size
396- input_numel = batch * in_channels * H * W
397-
398- # Output spatial dimensions
399- H_out = (H - kernel_size + 2 * padding ) // stride + 1
400- W_out = (W - kernel_size + 2 * padding ) // stride + 1
401-
402- # Im2col matrix size: (batch * H_out * W_out) × (in_channels * kernel_size^2)
403- im2col_numel = batch * H_out * W_out * in_channels * kernel_size * kernel_size
404-
405- elif op_name == "conv3d" :
406- # Input size
407- input_numel = batch * in_channels * D * H * W
408-
409- # Output spatial dimensions
410- D_out = (D - kernel_size + 2 * padding ) // stride + 1
411- H_out = (H - kernel_size + 2 * padding ) // stride + 1
412- W_out = (W - kernel_size + 2 * padding ) // stride + 1
413-
414- # Im2col matrix size: (batch * D_out * H_out * W_out) × (in_channels * kernel_size^3)
415- im2col_numel = batch * D_out * H_out * W_out * in_channels * kernel_size * kernel_size * kernel_size
416-
417- else :
418- raise ValueError (f"Unsupported op_name: { op_name } " )
419-
420- # Memory traffic: read input + write im2col matrix
421- # Note: In practice, some implementations may avoid materializing the full im2col
422- # matrix, but we model the worst case here
423- bytes_read = input_numel * bytes_per_el
424- bytes_write = im2col_numel * bytes_per_el
425- total_bytes = bytes_read + bytes_write
426-
427- # Convert to time using memory bandwidth
428- im2col_time_s = total_bytes / specs ["peak_mem_bw_bytes_sec" ] / specs ["pct_achievable_mem_bw" ]
429-
430- # Account for kernel launch overhead
431- im2col_time_s = sympy .Max (im2col_time_s , KERNEL_LAUNCH_OVERHEAD_SEC )
432-
433- return im2col_time_s
434-
435-
436258def run (
437259 outfile : str ,
438260 recipe_name : str ,
@@ -449,6 +271,9 @@ def run(
449271 H : Optional [int ] = None ,
450272 W : Optional [int ] = None ,
451273 kernel_size : Optional [int ] = None ,
274+ stride : int = 1 ,
275+ padding : int = 0 ,
276+ verbose : bool = False ,
452277):
453278 """
454279 Args:
@@ -457,11 +282,13 @@ def run(
457282 * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, `sweep`, or `custom`
458283 * `M|K|N`: if shape_gen_name is `custom`, then these values are used for MKN
459284 * `n_limit (optional)`: if specified, only runs `n_limit` iterations
460- # `save_profile_traces (optional)`: if True, saves profiling traces
461- # `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
462- # `op_name`: linear, conv2d or conv3d, decides which op to benchmark
463- # `D`, `H`, `W`: spatial dimensiosn for conv3d / conv2d
464- # `kernel_size`: kernel_size for conv3d / conv2d
285+ * `save_profile_traces (optional)`: if True, saves profiling traces
286+ * `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
287+ * `op_name`: linear, conv2d or conv3d, decides which op to benchmark
288+ * `D`, `H`, `W`: spatial dimensions for conv3d / conv2d
289+ * `kernel_size`: kernel_size for conv3d / conv2d
290+ * `stride`: stride for conv ops (default: 1)
291+ * `padding`: padding for conv ops (default: 0)
465292 """
466293 _SUPPORTED_OPS = ["linear" , "conv2d" , "conv3d" ]
467294 assert op_name in _SUPPORTED_OPS , f"Unsupported op: { op_name } , supported: { _SUPPORTED_OPS } "
@@ -481,6 +308,8 @@ def run(
481308 ["MKN" , f"{ M } { K } { N } " ],
482309 ["DHW" , f"{ D } { H } { W } " ],
483310 ["kernel_size" , kernel_size ],
311+ ["stride" , stride ],
312+ ["padding" , padding ],
484313 ]
485314 print (tabulate (config_table , headers = ["Parameter" , "Value" ], tablefmt = "simple" ))
486315
@@ -513,7 +342,7 @@ def run(
513342 print ()
514343
515344 if op_name in ("conv2d" , "conv3d" ):
516- print (f"{ op_name } : GEMM dimensions from unfold, roofline from symbolic expressions " )
345+ print (f"{ op_name } : GEMM dimensions derived from conv params " )
517346 print ()
518347 elif op_name != "linear" :
519348 raise ValueError (f"Unsupported op_name: { op_name } " )
@@ -525,17 +354,11 @@ def run(
525354 # Roofline: GEMM time
526355 "r_bf16_gemm_s" , "r_fp8_gemm_s" ,
527356
528- # Roofline: im2col overhead
529- "r_im2col_bf16_s" , "r_im2col_fp8_s" ,
530-
531357 # Roofline: FP8 quantization overhead
532358 "r_fp8_ovhd_s" ,
533359
534- # Roofline: GEMM-only metrics
535- "r_fp8_gemm_and_ovhd_s" , "r_fp8_gemm_and_ovhd_spdp" ,
536-
537- # Roofline: Total (im2col + GEMM + quantization)
538- "r_bf16_total_s" , "r_fp8_total_s" , "r_fp8_total_spdp" ,
360+ # Roofline: Total (GEMM + quantization)
361+ "r_fp8_gemm_and_ovhd_s" , "r_fp8_speedup" ,
539362
540363 # Benchmarks: Direct GEMM
541364 "b_bf16_gemm_s" , "b_fp8_gemm_s" ,
@@ -572,11 +395,6 @@ def run(
572395 r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
573396 r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s )
574397
575- # Linear ops don't have im2col overhead
576- r_im2col_bf16_s , r_im2col_fp8_s = 0 , 0
577- r_bf16_total_s = r_bf16_gemm_time_s
578- r_fp8_total_s = r_fp8_gemm_and_ovhd_s
579- r_total_spdp = r_speedup
580398 b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
581399 rb_bf16_gemm_ratio , rb_fp8_gemm_ratio = - 1 , - 1
582400
@@ -599,8 +417,6 @@ def run(
599417 rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
600418
601419 elif op_name in ("conv2d" , "conv3d" ):
602- # Get GEMM dimensions from unfold
603- stride , padding = 1 , 0
604420 gemm_M , gemm_K , gemm_N = get_conv_equivalent_gemm_dims (
605421 op_name = op_name ,
606422 batch = M_val ,
@@ -625,24 +441,11 @@ def run(
625441 fp8_ovhd_time_sympy .subs (M , gemm_M ).subs (K , gemm_K ).subs (N , gemm_N )
626442 )
627443
628- # Compute combined metrics
629444 r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
630445 r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s )
631446
632- # Roofline im2col overhead (theoretical)
633- r_im2col_bf16_s = float (get_im2col_memory_overhead_sympy (
634- op_name , M_val , K_val , N_val , kernel_size ,
635- D , H , W , stride = 1 , padding = 0 , dtype = torch .bfloat16
636- ))
637- r_im2col_fp8_s = r_im2col_bf16_s * 0.5
638-
639- # Roofline total: im2col + GEMM + quantization
640- r_bf16_total_s = r_bf16_gemm_time_s + r_im2col_bf16_s
641- r_fp8_total_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s + r_im2col_fp8_s
642- r_total_spdp = r_bf16_total_s / r_fp8_total_s
643-
644- print (f" -> Im2col: BF16={ r_im2col_bf16_s * 1e6 :.2f} µs, FP8={ r_im2col_fp8_s * 1e6 :.2f} µs" )
645- print (f" -> Speedup: GEMM only={ r_speedup :.3f} x | Total={ r_total_spdp :.3f} x" )
447+ print (f" -> GEMM dims: M={ gemm_M } , K={ gemm_K } , N={ gemm_N } " )
448+ print (f" -> Speedup: { r_speedup :.3f} x" )
646449
647450 # GEMM benchmarks not yet implemented for conv ops
648451 b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
@@ -768,18 +571,11 @@ def run(
768571 # Roofline: GEMM
769572 r_bf16_gemm_time_s ,
770573 r_fp8_gemm_time_s ,
771- # Roofline: im2col
772- r_im2col_bf16_s ,
773- r_im2col_fp8_s ,
774- # Roofline: FP8 quantization
574+ # Roofline: FP8 quantization overhead
775575 r_fp8_ovhd_time_s ,
776- # Roofline: GEMM-only
576+ # Roofline: Total ( GEMM + quantization)
777577 r_fp8_gemm_and_ovhd_s ,
778578 r_speedup ,
779- # Roofline: Total
780- r_bf16_total_s ,
781- r_fp8_total_s ,
782- r_total_spdp ,
783579 # Benchmarks: GEMM
784580 b_bf16_gemm_time_s ,
785581 b_fp8_gemm_time_s ,
0 commit comments