From 1b4341ea7dd937eee0ebeabe933b1be2af198068 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Fri, 17 Oct 2025 13:25:33 +0000 Subject: [PATCH 01/11] Remove one memory allocation --- .github/workflows/third-party-benchmarks.yml | 2 +- benchmarks/third_party/vllm/batched_moe_benchmark.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index 982f76dbc6..baeb2002ee 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -112,7 +112,7 @@ jobs: cd benchmarks/third_party/vllm FP8="1" python batched_moe_benchmark.py --reports $REPORTS - python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-report.csv --tag $TAG --benchmark moe-fp8-benchmark + python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-fp8-report.csv --tag $TAG --benchmark moe-fp8-benchmark - name: Run Liger-Kernel benchmarks if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'liger-kernel')) }} diff --git a/benchmarks/third_party/vllm/batched_moe_benchmark.py b/benchmarks/third_party/vllm/batched_moe_benchmark.py index 7afba1efea..18685edf95 100644 --- a/benchmarks/third_party/vllm/batched_moe_benchmark.py +++ b/benchmarks/third_party/vllm/batched_moe_benchmark.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.fused_moe.utils import normalize_batched_scales_shape # Import utility functions from vLLM tests -from tests.kernels.moe.utils import make_quantized_test_activations, make_test_weights +from tests.kernels.moe.utils import make_quantized_test_activations, make_test_weight from tests.kernels.quant_utils import native_batched_masked_quant_matmul @@ -552,9 +552,9 @@ def benchmark(num_experts, max_tokens_per_expert, K, N, fp8, block_quant, provid ) # Create test weights (only need B matrix for batched MM) - (B, B_q, B_scale, _), _ = make_test_weights( + B, B_q, B_scale, _ = make_test_weight( num_experts, - N // 2, + N, K, in_dtype=act_dtype, quant_dtype=quant_dtype, From 784b7e88c4ff3f4ef5ec64877c224817625c6f10 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Fri, 17 Oct 2025 16:10:21 +0000 Subject: [PATCH 02/11] Added benchmark --- .github/workflows/third-party-benchmarks.yml | 18 + .../third_party/sglang/scaled_mm_benchmark.py | 530 ++++++++++++++++++ 2 files changed, 548 insertions(+) create mode 100644 benchmarks/third_party/sglang/scaled_mm_benchmark.py diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index baeb2002ee..2fdf2bc935 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -93,6 +93,24 @@ jobs: cd benchmarks pip install . + - name: Run sglang benchmark int8 + if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} + run: | + source ./scripts/capture-hw-details.sh + + cd benchmarks/third_party/sglang + python scaled_mm_benchmark.py --reports $REPORTS + python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-int8-report.csv --tag $TAG --benchmark scaled-mm-int8 + + - name: Run sglang benchmark with fp8 + if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} + run: | + source ./scripts/capture-hw-details.sh + + cd benchmarks/third_party/sglang + FP8="1" python scaled_mm_benchmark.py --reports $REPORTS + python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-fp8-report.csv --tag $TAG --benchmark scaled-mm-fp8 + - name: Run vllm benchmarks bf16 if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} run: | diff --git a/benchmarks/third_party/sglang/scaled_mm_benchmark.py b/benchmarks/third_party/sglang/scaled_mm_benchmark.py new file mode 100644 index 0000000000..3b8e38b113 --- /dev/null +++ b/benchmarks/third_party/sglang/scaled_mm_benchmark.py @@ -0,0 +1,530 @@ +# From +# https://github.com/sgl-project/sglang/blob/6d0364681c8b1abc132cc88f1bb0b7a8a352628f/test/srt/quant/test_triton_scaled_mm.py +# https://github.com/sgl-project/sglang/blob/6d0364681c8b1abc132cc88f1bb0b7a8a352628f/python/sglang/srt/layers/quantization/fp8_kernel.py +import os +from typing import Optional, List + +import torch +import triton +import triton.language as tl + +import triton_kernels_benchmark as benchmark_suite + +# Import vLLM MoE functions +# from vllm.model_executor.layers.fused_moe.fused_batched_moe import invoke_moe_batched_triton_kernel +# from vllm.platforms import current_platform +# from vllm.model_executor.layers.fused_moe.utils import normalize_batched_scales_shape + +# # Import utility functions from vLLM tests +# from tests.kernels.moe.utils import make_quantized_test_activations, make_test_weight +# from tests.kernels.quant_utils import native_batched_masked_quant_matmul + +import torch.testing + + +def is_weak_contiguous(x: torch.Tensor): + strides = x.stride() + sizes = x.shape + is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0])) + is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1])) + return is_transpose or is_not_transpose + + +def get_matmul_batched_autotune_configs() -> List[triton.Config]: + configs = [ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + for s in [2, 3] + ] + [ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': m}, num_stages=s, num_warps=w) + for s in [2] + for (m, w) in ([('large', 32), ('small', 64)]) + ] + [ + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + for s in [2] + ] + [ + triton.Config({'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + for s in [2] + ] + [ + triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=4) + for s in [2] + ] + return configs + + +# @triton.autotune( +# configs=get_matmul_batched_autotune_configs(), +# key=['M', 'N', 'K']) +@triton.jit +def scaled_mm_kernel( + a_ptr, + b_ptr, + scale_a_ptr, + scale_b_ptr, + c_ptr, + bias_ptr, + M, + N, + K, + stride_am: tl.int64, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + ACCUMULATOR_DTYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_SCALE_A: tl.constexpr, + BLOCK_SIZE_SCALE_B: tl.constexpr, +): + pid = tl.program_id(axis=0) + + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = ACCUMULATOR_DTYPE + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) + + # NOTE: Some tensor inputs are so large, they will cause int32 overflow + # so it is necessary to use tl.int64 for all the offsets, else SEGV will + # eventually occur. + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + masks_am = offsets_am < M + + offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + masks_bn = offsets_bn < N + + offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :] + offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :] + + # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create + # appropriate offsets and masks for each case. Same goes for + # BLOCK_SIZE_SCALE_B. + offsets_scale_am = tl.arange(0, BLOCK_SIZE_SCALE_A) + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M + masks_scale_am = offsets_scale_am < M + + offsets_scale_bn = tl.arange(0, BLOCK_SIZE_SCALE_B) + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N + masks_scale_bn = offsets_scale_bn < N + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + scale_a_ptrs = scale_a_ptr + offsets_scale_am + scale_b_ptrs = scale_b_ptr + offsets_scale_bn + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Apply scale at end. + masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None] + scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a) + # Need to broadcast to the appropriate size, if scale_a is already + # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes + # for scale_b below. + scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1)) + accumulator = scale_a * accumulator.to(tl.float32) + + masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :] + scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b) + scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1)) + accumulator = scale_b.T * accumulator.to(tl.float32) + + # Convert to output format. + c = accumulator.to(c_ptr.type.element_ty) + + # Add bias, it's already in output format, so add it after conversion. + if bias_ptr: + offsets_bias = offsets_bn + bias_ptrs = bias_ptr + offsets_bias + bias_mask = offsets_bias < N + bias = tl.load(bias_ptrs, bias_mask) + c += bias + + # Save output + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + offs_cm = offs_cm.to(tl.int64) + offs_cn = offs_cn.to(tl.int64) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + tl.store(c_ptrs, c, mask=c_mask) + + +@triton.jit +def scaled_mm_kernel_td( + a_ptr, + b_ptr, + scale_a_ptr, + scale_b_ptr, + c_ptr, + bias_ptr, + M, + N, + K, + stride_am: tl.int64, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + ACCUMULATOR_DTYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_SCALE_A: tl.constexpr, + BLOCK_SIZE_SCALE_B: tl.constexpr, +): + pid = tl.program_id(axis=0) + + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = ACCUMULATOR_DTYPE + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) + + # NOTE: Some tensor inputs are so large, they will cause int32 overflow + # so it is necessary to use tl.int64 for all the offsets, else SEGV will + # eventually occur. + + # Offsets and masks. + # offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + # masks_am = offsets_am < M + + offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + # masks_bn = offsets_bn < N + + # offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + # offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :] + # offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :] + + # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create + # appropriate offsets and masks for each case. Same goes for + # BLOCK_SIZE_SCALE_B. + offsets_scale_am = tl.arange(0, BLOCK_SIZE_SCALE_A) + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M + masks_scale_am = offsets_scale_am < M + + offsets_scale_bn = tl.arange(0, BLOCK_SIZE_SCALE_B) + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N + masks_scale_bn = offsets_scale_bn < N + + a_desc = tl.make_tensor_descriptor(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K)) + b_desc = tl.make_tensor_descriptor(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N)) + + # a_ptrs = a_ptr + offsets_a + # b_ptrs = b_ptr + offsets_b + + scale_a_ptrs = scale_a_ptr + offsets_scale_am + scale_b_ptrs = scale_b_ptr + offsets_scale_bn + + off_k = 0 + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # masks_k = offsets_k < K + # masks_a = masks_am[:, None] & masks_k[None, :] + # a = tl.load(a_ptrs, mask=masks_a) + + # masks_b = masks_k[:, None] & masks_bn[None, :] + # b = tl.load(b_ptrs, mask=masks_b) + + a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k]) + b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N]) + # accumulator += tl.dot(a, b) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + off_k += BLOCK_SIZE_K + + # offsets_k += BLOCK_SIZE_K + # a_ptrs += BLOCK_SIZE_K * stride_ak + # b_ptrs += BLOCK_SIZE_K * stride_bk + + # Apply scale at end. + masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None] + scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a) + # Need to broadcast to the appropriate size, if scale_a is already + # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes + # for scale_b below. + scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1)) + accumulator = scale_a * accumulator.to(tl.float32) + + masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :] + scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b) + scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1)) + accumulator = scale_b.T * accumulator.to(tl.float32) + + # Convert to output format. + c = accumulator.to(c_ptr.type.element_ty) + + # Add bias, it's already in output format, so add it after conversion. + if bias_ptr: + offsets_bias = offsets_bn + bias_ptrs = bias_ptr + offsets_bias + bias_mask = offsets_bias < N + bias = tl.load(bias_ptrs, bias_mask) + c += bias + + # Save output + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + offs_cm = offs_cm.to(tl.int64) + offs_cn = offs_cn.to(tl.int64) + # c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + # c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + # tl.store(c_ptrs, c, mask=c_mask) + c_desc = tl.make_tensor_descriptor(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N)) + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c) + + +# input - [M, K] +# weight - [K, N] +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +def triton_scaled_mm( + input: torch.Tensor, + weight: torch.Tensor, + result: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32, + use_heuristic=True, + use_td_kernel=False, +) -> torch.Tensor: + M, K = input.shape + N = weight.shape[1] + + assert N > 0 and K > 0 and M > 0 + assert weight.shape[0] == K + assert input.dtype == weight.dtype + + scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a + scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b + + assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() + assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M) + assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N) + assert out_dtype.is_floating_point + assert bias is None or bias.is_floating_point() + assert is_weak_contiguous(input) + assert is_weak_contiguous(weight) + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + + # result = torch.empty((M, N), dtype=out_dtype, device=input.device) + + has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1 + + if use_heuristic: + is_small_N = N < 8192 + next_power_of_2_M = max(32, triton.next_power_of_2(M)) + if next_power_of_2_M <= 32: + tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256) + elif next_power_of_2_M <= 64: + tile_shape = (64, 64, 256) + elif next_power_of_2_M <= 128: + tile_shape = (64, 128, 128) + else: + tile_shape = (128, 128, 128) + + block_size_m, block_size_n, block_size_k = tile_shape + + block_size_sa = 1 if has_scalar(scale_a) else block_size_m + block_size_sb = 1 if has_scalar(scale_b) else block_size_n + + accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32 + + kernel = scaled_mm_kernel if not use_td_kernel else scaled_mm_kernel_td + # A = input, B = weight, C = result + # A = M x K, B = K x N, C = M x N + kernel[grid]( + input, + weight, + scale_a, + scale_b, + result, + bias, + M, + N, + K, + input.stride(0), + input.stride(1), + weight.stride(0), + weight.stride(1), + result.stride(0), + result.stride(1), + accumulator_dtype, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + BLOCK_SIZE_SCALE_A=block_size_sa, + BLOCK_SIZE_SCALE_B=block_size_sb, + ) + + return result + + +torch.set_default_device('xpu') +device = 'xpu' + + +def torch_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + result: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Reference implementation using float32 for stability""" + out = torch.mm(a.to(torch.float32), b.to(torch.float32)) + out = scale_a.to(torch.float32) * out * scale_b.to(torch.float32).T + if bias is not None: + out = out + bias.to(torch.float32) + result[:] = out.to(out_dtype) + return result + # return out.to(out_dtype) + + +def _make_inputs(M, K, N, in_dtype): + if in_dtype == torch.int8: + a = torch.randint(-8, 8, (M, K), dtype=in_dtype, device=device) + b = torch.randint(-8, 8, (K, N), dtype=in_dtype, device=device) + else: # fp8 + # Adding zero help with nan for some reason, without it there will be some accidental nans + a = (0 + torch.clamp(0.5 * torch.randn((M, K), dtype=torch.float16, device=device), -0.25, 0.25)).to(in_dtype) + b = 0.5 * torch.randn((K, N), dtype=torch.float16, device=device) + b = torch.clamp(b, -0.25, 0.25) + # Adding zero help with nan for some reason, without it there will be some accidental nans + b = (0 + b).to(in_dtype) + return a, b + + +X_VALS = sum([[ # + # [M, 128, 128], + [M, 1024, 4096], [M, 4096, 4096], [M, 4096, 4096 * 4] +] for M in [1, 8, 128, 1024, 4096]], []) + + +def get_scaled_mm_benchmark( + providers_filter: Optional[list[str]] = None, + fp8=False, + plot_name: str = 'scaled_mm_benchmark', +): + supported_providers = { + 'triton': 'triton', + 'triton-td': 'triton-td', + 'pytorch': 'pytorch', + } + if fp8: + pass + + providers = benchmark_suite.filter_providers(supported_providers, providers_filter) + + @benchmark_suite.perf_report( + benchmark_suite.Benchmark( + x_names=['M', 'N', 'K'], + x_vals=X_VALS, + line_arg='provider', + line_vals=list(providers.keys()), + line_names=list(providers.values()), + styles=[('green', '-'), ('blue', '--'), ('red', ':')], + ylabel=['GB/s', 'TFlops'], + plot_name=plot_name, + args={}, + )) + def benchmark(M, N, K, provider, with_bias=False): + torch.manual_seed(10) + n_warmup = 600 + + quantiles = [0.5, 0, 1.0] + + if fp8: + in_dtype, out_dtype = torch.float8_e4m3fn, torch.float32 + else: + in_dtype, out_dtype = torch.int8, torch.bfloat16 + + input, weight = _make_inputs(M, K, N, in_dtype) + scale_a = 0.1 + 0.05 * torch.rand((M, 1), dtype=torch.float32, device=device) + scale_b = 0.1 + 0.05 * torch.rand((N, 1), dtype=torch.float32, device=device) + bias = (0.01 * torch.randn((M, N), dtype=out_dtype, device=device) if with_bias else None) + + ref = torch.empty((M, N), dtype=out_dtype, device=input.device) + + def torch_fn(): + return torch_scaled_mm(input, weight, ref, scale_a, scale_b, out_dtype, bias) + + # Use relaxed tolerances + rtol = 0.15 if in_dtype == torch.int8 else 0.25 + atol = 0.1 if in_dtype == torch.int8 else 0.15 + + if provider == 'pytorch': + # PyTorch reference implementation using native_batched_masked_quant_matmul + _, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench( + torch_fn, + n_warmup=n_warmup, + n_repeat=10, + quantiles=quantiles, + ) + + elif provider == 'triton' or provider == 'triton-td': + result = torch.empty((M, N), dtype=out_dtype, device=input.device) + + # invoke_kernel = invoke_moe_batched_triton_kernel if provider == 'triton' else invoke_moe_batched_triton_kernel_td + def triton_fn(): + return triton_scaled_mm(input, weight, result, scale_a, scale_b, out_dtype, bias, + use_td_kernel=provider == 'triton-td') + + benchmark_suite.assert_close(triton_fn, torch_fn, atol=atol, rtol=rtol, err_msg='triton to torch') + + _, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench( + triton_fn, + n_warmup=n_warmup, + n_repeat=10, + quantiles=quantiles, + ) + + else: + raise NotImplementedError(f'Unsupported provider {provider}') + + def gbps(ms): + total_bytes = in_dtype.itemsize * (M * K + K * N) + out_dtype.itemsize * M * N + return total_bytes * (1e-9) / (ms * 1e-3) + + def tflops(ms): + total_flops = M * N * K * 2 + return total_flops * (1e-12) / (ms * 1e-3) + + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv + + return benchmark + + +if __name__ == '__main__': + _benchmark_mm = get_scaled_mm_benchmark(fp8=(os.getenv('FP8', '0') == '1'), ) + _benchmark_mm.run(show_plots=False, print_data=True) From 76d22ea95303af57e7fdac4267c6f9590c856505 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 20 Oct 2025 11:39:09 +0000 Subject: [PATCH 03/11] Use python from pyenv --- .github/workflows/third-party-benchmarks.yml | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index 7fa3f2a202..0b846928bc 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -15,10 +15,6 @@ on: description: JSON list of benchmarks to run. Leave empty to run all benchmarks. type: string default: "" - use_pyenv_python: - description: Use Python built with pyenv - type: boolean - default: false schedule: # About midnight PST (UTC-8) - cron: "5 10 * * *" @@ -57,7 +53,6 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} - name: Install Python (from pyenv) ${{ inputs.python_version }} - if: ${{ inputs.use_pyenv_python }} uses: ./.github/actions/setup-pyenv-python with: python-version: ${{ env.PYTHON_VERSION }} @@ -77,16 +72,16 @@ jobs: - name: Setup Triton uses: ./.github/actions/setup-triton - - name: Install benchmark dependencies - id: install - run: | - pip install transformers pandas pytest - - name: Create reports dir run: | mkdir reports echo "REPORTS=$PWD/reports" >> $GITHUB_ENV + - name: Install benchmark dependencies + id: install + run: | + pip install transformers pandas pytest + - name: Install benchmarks id: install-benchmarks run: | From d325bd6017f901001681647d8c67d0b21894f38b Mon Sep 17 00:00:00 2001 From: Egor Date: Mon, 20 Oct 2025 14:05:44 +0200 Subject: [PATCH 04/11] Update third-party-benchmarks.yml --- .github/workflows/third-party-benchmarks.yml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index 0b846928bc..6612f937f5 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -46,12 +46,6 @@ jobs: - name: Checkout repository uses: actions/checkout@v5 - - name: Install Python - if: ${{ !(inputs.use_pyenv_python || false) }} - uses: actions/setup-python@v6 - with: - python-version: ${{ env.PYTHON_VERSION }} - - name: Install Python (from pyenv) ${{ inputs.python_version }} uses: ./.github/actions/setup-pyenv-python with: From ee0253f1bfc2a5c69b50c043f4e3c27b67d91e28 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 20 Oct 2025 12:37:26 +0000 Subject: [PATCH 05/11] Fixed codestyle --- .github/workflows/third-party-benchmarks.yml | 4 +-- .../third_party/sglang/scaled_mm_benchmark.py | 32 +++++++------------ 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index e5c12d3e7a..e185cdebbf 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -83,7 +83,7 @@ jobs: pip install . - name: Run sglang benchmark int8 - if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} + if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'sglang')) }} run: | source ./scripts/capture-hw-details.sh @@ -92,7 +92,7 @@ jobs: python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-int8-report.csv --tag $TAG --benchmark scaled-mm-int8 - name: Run sglang benchmark with fp8 - if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} + if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'sglang')) }} run: | source ./scripts/capture-hw-details.sh diff --git a/benchmarks/third_party/sglang/scaled_mm_benchmark.py b/benchmarks/third_party/sglang/scaled_mm_benchmark.py index 3b8e38b113..0327294aad 100644 --- a/benchmarks/third_party/sglang/scaled_mm_benchmark.py +++ b/benchmarks/third_party/sglang/scaled_mm_benchmark.py @@ -10,17 +10,6 @@ import triton_kernels_benchmark as benchmark_suite -# Import vLLM MoE functions -# from vllm.model_executor.layers.fused_moe.fused_batched_moe import invoke_moe_batched_triton_kernel -# from vllm.platforms import current_platform -# from vllm.model_executor.layers.fused_moe.utils import normalize_batched_scales_shape - -# # Import utility functions from vLLM tests -# from tests.kernels.moe.utils import make_quantized_test_activations, make_test_weight -# from tests.kernels.quant_utils import native_batched_masked_quant_matmul - -import torch.testing - def is_weak_contiguous(x: torch.Tensor): strides = x.stride() @@ -118,7 +107,7 @@ def scaled_mm_kernel( scale_a_ptrs = scale_a_ptr + offsets_scale_am scale_b_ptrs = scale_b_ptr + offsets_scale_bn - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)): masks_k = offsets_k < K masks_a = masks_am[:, None] & masks_k[None, :] a = tl.load(a_ptrs, mask=masks_a) @@ -239,7 +228,7 @@ def scaled_mm_kernel_td( scale_b_ptrs = scale_b_ptr + offsets_scale_bn off_k = 0 - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # masks_k = offsets_k < K # masks_a = masks_am[:, None] & masks_k[None, :] # a = tl.load(a_ptrs, mask=masks_a) @@ -302,7 +291,7 @@ def scaled_mm_kernel_td( # weight - [K, N] # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py def triton_scaled_mm( - input: torch.Tensor, + input: torch.Tensor, # pylint: disable=redefined-builtin weight: torch.Tensor, result: torch.Tensor, scale_a: torch.Tensor, @@ -350,6 +339,8 @@ def triton_scaled_mm( tile_shape = (64, 128, 128) else: tile_shape = (128, 128, 128) + else: + raise NotImplementedError('Only heuristic-based tile size selection is supported currently.') block_size_m, block_size_n, block_size_k = tile_shape @@ -384,7 +375,6 @@ def triton_scaled_mm( BLOCK_SIZE_SCALE_A=block_size_sa, BLOCK_SIZE_SCALE_B=block_size_sb, ) - return result @@ -469,15 +459,15 @@ def benchmark(M, N, K, provider, with_bias=False): else: in_dtype, out_dtype = torch.int8, torch.bfloat16 - input, weight = _make_inputs(M, K, N, in_dtype) + x, weight = _make_inputs(M, K, N, in_dtype) scale_a = 0.1 + 0.05 * torch.rand((M, 1), dtype=torch.float32, device=device) scale_b = 0.1 + 0.05 * torch.rand((N, 1), dtype=torch.float32, device=device) bias = (0.01 * torch.randn((M, N), dtype=out_dtype, device=device) if with_bias else None) - ref = torch.empty((M, N), dtype=out_dtype, device=input.device) + ref = torch.empty((M, N), dtype=out_dtype, device=x.device) def torch_fn(): - return torch_scaled_mm(input, weight, ref, scale_a, scale_b, out_dtype, bias) + return torch_scaled_mm(x, weight, ref, scale_a, scale_b, out_dtype, bias) # Use relaxed tolerances rtol = 0.15 if in_dtype == torch.int8 else 0.25 @@ -492,12 +482,12 @@ def torch_fn(): quantiles=quantiles, ) - elif provider == 'triton' or provider == 'triton-td': - result = torch.empty((M, N), dtype=out_dtype, device=input.device) + elif provider in ('triton', 'triton-td'): + result = torch.empty((M, N), dtype=out_dtype, device=x.device) # invoke_kernel = invoke_moe_batched_triton_kernel if provider == 'triton' else invoke_moe_batched_triton_kernel_td def triton_fn(): - return triton_scaled_mm(input, weight, result, scale_a, scale_b, out_dtype, bias, + return triton_scaled_mm(x, weight, result, scale_a, scale_b, out_dtype, bias, use_td_kernel=provider == 'triton-td') benchmark_suite.assert_close(triton_fn, torch_fn, atol=atol, rtol=rtol, err_msg='triton to torch') From a5d23fd228f296a44b4de7371888ac87f0c29dba Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 20 Oct 2025 13:04:14 +0000 Subject: [PATCH 06/11] Fixed perprocessing --- .github/workflows/third-party-benchmarks.yml | 10 ++++++---- .../third_party/vllm/transform_results.py | 20 +++++++++---------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index e185cdebbf..85c8552624 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -89,7 +89,7 @@ jobs: cd benchmarks/third_party/sglang python scaled_mm_benchmark.py --reports $REPORTS - python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-int8-report.csv --tag $TAG --benchmark scaled-mm-int8 + python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-int8-report.csv --tag $TAG --benchmark scaled-mm-int8 --param_cols="M,N,K" - name: Run sglang benchmark with fp8 if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'sglang')) }} @@ -98,7 +98,7 @@ jobs: cd benchmarks/third_party/sglang FP8="1" python scaled_mm_benchmark.py --reports $REPORTS - python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-fp8-report.csv --tag $TAG --benchmark scaled-mm-fp8 + python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-fp8-report.csv --tag $TAG --benchmark scaled-mm-fp8 --param_cols="M,N,K" - name: Run vllm benchmarks bf16 if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} @@ -110,7 +110,8 @@ jobs: cd benchmarks/third_party/vllm python batched_moe_benchmark.py --reports $REPORTS - python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-report.csv --tag $TAG --benchmark moe-bf16-benchmark + python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-report.csv --tag $TAG --benchmark moe-bf16-benchmark --param_cols="num_experts,max_tokens_per_expert,K,N" + - name: Run vllm benchmarks fp8 if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} @@ -119,7 +120,8 @@ jobs: cd benchmarks/third_party/vllm FP8="1" python batched_moe_benchmark.py --reports $REPORTS - python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-fp8-report.csv --tag $TAG --benchmark moe-fp8-benchmark + python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-fp8-report.csv --tag $TAG --benchmark moe-fp8-benchmark --param_cols="num_experts,max_tokens_per_expert,K,N" + - name: Run Liger-Kernel benchmarks if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'liger-kernel')) }} diff --git a/benchmarks/third_party/vllm/transform_results.py b/benchmarks/third_party/vllm/transform_results.py index 66ada1157c..d583e5694c 100644 --- a/benchmarks/third_party/vllm/transform_results.py +++ b/benchmarks/third_party/vllm/transform_results.py @@ -11,14 +11,19 @@ def parse_args(): parser = argparse.ArgumentParser(description='Parse MoE benchmark CSV') parser.add_argument('source', help='Path to the MoE benchmark CSV file') parser.add_argument('target', help='Path to output CSV file') + parser.add_argument( + '--param_cols', + help='Names of parameter columns, separated by commas.', + required=True, + ) parser.add_argument('--tag', help='Tag for the benchmark run', default='') parser.add_argument('--benchmark', help='moe-benchmark', default='') return parser.parse_args() -def parse_moe_csv(csv_file_path, tag, benchmark): - """Parse the MoE benchmark CSV and extract performance metrics.""" +def parse_csv(csv_file_path, tag, benchmark, param_cols): + """Parse the benchmark CSV and extract performance metrics.""" df = pd.read_csv(csv_file_path) @@ -26,13 +31,7 @@ def parse_moe_csv(csv_file_path, tag, benchmark): current_datetime = datetime.now().isoformat() # Create params for all rows vectorized - df['params'] = df.apply( - lambda row: json.dumps({ - 'num_experts': int(row['num_experts']), - 'max_tokens_per_expert': int(row['max_tokens_per_expert']), - 'K': int(row['K']), - 'N': int(row['N']), - }), axis=1) + df['params'] = df.apply(lambda row: json.dumps({p: int(row[p]) for p in param_cols}), axis=1) # Define compiler columns compilers = [('triton', 'triton-TFlops'), ('pytorch', 'pytorch-TFlops'), ('triton-td', 'triton-td-TFlops')] @@ -90,7 +89,8 @@ def main(): if not os.path.exists(args.source): raise ValueError(f'Error: CSV file {args.source} not found') - df_results = parse_moe_csv(args.source, args.tag, args.benchmark) + param_cols = args.param_cols.split(',') + df_results = parse_csv(args.source, args.tag, args.benchmark, param_cols) df_results.to_csv(args.target, index=False) From 48df63f854f25e9ce2483f9aa054afff0f08e753 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Tue, 21 Oct 2025 11:11:18 +0000 Subject: [PATCH 07/11] Fixed benchmark groups --- .github/workflows/third-party-benchmarks.yml | 8 ++++---- benchmarks/third_party/vllm/transform_results.py | 9 +++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index 85c8552624..b337fbfa62 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -89,7 +89,7 @@ jobs: cd benchmarks/third_party/sglang python scaled_mm_benchmark.py --reports $REPORTS - python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-int8-report.csv --tag $TAG --benchmark scaled-mm-int8 --param_cols="M,N,K" + python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-int8-report.csv --tag $TAG --benchmark scaled-mm-int8 --param_cols="M,N,K" --bgroup sglang - name: Run sglang benchmark with fp8 if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'sglang')) }} @@ -98,7 +98,7 @@ jobs: cd benchmarks/third_party/sglang FP8="1" python scaled_mm_benchmark.py --reports $REPORTS - python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-fp8-report.csv --tag $TAG --benchmark scaled-mm-fp8 --param_cols="M,N,K" + python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-fp8-report.csv --tag $TAG --benchmark scaled-mm-fp8 --param_cols="M,N,K" --bgroup sglang - name: Run vllm benchmarks bf16 if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} @@ -110,7 +110,7 @@ jobs: cd benchmarks/third_party/vllm python batched_moe_benchmark.py --reports $REPORTS - python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-report.csv --tag $TAG --benchmark moe-bf16-benchmark --param_cols="num_experts,max_tokens_per_expert,K,N" + python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-report.csv --tag $TAG --benchmark moe-bf16-benchmark --param_cols="num_experts,max_tokens_per_expert,K,N" --bgroup vllm - name: Run vllm benchmarks fp8 @@ -120,7 +120,7 @@ jobs: cd benchmarks/third_party/vllm FP8="1" python batched_moe_benchmark.py --reports $REPORTS - python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-fp8-report.csv --tag $TAG --benchmark moe-fp8-benchmark --param_cols="num_experts,max_tokens_per_expert,K,N" + python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-fp8-report.csv --tag $TAG --benchmark moe-fp8-benchmark --param_cols="num_experts,max_tokens_per_expert,K,N" --bgroup vllm - name: Run Liger-Kernel benchmarks diff --git a/benchmarks/third_party/vllm/transform_results.py b/benchmarks/third_party/vllm/transform_results.py index d583e5694c..e5baa9ab42 100644 --- a/benchmarks/third_party/vllm/transform_results.py +++ b/benchmarks/third_party/vllm/transform_results.py @@ -17,12 +17,13 @@ def parse_args(): required=True, ) parser.add_argument('--tag', help='Tag for the benchmark run', default='') - parser.add_argument('--benchmark', help='moe-benchmark', default='') + parser.add_argument('--benchmark', help='moe-benchmark', required=True) + parser.add_argument('--bgroup', help='Benchmark group', required=True) return parser.parse_args() -def parse_csv(csv_file_path, tag, benchmark, param_cols): +def parse_csv(csv_file_path, tag, bench_group, benchmark, param_cols): """Parse the benchmark CSV and extract performance metrics.""" df = pd.read_csv(csv_file_path) @@ -45,7 +46,7 @@ def parse_csv(csv_file_path, tag, benchmark, param_cols): if len(valid_rows) > 0: valid_rows['run_uuid'] = run_uuid valid_rows['ts'] = current_datetime - valid_rows['benchmark_group'] = 'moe-benchmark' + valid_rows['benchmark_group'] = bench_group valid_rows['benchmark'] = benchmark valid_rows['compiler'] = compiler_name valid_rows['value_name'] = 'tflops' @@ -90,7 +91,7 @@ def main(): raise ValueError(f'Error: CSV file {args.source} not found') param_cols = args.param_cols.split(',') - df_results = parse_csv(args.source, args.tag, args.benchmark, param_cols) + df_results = parse_csv(args.source, args.tag, args.bgroup, args.benchmark, param_cols) df_results.to_csv(args.target, index=False) From 95781ad0c019f87c9c22d60e0ad42e14bf4277fd Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Thu, 23 Oct 2025 13:32:15 +0000 Subject: [PATCH 08/11] Fixed sglang installation with pin and testing --- .github/workflows/sglang-tests.yml | 100 -------- .github/workflows/third-party-tests.yml | 81 +++++- .../third_party/sglang/sglang-fix.patch | 242 ++++++++++++++++-- scripts/test-triton.sh | 44 +++- 4 files changed, 320 insertions(+), 147 deletions(-) delete mode 100644 .github/workflows/sglang-tests.yml diff --git a/.github/workflows/sglang-tests.yml b/.github/workflows/sglang-tests.yml deleted file mode 100644 index dc5cabc991..0000000000 --- a/.github/workflows/sglang-tests.yml +++ /dev/null @@ -1,100 +0,0 @@ -name: Third party SGLang tests - -on: - workflow_dispatch: - inputs: - runner_label: - description: Runner label, keep empty for default - type: string - default: "" - use_pyenv_python: - description: Use Python built with pyenv - type: boolean - default: false - schedule: - # About midnight PST Sunday (UTC-8) - - cron: "5 10 * * SUN" - - -# Cancels in-progress PR runs when the PR is updated. Manual runs are never cancelled. -concurrency: - group: ${{ github.workflow }}-${{ github.event_name == 'workflow_dispatch' && github.run_id || github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -permissions: read-all - -env: - PYTHON_VERSION: "3.10" - TAG: ${{ inputs.tag || (github.event_name == 'pull_request' && format('pr-{0}', github.event.number)) || (github.event_name == 'schedule' && 'ci') || 'test' }} - -jobs: - build: - name: SGLang tests - runs-on: - - linux - - ${{ inputs.runner_label || 'rolling' }} - timeout-minutes: 720 - defaults: - run: - shell: bash -noprofile --norc -eo pipefail -c "source /opt/intel/oneapi/setvars.sh > /dev/null; source {0}" - steps: - - name: Print inputs - run: | - cat <> $GITHUB_ENV - - - name: Install SGLang - id: install - run: | - git clone https://github.com/sgl-project/sglang.git - cd sglang - git apply ../benchmarks/third_party/sglang/sglang-fix.patch - pip install "./python[dev_xpu]" - - - name: Setup PyTorch - uses: ./.github/actions/setup-pytorch - - - name: Setup Triton - uses: ./.github/actions/setup-triton - - - name: Run SGLANG tests - if: ${{ steps.install.outcome == 'success' && !cancelled() }} - run: | - ./scripts/test-triton.sh --sglang --skip-pip-install --skip-pytorch-install - - - name: Upload test report - if: ${{ steps.install.outcome == 'success' && !cancelled() }} - uses: actions/upload-artifact@v4 - with: - name: test-reports - path: reports diff --git a/.github/workflows/third-party-tests.yml b/.github/workflows/third-party-tests.yml index 41a38b5c3f..8415d15b86 100644 --- a/.github/workflows/third-party-tests.yml +++ b/.github/workflows/third-party-tests.yml @@ -1,4 +1,4 @@ -name: Third party tests [liger-kernels, vllm] +name: Third party tests [liger-kernels, vllm, sglang] on: workflow_dispatch: @@ -28,12 +28,12 @@ env: TAG: ${{ inputs.tag || (github.event_name == 'pull_request' && format('pr-{0}', github.event.number)) || (github.event_name == 'schedule' && 'ci') || 'test' }} jobs: - build: - name: Third party tests [liger-kernels, vllm] + small-tests: + name: Third party tests [vllm, sglang] runs-on: - linux - ${{ inputs.runner_label || 'max1550' }} - timeout-minutes: 720 + timeout-minutes: 120 defaults: run: shell: bash -noprofile --norc -eo pipefail -c "source /opt/intel/oneapi/setvars.sh > /dev/null; source {0}" @@ -47,14 +47,7 @@ jobs: - name: Checkout repository uses: actions/checkout@v5 - - name: Install Python - if: ${{ !(inputs.use_pyenv_python || false) }} - uses: actions/setup-python@v6 - with: - python-version: ${{ env.PYTHON_VERSION }} - - name: Install Python (from pyenv) ${{ inputs.python_version }} - if: ${{ inputs.use_pyenv_python }} uses: ./.github/actions/setup-pyenv-python with: python-version: ${{ env.PYTHON_VERSION }} @@ -86,13 +79,75 @@ jobs: mkdir reports echo "REPORTS=$PWD/reports" >> $GITHUB_ENV + - name: Run SGLANG tests + if: ${{ steps.install.outcome == 'success' && !cancelled() }} + run: | + ./scripts/test-triton.sh --sglang --skip-pip-install --skip-pytorch-install + - name: Run VLLM tests if: ${{ steps.install.outcome == 'success' && !cancelled() }} run: | ./scripts/test-triton.sh --vllm --skip-pip-install --skip-pytorch-install - - name: Run Liger-Kernel tests + - name: Upload test report if: ${{ steps.install.outcome == 'success' && !cancelled() }} + uses: actions/upload-artifact@v4 + with: + name: test-main-reports + path: reports + # We run all tests for Liger, so it's slow and we test it separately + liger: + name: Liger testing + runs-on: + - linux + - ${{ inputs.runner_label || 'max1550' }} + timeout-minutes: 120 + defaults: + run: + shell: bash -noprofile --norc -eo pipefail -c "source /opt/intel/oneapi/setvars.sh > /dev/null; source {0}" + steps: + - name: Print inputs + run: | + cat <> $GITHUB_ENV + + - name: Run Liger-Kernel tests run: | ./scripts/test-triton.sh --liger --skip-pip-install --skip-pytorch-install @@ -100,5 +155,5 @@ jobs: if: ${{ steps.install.outcome == 'success' && !cancelled() }} uses: actions/upload-artifact@v4 with: - name: test-reports + name: test-liger-reports path: reports diff --git a/benchmarks/third_party/sglang/sglang-fix.patch b/benchmarks/third_party/sglang/sglang-fix.patch index 9b9d38dc43..b3769b6385 100644 --- a/benchmarks/third_party/sglang/sglang-fix.patch +++ b/benchmarks/third_party/sglang/sglang-fix.patch @@ -1,9 +1,9 @@ -diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py -index bc2affa1..8ef91e66 100644 ---- a/python/sglang/srt/utils.py -+++ b/python/sglang/srt/utils.py -@@ -228,6 +228,22 @@ def is_flashinfer_available(): - return importlib.util.find_spec("flashinfer") is not None and is_cuda() +diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py +index 7c2f573e4..8023cd6be 100644 +--- a/python/sglang/srt/utils/common.py ++++ b/python/sglang/srt/utils/common.py +@@ -155,12 +155,44 @@ def is_cpu() -> bool: + return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86() +def auto_detect_device(): @@ -22,26 +22,48 @@ index bc2affa1..8ef91e66 100644 + return "cpu" + + - _ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var( - "SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false" - ) + def get_cuda_version(): + if torch.version.cuda: + return tuple(map(int, torch.version.cuda.split("."))) + return (0, 0) + + ++def auto_detect_device(): ++ """ ++ Infer the device type based on the current environment. ++ """ ++ if is_cuda_alike(): ++ return "cuda" ++ elif is_xpu(): ++ return "xpu" ++ elif is_hpu(): ++ return "hpu" ++ elif is_npu(): ++ return "npu" ++ else: ++ return "cpu" ++ ++ + def _check(cc_major): + if not is_cuda(): + return False diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py -index 47eb16a9..cce70fb9 100644 +index 16c107006..03b9411fa 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py -@@ -16,8 +16,11 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import ( +@@ -18,8 +18,11 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import ( + from sglang.srt.layers.attention.triton_ops.prefill_attention import ( context_attention_fwd, ) - from sglang.test.test_utils import CustomTestCase +from sglang.srt.utils import auto_detect_device - + from sglang.test.test_utils import CustomTestCase +device = auto_detect_device() + - class TestTritonAttention(CustomTestCase): - def _set_all_seeds(self, seed): -@@ -37,24 +40,24 @@ class TestTritonAttention(CustomTestCase): + def extend_attention_fwd_torch( + q: torch.Tensor, # [extend_tokens, H_Q, D] +@@ -114,24 +117,24 @@ class TestTritonAttention(CustomTestCase): dtype = torch.bfloat16 b_seq_len_prefix = torch.randint( @@ -73,7 +95,7 @@ index 47eb16a9..cce70fb9 100644 ) for i in range(B): -@@ -65,15 +68,15 @@ class TestTritonAttention(CustomTestCase): +@@ -142,15 +145,15 @@ class TestTritonAttention(CustomTestCase): total_token_num = torch.sum(b_seq_len).item() extend_token_num = torch.sum(b_seq_len_extend).item() k_buffer = torch.empty( @@ -94,7 +116,7 @@ index 47eb16a9..cce70fb9 100644 for i in range(B): extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] -@@ -86,20 +89,20 @@ class TestTritonAttention(CustomTestCase): +@@ -163,20 +166,20 @@ class TestTritonAttention(CustomTestCase): extend_start_in_buffer:extend_end_in_buffer ] q_extend[extend_start:extend_end] = torch.empty( @@ -120,7 +142,7 @@ index 47eb16a9..cce70fb9 100644 qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) custom_mask = None -@@ -123,9 +126,9 @@ class TestTritonAttention(CustomTestCase): +@@ -200,9 +203,9 @@ class TestTritonAttention(CustomTestCase): b_seq_mask_len = b_seq_len_extend * b_seq_len custom_mask = torch.ones( @@ -132,7 +154,81 @@ index 47eb16a9..cce70fb9 100644 mask_indptr[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0) for i in range(B): causal_mask = ( -@@ -187,14 +190,14 @@ class TestTritonAttention(CustomTestCase): +@@ -263,22 +266,22 @@ class TestTritonAttention(CustomTestCase): + dtype = torch.bfloat16 + + b_seq_len_prefix = torch.randint( +- 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" ++ 1, N_CTX // 2, (B,), dtype=torch.int32, device=device + ) + b_seq_len_extend = torch.randint( +- 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" ++ 1, N_CTX // 2, (B,), dtype=torch.int32, device=device + ) + b_seq_len = b_seq_len_prefix + b_seq_len_extend + +- b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") ++ b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device) + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) +- b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") ++ b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device) + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + +- kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") ++ kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) + kv_indices = torch.zeros( +- (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda" ++ (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device=device + ) + + for i in range(B): +@@ -289,15 +292,15 @@ class TestTritonAttention(CustomTestCase): + total_token_num = torch.sum(b_seq_len).item() + extend_token_num = torch.sum(b_seq_len_extend).item() + k_buffer = torch.empty( +- (total_token_num, H_KV, D), dtype=dtype, device="cuda" ++ (total_token_num, H_KV, D), dtype=dtype, device=device + ).normal_(mean=0.1, std=0.2) + v_buffer = torch.empty( +- (total_token_num, H_KV, D), dtype=dtype, device="cuda" ++ (total_token_num, H_KV, D), dtype=dtype, device=device + ).normal_(mean=0.1, std=0.2) + +- k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") +- v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") +- q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") ++ k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) ++ v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) ++ q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) + for i in range(B): + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] +@@ -310,19 +313,19 @@ class TestTritonAttention(CustomTestCase): + extend_start_in_buffer:extend_end_in_buffer + ] + q_extend[extend_start:extend_end] = torch.empty( +- (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" ++ (b_seq_len_extend[i], H_Q, D), dtype=dtype, device=device + ).normal_(mean=0.1, std=0.2) + + o_extend_triton = torch.empty( +- (extend_token_num, H_Q, D), dtype=dtype, device="cuda" ++ (extend_token_num, H_Q, D), dtype=dtype, device=device + ) + o_extend_torch = torch.empty( +- (extend_token_num, H_Q, D), dtype=dtype, device="cuda" ++ (extend_token_num, H_Q, D), dtype=dtype, device=device + ) + + b_seq_len_extend = b_seq_len - b_seq_len_prefix + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() +- qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") ++ qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) + qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) + + extend_attention_fwd( +@@ -373,14 +376,14 @@ class TestTritonAttention(CustomTestCase): max_seq_len = max(seq_lens) # Create random input tensors @@ -153,7 +249,7 @@ index 47eb16a9..cce70fb9 100644 context_attention_fwd( q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal -@@ -232,33 +235,33 @@ class TestTritonAttention(CustomTestCase): +@@ -418,33 +421,33 @@ class TestTritonAttention(CustomTestCase): total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) max_kv_splits = 8 @@ -197,7 +293,7 @@ index 47eb16a9..cce70fb9 100644 ) decode_attention_fwd( -@@ -296,34 +299,34 @@ class TestTritonAttention(CustomTestCase): +@@ -482,34 +485,34 @@ class TestTritonAttention(CustomTestCase): total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) max_kv_splits = 8 @@ -243,7 +339,7 @@ index 47eb16a9..cce70fb9 100644 ) decode_attention_fwd_normal( -@@ -343,12 +346,12 @@ class TestTritonAttention(CustomTestCase): +@@ -529,12 +532,12 @@ class TestTritonAttention(CustomTestCase): attn_logits1 = torch.empty( (B, H_Q, max_kv_splits, D_V), dtype=torch.float32, @@ -258,3 +354,103 @@ index 47eb16a9..cce70fb9 100644 ) decode_attention_fwd_grouped( +@@ -578,23 +581,23 @@ class TestTritonAttention(CustomTestCase): + dtype = torch.bfloat16 + + b_seq_len_prefix = torch.randint( +- 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" ++ 1, N_CTX // 2, (B,), dtype=torch.int32, device=device + ) + b_seq_len_extend = torch.randint( +- 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" ++ 1, N_CTX // 2, (B,), dtype=torch.int32, device=device + ) + b_seq_len = b_seq_len_prefix + b_seq_len_extend + +- b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") ++ b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device) + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) +- b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") ++ b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device) + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + # Setup prefix KV indices +- kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") ++ kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) + kv_indices = torch.zeros( +- (b_seq_len_prefix.sum().item(),), dtype=torch.int64, device="cuda" ++ (b_seq_len_prefix.sum().item(),), dtype=torch.int64, device=device + ) + + for i in range(B): +@@ -605,15 +608,15 @@ class TestTritonAttention(CustomTestCase): + total_token_num = torch.sum(b_seq_len).item() + extend_token_num = torch.sum(b_seq_len_extend).item() + k_buffer = torch.empty( +- (total_token_num, H_KV, D), dtype=dtype, device="cuda" ++ (total_token_num, H_KV, D), dtype=dtype, device=device + ).normal_(mean=0.1, std=0.2) + v_buffer = torch.empty( +- (total_token_num, H_KV, D), dtype=dtype, device="cuda" ++ (total_token_num, H_KV, D), dtype=dtype, device=device + ).normal_(mean=0.1, std=0.2) + +- k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") +- v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") +- q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") ++ k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) ++ v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) ++ q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) + + for i in range(B): + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] +@@ -627,16 +630,16 @@ class TestTritonAttention(CustomTestCase): + extend_start_in_buffer:extend_end_in_buffer + ] + q_extend[extend_start:extend_end] = torch.empty( +- (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" ++ (b_seq_len_extend[i], H_Q, D), dtype=dtype, device=device + ).normal_(mean=0.1, std=0.2) + + # Setup for extend attention + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() +- qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") ++ qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) + qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) + + # Run 2-stage kernel +- o_regular = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") ++ o_regular = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) + extend_attention_fwd( + q_extend, + k_extend, +@@ -658,9 +661,9 @@ class TestTritonAttention(CustomTestCase): + total_token_num - extend_token_num, + total_token_num, + dtype=torch.int64, +- device="cuda", ++ device=device, + ) +- extend_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") ++ extend_start_loc = torch.zeros((B,), dtype=torch.int32, device=device) + extend_start_loc[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + unified_kv_indptr, unified_kv_indices, prefix_lens = build_unified_kv_indices( +@@ -673,7 +676,7 @@ class TestTritonAttention(CustomTestCase): + ) + + # Run unified kernel +- o_unified = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") ++ o_unified = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) + extend_attention_fwd_unified( + q_extend, + o_unified, +@@ -716,7 +719,6 @@ class TestTritonAttention(CustomTestCase): + """Test build_unified_kv_indices correctness.""" + B = 4 + dtype = torch.int64 +- device = "cuda" + + # Setup test data + prefix_lens = torch.tensor([10, 20, 15, 25], dtype=torch.int32, device=device) diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index 0bdc5de7ad..27ef883d99 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -30,6 +30,7 @@ TEST: --liger --vllm --install-vllm + --install-sglang OPTION: --unskip @@ -74,6 +75,7 @@ TEST_SGLANG=false TEST_LIGER=false TEST_VLLM=false INSTALL_VLLM=false +INSTALL_SGLANG=false TEST_TRITON_KERNELS=false VENV=false TRITON_TEST_REPORTS=false @@ -190,6 +192,11 @@ while (( $# != 0 )); do TEST_DEFAULT=false shift ;; + --install-sglang) + INSTALL_SGLANG=true + TEST_DEFAULT=false + shift + ;; --sglang) TEST_SGLANG=true TEST_DEFAULT=false @@ -589,26 +596,38 @@ run_inductor_tests() { grep AlbertForMaskedLM inductor_log.csv | grep -q ,pass, } -run_sglang_tests() { - echo "***************************************************" - echo "****** Running SGLang Triton tests ******" - echo "***************************************************" +run_sglang_install() { + echo "************************************************" + echo "****** Installing SGLang ****" + echo "************************************************" if ! [ -d "./sglang" ]; then git clone https://github.com/sgl-project/sglang.git fi - cd sglang if ! pip list | grep "sglang" ; then - git apply $TRITON_PROJ/benchmarks/third_party/sglang/sglang-fix.patch + cd sglang + git checkout "$(<../benchmarks/third_party/sglang/sglang-pin.txt)" + git apply ../benchmarks/third_party/sglang/sglang-fix.patch + + # That's how sglang assumes we'll pick out platform for now + cp python/pyproject_xpu.toml python/pyproject.toml + # We should remove all torch libraries from requirements to avoid reinstalling triton & torch + # We remove sgl kernel due to a bug in the current environment probably due to using newer torch + sed -i '/pytorch\|torch\|sgl-kernel/d' python/pyproject.toml pip install "./python[dev_xpu]" - - # SGLang installation breaks the default PyTorch and Triton versions, so we need to reinstall them. - $SCRIPTS_DIR/install-pytorch.sh --force-reinstall - $SCRIPTS_DIR/compile-triton.sh --triton + cd .. fi - pip install pytest pytest-xdist + pip install pytest pytest-cov pytest-xdist +} + +run_sglang_tests() { + echo "***************************************************" + echo "****** Running SGLang Triton tests ******" + echo "***************************************************" + + run_sglang_install run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-4} test/srt/test_triton_attention_kernels.py } @@ -745,6 +764,9 @@ test_triton() { if [ "$TEST_INDUCTOR" == true ]; then run_inductor_tests fi + if [ "$INSTALL_SGLANG" == true ]; then + run_sglang_install + fi if [ "$TEST_SGLANG" == true ]; then run_sglang_tests fi From a31ea9fcfb4260c8ef7b70a958ca8a65ad72e313 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Tue, 28 Oct 2025 11:12:34 +0000 Subject: [PATCH 09/11] Removed duplicate --- scripts/test-triton.sh | 6 ------ 1 file changed, 6 deletions(-) diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index a5484eea8d..69e72e945d 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -78,7 +78,6 @@ TEST_LIGER=false INSTALL_LIGER=false TEST_VLLM=false INSTALL_VLLM=false -INSTALL_SGLANG=false TEST_TRITON_KERNELS=false VENV=false TRITON_TEST_REPORTS=false @@ -195,11 +194,6 @@ while (( $# != 0 )); do TEST_DEFAULT=false shift ;; - --install-sglang) - INSTALL_SGLANG=true - TEST_DEFAULT=false - shift - ;; --sglang) TEST_SGLANG=true TEST_DEFAULT=false From f7e86ccd2fd1254fa4b78dd24603e9d186addffa Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Tue, 28 Oct 2025 14:39:39 +0000 Subject: [PATCH 10/11] Update for extrnal dependency --- .../third_party/sglang/scaled_mm_benchmark.py | 149 +------------ .../third_party/sglang/sglang-bench-fix.patch | 208 ++++++++++++++++++ ...sglang-fix.patch => sglang-test-fix.patch} | 0 scripts/test-triton.sh | 3 +- 4 files changed, 220 insertions(+), 140 deletions(-) create mode 100644 benchmarks/third_party/sglang/sglang-bench-fix.patch rename benchmarks/third_party/sglang/{sglang-fix.patch => sglang-test-fix.patch} (100%) diff --git a/benchmarks/third_party/sglang/scaled_mm_benchmark.py b/benchmarks/third_party/sglang/scaled_mm_benchmark.py index 0327294aad..d79850a3b9 100644 --- a/benchmarks/third_party/sglang/scaled_mm_benchmark.py +++ b/benchmarks/third_party/sglang/scaled_mm_benchmark.py @@ -10,6 +10,8 @@ import triton_kernels_benchmark as benchmark_suite +from sglang.srt.layers.quantization.fp8_kernel import triton_scaled_mm + def is_weak_contiguous(x: torch.Tensor): strides = x.stride() @@ -40,124 +42,6 @@ def get_matmul_batched_autotune_configs() -> List[triton.Config]: return configs -# @triton.autotune( -# configs=get_matmul_batched_autotune_configs(), -# key=['M', 'N', 'K']) -@triton.jit -def scaled_mm_kernel( - a_ptr, - b_ptr, - scale_a_ptr, - scale_b_ptr, - c_ptr, - bias_ptr, - M, - N, - K, - stride_am: tl.int64, - stride_ak: tl.constexpr, - stride_bk: tl.constexpr, - stride_bn: tl.constexpr, - stride_cm: tl.constexpr, - stride_cn: tl.constexpr, - ACCUMULATOR_DTYPE: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - BLOCK_SIZE_SCALE_A: tl.constexpr, - BLOCK_SIZE_SCALE_B: tl.constexpr, -): - pid = tl.program_id(axis=0) - - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - accumulator_dtype = ACCUMULATOR_DTYPE - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) - - # NOTE: Some tensor inputs are so large, they will cause int32 overflow - # so it is necessary to use tl.int64 for all the offsets, else SEGV will - # eventually occur. - - # Offsets and masks. - offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) - masks_am = offsets_am < M - - offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) - masks_bn = offsets_bn < N - - offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) - offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :] - offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :] - - # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create - # appropriate offsets and masks for each case. Same goes for - # BLOCK_SIZE_SCALE_B. - offsets_scale_am = tl.arange(0, BLOCK_SIZE_SCALE_A) + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M - masks_scale_am = offsets_scale_am < M - - offsets_scale_bn = tl.arange(0, BLOCK_SIZE_SCALE_B) + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N - masks_scale_bn = offsets_scale_bn < N - - a_ptrs = a_ptr + offsets_a - b_ptrs = b_ptr + offsets_b - - scale_a_ptrs = scale_a_ptr + offsets_scale_am - scale_b_ptrs = scale_b_ptr + offsets_scale_bn - - for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - masks_k = offsets_k < K - masks_a = masks_am[:, None] & masks_k[None, :] - a = tl.load(a_ptrs, mask=masks_a) - - masks_b = masks_k[:, None] & masks_bn[None, :] - b = tl.load(b_ptrs, mask=masks_b) - - # Accumulate results. - accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) - - offsets_k += BLOCK_SIZE_K - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - # Apply scale at end. - masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None] - scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a) - # Need to broadcast to the appropriate size, if scale_a is already - # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes - # for scale_b below. - scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1)) - accumulator = scale_a * accumulator.to(tl.float32) - - masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :] - scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b) - scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1)) - accumulator = scale_b.T * accumulator.to(tl.float32) - - # Convert to output format. - c = accumulator.to(c_ptr.type.element_ty) - - # Add bias, it's already in output format, so add it after conversion. - if bias_ptr: - offsets_bias = offsets_bn - bias_ptrs = bias_ptr + offsets_bias - bias_mask = offsets_bias < N - bias = tl.load(bias_ptrs, bias_mask) - c += bias - - # Save output - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) - offs_cm = offs_cm.to(tl.int64) - offs_cn = offs_cn.to(tl.int64) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - - tl.store(c_ptrs, c, mask=c_mask) - - @triton.jit def scaled_mm_kernel_td( a_ptr, @@ -290,10 +174,9 @@ def scaled_mm_kernel_td( # input - [M, K] # weight - [K, N] # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py -def triton_scaled_mm( +def triton_scaled_mm_td( input: torch.Tensor, # pylint: disable=redefined-builtin weight: torch.Tensor, - result: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: type[torch.dtype], @@ -302,7 +185,6 @@ def triton_scaled_mm( block_size_n: int = 32, block_size_k: int = 32, use_heuristic=True, - use_td_kernel=False, ) -> torch.Tensor: M, K = input.shape N = weight.shape[1] @@ -324,7 +206,7 @@ def triton_scaled_mm( grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - # result = torch.empty((M, N), dtype=out_dtype, device=input.device) + result = torch.empty((M, N), dtype=out_dtype, device=input.device) has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1 @@ -349,10 +231,9 @@ def triton_scaled_mm( accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32 - kernel = scaled_mm_kernel if not use_td_kernel else scaled_mm_kernel_td # A = input, B = weight, C = result # A = M x K, B = K x N, C = M x N - kernel[grid]( + scaled_mm_kernel_td[grid]( input, weight, scale_a, @@ -385,7 +266,6 @@ def triton_scaled_mm( def torch_scaled_mm( a: torch.Tensor, b: torch.Tensor, - result: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, @@ -396,9 +276,7 @@ def torch_scaled_mm( out = scale_a.to(torch.float32) * out * scale_b.to(torch.float32).T if bias is not None: out = out + bias.to(torch.float32) - result[:] = out.to(out_dtype) - return result - # return out.to(out_dtype) + return out.to(out_dtype) def _make_inputs(M, K, N, in_dtype): @@ -429,11 +307,8 @@ def get_scaled_mm_benchmark( supported_providers = { 'triton': 'triton', 'triton-td': 'triton-td', - 'pytorch': 'pytorch', + 'pytorch': 'pytorch-deqmm', } - if fp8: - pass - providers = benchmark_suite.filter_providers(supported_providers, providers_filter) @benchmark_suite.perf_report( @@ -464,10 +339,8 @@ def benchmark(M, N, K, provider, with_bias=False): scale_b = 0.1 + 0.05 * torch.rand((N, 1), dtype=torch.float32, device=device) bias = (0.01 * torch.randn((M, N), dtype=out_dtype, device=device) if with_bias else None) - ref = torch.empty((M, N), dtype=out_dtype, device=x.device) - def torch_fn(): - return torch_scaled_mm(x, weight, ref, scale_a, scale_b, out_dtype, bias) + return torch_scaled_mm(x, weight, scale_a, scale_b, bias) # Use relaxed tolerances rtol = 0.15 if in_dtype == torch.int8 else 0.25 @@ -483,12 +356,10 @@ def torch_fn(): ) elif provider in ('triton', 'triton-td'): - result = torch.empty((M, N), dtype=out_dtype, device=x.device) + invoke_kernel = triton_scaled_mm if provider == 'triton' else triton_scaled_mm_td - # invoke_kernel = invoke_moe_batched_triton_kernel if provider == 'triton' else invoke_moe_batched_triton_kernel_td def triton_fn(): - return triton_scaled_mm(x, weight, result, scale_a, scale_b, out_dtype, bias, - use_td_kernel=provider == 'triton-td') + return invoke_kernel(x, weight, scale_a, scale_b, out_dtype, bias) benchmark_suite.assert_close(triton_fn, torch_fn, atol=atol, rtol=rtol, err_msg='triton to torch') diff --git a/benchmarks/third_party/sglang/sglang-bench-fix.patch b/benchmarks/third_party/sglang/sglang-bench-fix.patch new file mode 100644 index 0000000000..7c23b93a09 --- /dev/null +++ b/benchmarks/third_party/sglang/sglang-bench-fix.patch @@ -0,0 +1,208 @@ +diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py +index 3aaf301bb..68b6520d4 100644 +--- a/python/sglang/srt/layers/linear.py ++++ b/python/sglang/srt/layers/linear.py +@@ -18,9 +18,9 @@ from sglang.srt.distributed import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + ) +-from sglang.srt.distributed.device_communicators.pynccl_allocator import ( +- use_symmetric_memory, +-) ++# from sglang.srt.distributed.device_communicators.pynccl_allocator import ( ++ # use_symmetric_memory, ++# ) + from sglang.srt.layers.parameter import ( + BasevLLMParameter, + BlockQuantScaleParameter, +diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py +index df0658f86..e69de29bb 100644 +--- a/python/sglang/srt/layers/quantization/__init__.py ++++ b/python/sglang/srt/layers/quantization/__init__.py +@@ -1,173 +0,0 @@ +-# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py +-from __future__ import annotations +- +-import builtins +-import inspect +-from typing import TYPE_CHECKING, Dict, Optional, Type +- +-import torch +- +-try: +- from vllm.model_executor.layers.quantization.aqlm import AQLMConfig +- from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig +- from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig +- from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config +- from vllm.model_executor.layers.quantization.gguf import GGUFConfig +- from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( +- GPTQMarlin24Config, +- ) +- from vllm.model_executor.layers.quantization.marlin import MarlinConfig +- from vllm.model_executor.layers.quantization.qqq import QQQConfig +- from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig +- +- VLLM_AVAILABLE = True +-except ImportError as e: +- VLLM_AVAILABLE = False +- VLLM_IMPORT_ERROR = e +- +- # Define empty classes as placeholders when vllm is not available +- class DummyConfig: +- def override_quantization_method(self, *args, **kwargs): +- return None +- +- AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = ( +- ExpertsInt8Config +- ) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = ( +- DummyConfig +- ) +- +- +-from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig +-from sglang.srt.layers.quantization.base_config import QuantizationConfig +-from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config +-from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( +- CompressedTensorsConfig, +-) +-from sglang.srt.layers.quantization.fp8 import Fp8Config +-from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config +-from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig +-from sglang.srt.layers.quantization.modelopt_quant import ( +- ModelOptFp4Config, +- ModelOptFp8Config, +-) +-from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config +-from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config +-from sglang.srt.layers.quantization.petit import PetitNvFp4Config +-from sglang.srt.layers.quantization.qoq import QoQConfig +-from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config +-from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config +-from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config +-from sglang.srt.utils import is_cuda, is_hip, mxfp_supported +- +-_is_mxfp_supported = mxfp_supported() +- +-if TYPE_CHECKING: +- from sglang.srt.layers.moe.topk import TopKOutput +- +-# Base quantization methods that don't depend on vllm +-BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { +- "fp8": Fp8Config, +- "blockwise_int8": BlockInt8Config, +- "modelopt": ModelOptFp8Config, # Auto-detect, defaults to FP8 +- "modelopt_fp8": ModelOptFp8Config, +- "modelopt_fp4": ModelOptFp4Config, +- "w8a8_int8": W8A8Int8Config, +- "w8a8_fp8": W8A8Fp8Config, +- "awq": AWQConfig, +- "awq_marlin": AWQMarlinConfig, +- "gptq": GPTQConfig, +- "gptq_marlin": GPTQMarlinConfig, +- "moe_wna16": MoeWNA16Config, +- "compressed-tensors": CompressedTensorsConfig, +- "qoq": QoQConfig, +- "w4afp8": W4AFp8Config, +- "petit_nvfp4": PetitNvFp4Config, +- "fbgemm_fp8": FBGEMMFp8Config, +-} +- +- +-if is_cuda(): +- BASE_QUANTIZATION_METHODS.update( +- { +- "quark": Mxfp4Config, +- "mxfp4": Mxfp4Config, +- } +- ) +-elif _is_mxfp_supported and is_hip(): +- from sglang.srt.layers.quantization.quark.quark import QuarkConfig +- +- BASE_QUANTIZATION_METHODS.update( +- { +- "quark": QuarkConfig, +- "mxfp4": Mxfp4Config, +- } +- ) +-# VLLM-dependent quantization methods +-VLLM_QUANTIZATION_METHODS = { +- "aqlm": AQLMConfig, +- "deepspeedfp": DeepSpeedFPConfig, +- "tpu_int8": Int8TpuConfig, +- "marlin": MarlinConfig, +- "gguf": GGUFConfig, +- "gptq_marlin_24": GPTQMarlin24Config, +- "bitsandbytes": BitsAndBytesConfig, +- "qqq": QQQConfig, +- "experts_int8": ExpertsInt8Config, +-} +- +-QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS} +- +- +-def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: +- if quantization not in QUANTIZATION_METHODS: +- raise ValueError( +- f"Invalid quantization method: {quantization}. " +- f"Available methods: {list(QUANTIZATION_METHODS.keys())}" +- ) +- if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE: +- raise ValueError( +- f"{quantization} quantization requires some operators from vllm. " +- f"Please install vllm by `pip install vllm==0.9.0.1`\n" +- f"Import error: {VLLM_IMPORT_ERROR}" +- ) +- +- return QUANTIZATION_METHODS[quantization] +- +- +-original_isinstance = builtins.isinstance +- +- +-def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False): +- """ +- Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig +- can recognize sglang layers +- """ +- if not VLLM_AVAILABLE: +- return +- +- if reverse: +- builtins.isinstance = original_isinstance +- return +- +- from vllm.model_executor.layers.fused_moe import FusedMoE +- from vllm.model_executor.layers.linear import LinearBase +- from vllm.model_executor.layers.vocab_parallel_embedding import ( +- VocabParallelEmbedding, +- ) +- +- from sglang.srt.layers.linear import LinearBase as PatchedLinearBase +- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE +- from sglang.srt.layers.vocab_parallel_embedding import ( +- VocabParallelEmbedding as PatchedVocabParallelEmbedding, +- ) +- +- def patched_isinstance(obj, classinfo): +- if classinfo is LinearBase: +- return original_isinstance(obj, PatchedLinearBase) +- if classinfo is FusedMoE: +- return original_isinstance(obj, PatchedFusedMoE) +- if classinfo is VocabParallelEmbedding: +- return original_isinstance(obj, PatchedVocabParallelEmbedding) +- return original_isinstance(obj, classinfo) +- +- builtins.isinstance = patched_isinstance +diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py +index e98b7f6ff..9b4761a30 100644 +--- a/python/sglang/srt/layers/quantization/fp8_kernel.py ++++ b/python/sglang/srt/layers/quantization/fp8_kernel.py +@@ -23,7 +23,7 @@ import torch + import triton + import triton.language as tl + +-from sglang.srt.layers import deep_gemm_wrapper ++# from sglang.srt.layers import deep_gemm_wrapper + from sglang.srt.utils import ( + align, + direct_register_custom_op, diff --git a/benchmarks/third_party/sglang/sglang-fix.patch b/benchmarks/third_party/sglang/sglang-test-fix.patch similarity index 100% rename from benchmarks/third_party/sglang/sglang-fix.patch rename to benchmarks/third_party/sglang/sglang-test-fix.patch diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index 69e72e945d..d9ba1f088c 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -619,7 +619,8 @@ run_sglang_install() { if ! pip list | grep "sglang" ; then cd sglang git checkout "$(<../benchmarks/third_party/sglang/sglang-pin.txt)" - git apply ../benchmarks/third_party/sglang/sglang-fix.patch + git apply ../benchmarks/third_party/sglang/sglang-test-fix.patch + git apply ../benchmarks/third_party/sglang/sglang-bench-fix.patch # That's how sglang assumes we'll pick out platform for now cp python/pyproject_xpu.toml python/pyproject.toml From 38c95397c479d6c86c88dcfffbaa20432c393927 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Tue, 28 Oct 2025 15:09:47 +0000 Subject: [PATCH 11/11] Added sglang install --- .github/workflows/third-party-benchmarks.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index d9a0cfbec4..c1687ddba0 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -87,6 +87,7 @@ jobs: run: | source ./scripts/capture-hw-details.sh + ./scripts/test-triton.sh --install-sglang --skip-pip-install --skip-pytorch-install cd benchmarks/third_party/sglang python scaled_mm_benchmark.py --reports $REPORTS python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-int8-report.csv --tag $TAG --benchmark scaled-mm-int8 --param_cols="M,N,K" --bgroup sglang