-
Notifications
You must be signed in to change notification settings - Fork 380
Roofline quantized conv3d/2d layer #3419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -54,7 +54,7 @@ | |||
| get_inference_float8_mem_sympy, | ||||
| get_inference_gemm_time_sympy, | ||||
| ) | ||||
| from torchao.utils import is_MI300 | ||||
| from torchao.utils import is_MI300, is_sm_at_least_100 | ||||
|
|
||||
|
|
||||
| @torch.no_grad() | ||||
|
|
@@ -165,6 +165,67 @@ def do_matmul(A, B): | |||
| return bf16_time_s, f8_time_s | ||||
|
|
||||
|
|
||||
| def get_conv_equivalent_gemm_dims( | ||||
| op_name: str, | ||||
| batch: int, | ||||
| in_channels: int, | ||||
| out_channels: int, | ||||
| kernel_size: int, | ||||
| D: Optional[int], | ||||
| H: int, | ||||
| W: int, | ||||
| stride: int = 1, | ||||
| padding: int = 0, | ||||
| ): | ||||
| """ | ||||
| Get equivalent GEMM dimensions for a conv operation using analytical calculation. | ||||
|
|
||||
| Conv operations can be expressed as implicit GEMM. This function computes | ||||
| the equivalent GEMM dimensions without creating any tensors. | ||||
|
|
||||
| Args: | ||||
| op_name: "conv2d" or "conv3d" | ||||
| batch: Batch size | ||||
| in_channels: Number of input channels | ||||
| out_channels: Number of output channels | ||||
| kernel_size: Kernel size (assumes square/cubic kernel) | ||||
| D: Depth dimension (required for conv3d) | ||||
| H: Height dimension | ||||
| W: Width dimension | ||||
| stride: Stride value | ||||
| padding: Padding value | ||||
|
|
||||
| Returns: | ||||
| Tuple[int, int, int]: (gemm_M, gemm_K, gemm_N) | ||||
| gemm_M: Number of output spatial positions (batch * spatial_output_size) | ||||
| gemm_K: Size of each filter (in_channels * kernel_volume) | ||||
| gemm_N: Number of filters (out_channels) | ||||
| """ | ||||
| if op_name == "conv2d": | ||||
| # Output spatial dimensions | ||||
| H_out = (H + 2 * padding - kernel_size) // stride + 1 | ||||
| W_out = (W + 2 * padding - kernel_size) // stride + 1 | ||||
|
|
||||
| gemm_M = batch * H_out * W_out | ||||
| gemm_K = in_channels * kernel_size * kernel_size | ||||
| gemm_N = out_channels | ||||
|
|
||||
| elif op_name == "conv3d": | ||||
| # Output spatial dimensions | ||||
| D_out = (D + 2 * padding - kernel_size) // stride + 1 | ||||
| H_out = (H + 2 * padding - kernel_size) // stride + 1 | ||||
| W_out = (W + 2 * padding - kernel_size) // stride + 1 | ||||
|
|
||||
| gemm_M = batch * D_out * H_out * W_out | ||||
| gemm_K = in_channels * kernel_size * kernel_size * kernel_size | ||||
| gemm_N = out_channels | ||||
|
|
||||
| else: | ||||
| raise ValueError(f"Unsupported op_name: {op_name}") | ||||
|
|
||||
| return gemm_M, gemm_K, gemm_N | ||||
|
|
||||
|
|
||||
| def run( | ||||
| outfile: str, | ||||
| recipe_name: str, | ||||
|
|
@@ -181,6 +242,8 @@ def run( | |||
| H: Optional[int] = None, | ||||
| W: Optional[int] = None, | ||||
| kernel_size: Optional[int] = None, | ||||
| stride: int = 1, | ||||
| padding: int = 0, | ||||
| ): | ||||
| """ | ||||
| Args: | ||||
|
|
@@ -189,16 +252,19 @@ def run( | |||
| * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, `sweep`, or `custom` | ||||
| * `M|K|N`: if shape_gen_name is `custom`, then these values are used for MKN | ||||
| * `n_limit (optional)`: if specified, only runs `n_limit` iterations | ||||
| # `save_profile_traces (optional)`: if True, saves profiling traces | ||||
| # `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm | ||||
| # `op_name`: linear, conv2d or conv3d, decides which op to benchmark | ||||
| # `D`, `H`, `W`: spatial dimensiosn for conv3d / conv2d | ||||
| # `kernel_size`: kernel_size for conv3d / conv2d | ||||
| * `save_profile_traces (optional)`: if True, saves profiling traces | ||||
| * `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm | ||||
| * `op_name`: linear, conv2d or conv3d, decides which op to benchmark | ||||
| * `D`, `H`, `W`: spatial dimensions for conv3d / conv2d | ||||
| * `kernel_size`: kernel_size for conv3d / conv2d | ||||
| * `stride`: stride for conv ops (default: 1) | ||||
| * `padding`: padding for conv ops (default: 0) | ||||
| """ | ||||
| _SUPPORTED_OPS = ["linear", "conv2d", "conv3d"] | ||||
| assert op_name in _SUPPORTED_OPS, ( | ||||
| f"Unsupported op: {op_name}, supported are: {_SUPPORTED_OPS}" | ||||
| ) | ||||
|
|
||||
| if op_name == "conv2d": | ||||
| assert H is not None and W is not None, ( | ||||
| "Expected D, H, W to be specified for conv2d" | ||||
|
|
@@ -226,6 +292,8 @@ def run( | |||
| ["MKN", f"{M} {K} {N}"], | ||||
| ["DHW", f"{D} {H} {W}"], | ||||
| ["kernel_size", kernel_size], | ||||
| ["stride", stride], | ||||
| ["padding", padding], | ||||
| ] | ||||
| print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple")) | ||||
|
|
||||
|
|
@@ -234,7 +302,9 @@ def run( | |||
|
|
||||
| M, K, N = sympy.symbols("M K N") | ||||
|
|
||||
| if op_name == "linear": | ||||
| # Roofline model setup: linear uses M/K/N directly, conv uses equivalent | ||||
| # implicit GEMM dimensions (computed per-iteration in the loop below) | ||||
| if op_name in ("linear", "conv2d", "conv3d"): | ||||
| fp8_ovhd_time_sympy = get_inference_float8_mem_sympy( | ||||
| M, | ||||
| K, | ||||
|
|
@@ -260,20 +330,17 @@ def run( | |||
| print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy) | ||||
| print() | ||||
| else: | ||||
| # TODO: enable roofline analysis for conv | ||||
| pass | ||||
|
|
||||
| # Note: roofline for conv2d/conv3d is not added yet, so most of the | ||||
| # things for conv2d/conv3d we'll left out for now | ||||
| headers = [ | ||||
| "fwd_M", | ||||
| "fwd_K", | ||||
| "fwd_N", | ||||
| "fwd_M", # for conv: batch size | ||||
| "fwd_K", # for conv: in_channels | ||||
| "fwd_N", # for conv: out_channels | ||||
| "D", | ||||
| "H", | ||||
| "W", | ||||
| "kernel_size", | ||||
| # roofline - gemm time (fwd + bwd, 3 gemms) | ||||
| # roofline - gemm time (fwd + bwd, 3 gemms; for conv: using equivalent implicit gemm dims) | ||||
| "r_bf16_gemm_s", | ||||
| "r_fp8_gemm_s", | ||||
| # roofline - fp8 overhead time (by counting reads/writes in the ideal case) | ||||
|
|
@@ -327,7 +394,6 @@ def run( | |||
| b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 | ||||
| rb_bf16_gemm_ratio = -1 | ||||
| rb_fp8_gemm_ratio = -1 | ||||
|
|
||||
| if do_benchmarks: | ||||
| # TODO(future): make the bf16 gemm times exactly match the e2e | ||||
| # benchmarks, there is a slight deviation, probably related to gemm | ||||
|
|
@@ -347,44 +413,105 @@ def run( | |||
| rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s | ||||
|
|
||||
| else: | ||||
| # roofline analysis for conv2d/conv3d are not added yet | ||||
| r_bf16_gemm_time_s = None | ||||
| r_fp8_gemm_time_s = None | ||||
| # For conv ops, compute equivalent GEMM dimensions | ||||
| # M_val=batch, K_val=in_channels, N_val=out_channels | ||||
| gemm_M, gemm_K, gemm_N = get_conv_equivalent_gemm_dims( | ||||
| op_name=op_name, | ||||
| batch=M_val, | ||||
| in_channels=K_val, | ||||
| out_channels=N_val, | ||||
| kernel_size=kernel_size, | ||||
| D=D, | ||||
| H=H, | ||||
| W=W, | ||||
| stride=stride, | ||||
| padding=padding, | ||||
| ) | ||||
|
|
||||
| r_fp8_ovhd_time_s = None | ||||
| r_fp8_gemm_and_ovhd_s = None | ||||
| r_speedup = None | ||||
| # use roofline model to estimate gemm time using equivalent GEMM dims | ||||
| r_bf16_gemm_time_s = float( | ||||
| bf16_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the memory operations of conv the same as linear well? ao/torchao/testing/training/roofline_utils.py Line 332 in 0975a40
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As conv is an implicit gemm, I'm assuming the memory operations for gemm and conv should be same. |
||||
| ) | ||||
| r_fp8_gemm_time_s = float( | ||||
| fp8_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) | ||||
| ) | ||||
| r_fp8_ovhd_time_s = float( | ||||
| fp8_ovhd_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) | ||||
| ) | ||||
| r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s | ||||
| r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) | ||||
|
|
||||
| # real gemm benchmark time, also not added yet | ||||
| # if enabled, also measured observed gemm time | ||||
| # gemm benchmarks for conv not implemented, as conv uses implicit GEMM | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should run the conv ops I think? |
||||
| b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 | ||||
| # gemm roofline ratio achieved in real benchmark | ||||
| rb_bf16_gemm_ratio = -1 | ||||
| rb_fp8_gemm_ratio = -1 | ||||
|
|
||||
| b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0 | ||||
| if do_benchmarks: | ||||
| # Check hardware requirements for conv operations | ||||
| skip_conv_benchmarks = ( | ||||
| do_benchmarks | ||||
| and op_name in ("conv2d", "conv3d") | ||||
| and not is_sm_at_least_100() | ||||
| ) | ||||
|
|
||||
| if skip_conv_benchmarks: | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel conditions seems a bit convoluted here maybe: |
||||
| print( | ||||
| f"WARNING: Skipping {op_name} benchmarks for shape ({M_val}, {K_val}, {N_val}). " | ||||
| f"Float8 convolution requires SM 10.0+ (Blackwell/B100 GPUs). " | ||||
| f"Current GPU: {torch.cuda.get_device_name(0)} with SM {torch.cuda.get_device_capability()}. " | ||||
| f"Roofline model estimates are still valid." | ||||
| ) | ||||
|
|
||||
| if do_benchmarks and not skip_conv_benchmarks: | ||||
| # create the model | ||||
| if op_name == "conv2d": | ||||
| if not enable_fusion_modeling: | ||||
| m_orig = nn.Sequential( | ||||
| nn.Conv2d(K_val, N_val, kernel_size, bias=False) | ||||
| nn.Conv2d( | ||||
| K_val, | ||||
| N_val, | ||||
| kernel_size, | ||||
| stride=stride, | ||||
| padding=padding, | ||||
| bias=False, | ||||
| ) | ||||
| ).to(memory_format=torch.channels_last) | ||||
| else: | ||||
| m_orig = nn.Sequential( | ||||
| nn.ReLU(), | ||||
| nn.Conv2d(K_val, N_val, kernel_size, bias=False), | ||||
| nn.Conv2d( | ||||
| K_val, | ||||
| N_val, | ||||
| kernel_size, | ||||
| stride=stride, | ||||
| padding=padding, | ||||
| bias=False, | ||||
| ), | ||||
| nn.ReLU(), | ||||
| ).to(memory_format=torch.channels_last) | ||||
| elif op_name == "conv3d": | ||||
| if not enable_fusion_modeling: | ||||
| m_orig = nn.Sequential( | ||||
| nn.Conv3d(K_val, N_val, kernel_size, bias=False) | ||||
| nn.Conv3d( | ||||
| K_val, | ||||
| N_val, | ||||
| kernel_size, | ||||
| stride=stride, | ||||
| padding=padding, | ||||
| bias=False, | ||||
| ) | ||||
| ).to(memory_format=torch.channels_last_3d) | ||||
| else: | ||||
| m_orig = nn.Sequential( | ||||
| nn.ReLU(), | ||||
| nn.Conv3d(K_val, N_val, kernel_size, bias=False), | ||||
| nn.Conv3d( | ||||
| K_val, | ||||
| N_val, | ||||
| kernel_size, | ||||
| stride=stride, | ||||
| padding=padding, | ||||
| bias=False, | ||||
| ), | ||||
| nn.ReLU(), | ||||
| ).to(memory_format=torch.channels_last_3d) | ||||
| else: | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we share the same thing, we should remove this if/else branch and inline the code in if here I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept if/else just incase of an unexpected op_name. We can either make the code verify op_name in beginning to avoid any errors, or assume the input for op_name will always be correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we already verify that in L264 I think