Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 156 additions & 29 deletions benchmarks/float8/float8_inference_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
get_inference_float8_mem_sympy,
get_inference_gemm_time_sympy,
)
from torchao.utils import is_MI300
from torchao.utils import is_MI300, is_sm_at_least_100


@torch.no_grad()
Expand Down Expand Up @@ -165,6 +165,67 @@ def do_matmul(A, B):
return bf16_time_s, f8_time_s


def get_conv_equivalent_gemm_dims(
op_name: str,
batch: int,
in_channels: int,
out_channels: int,
kernel_size: int,
D: Optional[int],
H: int,
W: int,
stride: int = 1,
padding: int = 0,
):
"""
Get equivalent GEMM dimensions for a conv operation using analytical calculation.

Conv operations can be expressed as implicit GEMM. This function computes
the equivalent GEMM dimensions without creating any tensors.

Args:
op_name: "conv2d" or "conv3d"
batch: Batch size
in_channels: Number of input channels
out_channels: Number of output channels
kernel_size: Kernel size (assumes square/cubic kernel)
D: Depth dimension (required for conv3d)
H: Height dimension
W: Width dimension
stride: Stride value
padding: Padding value

Returns:
Tuple[int, int, int]: (gemm_M, gemm_K, gemm_N)
gemm_M: Number of output spatial positions (batch * spatial_output_size)
gemm_K: Size of each filter (in_channels * kernel_volume)
gemm_N: Number of filters (out_channels)
"""
if op_name == "conv2d":
# Output spatial dimensions
H_out = (H + 2 * padding - kernel_size) // stride + 1
W_out = (W + 2 * padding - kernel_size) // stride + 1

gemm_M = batch * H_out * W_out
gemm_K = in_channels * kernel_size * kernel_size
gemm_N = out_channels

elif op_name == "conv3d":
# Output spatial dimensions
D_out = (D + 2 * padding - kernel_size) // stride + 1
H_out = (H + 2 * padding - kernel_size) // stride + 1
W_out = (W + 2 * padding - kernel_size) // stride + 1

gemm_M = batch * D_out * H_out * W_out
gemm_K = in_channels * kernel_size * kernel_size * kernel_size
gemm_N = out_channels

else:
raise ValueError(f"Unsupported op_name: {op_name}")

return gemm_M, gemm_K, gemm_N


def run(
outfile: str,
recipe_name: str,
Expand All @@ -181,6 +242,8 @@ def run(
H: Optional[int] = None,
W: Optional[int] = None,
kernel_size: Optional[int] = None,
stride: int = 1,
padding: int = 0,
):
"""
Args:
Expand All @@ -189,16 +252,19 @@ def run(
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, `sweep`, or `custom`
* `M|K|N`: if shape_gen_name is `custom`, then these values are used for MKN
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
# `save_profile_traces (optional)`: if True, saves profiling traces
# `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
# `op_name`: linear, conv2d or conv3d, decides which op to benchmark
# `D`, `H`, `W`: spatial dimensiosn for conv3d / conv2d
# `kernel_size`: kernel_size for conv3d / conv2d
* `save_profile_traces (optional)`: if True, saves profiling traces
* `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
* `op_name`: linear, conv2d or conv3d, decides which op to benchmark
* `D`, `H`, `W`: spatial dimensions for conv3d / conv2d
* `kernel_size`: kernel_size for conv3d / conv2d
* `stride`: stride for conv ops (default: 1)
* `padding`: padding for conv ops (default: 0)
"""
_SUPPORTED_OPS = ["linear", "conv2d", "conv3d"]
assert op_name in _SUPPORTED_OPS, (
f"Unsupported op: {op_name}, supported are: {_SUPPORTED_OPS}"
)

if op_name == "conv2d":
assert H is not None and W is not None, (
"Expected D, H, W to be specified for conv2d"
Expand Down Expand Up @@ -226,6 +292,8 @@ def run(
["MKN", f"{M} {K} {N}"],
["DHW", f"{D} {H} {W}"],
["kernel_size", kernel_size],
["stride", stride],
["padding", padding],
]
print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple"))

Expand All @@ -234,7 +302,9 @@ def run(

M, K, N = sympy.symbols("M K N")

if op_name == "linear":
# Roofline model setup: linear uses M/K/N directly, conv uses equivalent
# implicit GEMM dimensions (computed per-iteration in the loop below)
if op_name in ("linear", "conv2d", "conv3d"):
fp8_ovhd_time_sympy = get_inference_float8_mem_sympy(
M,
K,
Expand All @@ -260,20 +330,17 @@ def run(
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
print()
else:
# TODO: enable roofline analysis for conv
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we share the same thing, we should remove this if/else branch and inline the code in if here I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept if/else just incase of an unexpected op_name. We can either make the code verify op_name in beginning to avoid any errors, or assume the input for op_name will always be correct

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already verify that in L264 I think


# Note: roofline for conv2d/conv3d is not added yet, so most of the
# things for conv2d/conv3d we'll left out for now
headers = [
"fwd_M",
"fwd_K",
"fwd_N",
"fwd_M", # for conv: batch size
"fwd_K", # for conv: in_channels
"fwd_N", # for conv: out_channels
"D",
"H",
"W",
"kernel_size",
# roofline - gemm time (fwd + bwd, 3 gemms)
# roofline - gemm time (fwd + bwd, 3 gemms; for conv: using equivalent implicit gemm dims)
"r_bf16_gemm_s",
"r_fp8_gemm_s",
# roofline - fp8 overhead time (by counting reads/writes in the ideal case)
Expand Down Expand Up @@ -327,7 +394,6 @@ def run(
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
rb_bf16_gemm_ratio = -1
rb_fp8_gemm_ratio = -1

if do_benchmarks:
# TODO(future): make the bf16 gemm times exactly match the e2e
# benchmarks, there is a slight deviation, probably related to gemm
Expand All @@ -347,44 +413,105 @@ def run(
rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s

else:
# roofline analysis for conv2d/conv3d are not added yet
r_bf16_gemm_time_s = None
r_fp8_gemm_time_s = None
# For conv ops, compute equivalent GEMM dimensions
# M_val=batch, K_val=in_channels, N_val=out_channels
gemm_M, gemm_K, gemm_N = get_conv_equivalent_gemm_dims(
op_name=op_name,
batch=M_val,
in_channels=K_val,
out_channels=N_val,
kernel_size=kernel_size,
D=D,
H=H,
W=W,
stride=stride,
padding=padding,
)

r_fp8_ovhd_time_s = None
r_fp8_gemm_and_ovhd_s = None
r_speedup = None
# use roofline model to estimate gemm time using equivalent GEMM dims
r_bf16_gemm_time_s = float(
bf16_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the memory operations of conv the same as linear well?

mem_gemm_time_s = (

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As conv is an implicit gemm, I'm assuming the memory operations for gemm and conv should be same.

)
r_fp8_gemm_time_s = float(
fp8_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N)
)
r_fp8_ovhd_time_s = float(
fp8_ovhd_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N)
)
r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)

# real gemm benchmark time, also not added yet
# if enabled, also measured observed gemm time
# gemm benchmarks for conv not implemented, as conv uses implicit GEMM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should run the conv ops I think?

b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
# gemm roofline ratio achieved in real benchmark
rb_bf16_gemm_ratio = -1
rb_fp8_gemm_ratio = -1

b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0
if do_benchmarks:
# Check hardware requirements for conv operations
skip_conv_benchmarks = (
do_benchmarks
and op_name in ("conv2d", "conv3d")
and not is_sm_at_least_100()
)

if skip_conv_benchmarks:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel conditions seems a bit convoluted here

maybe:

if do_benchmarks:
    if op_name in ("conv2d", "conv3d") and not is_sm_at_least_100():
        print warning
   else:
      # can also move this part to a function to make it clearer
       ....

print(
f"WARNING: Skipping {op_name} benchmarks for shape ({M_val}, {K_val}, {N_val}). "
f"Float8 convolution requires SM 10.0+ (Blackwell/B100 GPUs). "
f"Current GPU: {torch.cuda.get_device_name(0)} with SM {torch.cuda.get_device_capability()}. "
f"Roofline model estimates are still valid."
)

if do_benchmarks and not skip_conv_benchmarks:
# create the model
if op_name == "conv2d":
if not enable_fusion_modeling:
m_orig = nn.Sequential(
nn.Conv2d(K_val, N_val, kernel_size, bias=False)
nn.Conv2d(
K_val,
N_val,
kernel_size,
stride=stride,
padding=padding,
bias=False,
)
).to(memory_format=torch.channels_last)
else:
m_orig = nn.Sequential(
nn.ReLU(),
nn.Conv2d(K_val, N_val, kernel_size, bias=False),
nn.Conv2d(
K_val,
N_val,
kernel_size,
stride=stride,
padding=padding,
bias=False,
),
nn.ReLU(),
).to(memory_format=torch.channels_last)
elif op_name == "conv3d":
if not enable_fusion_modeling:
m_orig = nn.Sequential(
nn.Conv3d(K_val, N_val, kernel_size, bias=False)
nn.Conv3d(
K_val,
N_val,
kernel_size,
stride=stride,
padding=padding,
bias=False,
)
).to(memory_format=torch.channels_last_3d)
else:
m_orig = nn.Sequential(
nn.ReLU(),
nn.Conv3d(K_val, N_val, kernel_size, bias=False),
nn.Conv3d(
K_val,
N_val,
kernel_size,
stride=stride,
padding=padding,
bias=False,
),
nn.ReLU(),
).to(memory_format=torch.channels_last_3d)
else:
Expand Down
Loading