Skip to content

Commit b56bb4f

Browse files
authored
misc: Label APIs for Logging (#2153)
<!-- .github/pull_request_template.md --> ## 📌 Description FlashInfer's API Logging feature was enabled in #2108, but only decorated a small number of APIs with `@flashinfer_api`. Current PR adds the decorator to a large number of existing APIs for comprehensive coverage. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Expanded public API: many attention, cascade, pod, page, sparse, quantization, rope, sampling, decode and GEMM entry points are now exposed. * Added high-level FP4/FP8 quantization utilities, new sampling functions (speculative, top-k/top-p, renorm), and RoPE transforms with precomputed-cache & FP8 paths. * **Changes** * layernorm now accepts additional parameters (gemma, beta, eps). * **Documentation** * Updated logging/API doc link in README. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 89e1adb commit b56bb4f

File tree

17 files changed

+103
-1
lines changed

17 files changed

+103
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ export FLASHINFER_LOGLEVEL=3
181181
export FLASHINFER_LOGDEST=stdout
182182
```
183183

184-
For detailed information about logging levels, configuration, and advanced features, see [LOGGING.md](LOGGING.md).
184+
For detailed information about logging levels, configuration, and advanced features, see [Logging](https://docs.flashinfer.ai/logging.html) in our documentation.
185185

186186
## Custom Attention Variants
187187

flashinfer/activation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222

23+
from .api_logging import flashinfer_api
2324
from .jit import gen_act_and_mul_module
2425
from .utils import (
2526
device_support_pdl,
@@ -66,6 +67,7 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
6667
)
6768

6869

70+
@flashinfer_api
6971
def silu_and_mul(
7072
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
7173
) -> torch.Tensor:
@@ -110,6 +112,7 @@ def silu_and_mul(
110112
return out
111113

112114

115+
@flashinfer_api
113116
def gelu_tanh_and_mul(
114117
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
115118
) -> torch.Tensor:
@@ -150,6 +153,7 @@ def gelu_tanh_and_mul(
150153
return out
151154

152155

156+
@flashinfer_api
153157
def gelu_and_mul(
154158
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
155159
) -> torch.Tensor:
@@ -190,6 +194,7 @@ def gelu_and_mul(
190194
return out
191195

192196

197+
@flashinfer_api
193198
def silu_and_mul_scaled_nvfp4_experts_quantize(
194199
a,
195200
mask,

flashinfer/attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222

23+
from .api_logging import flashinfer_api
2324
from .jit import gen_batch_attention_module
2425
from .utils import (
2526
MaskMode,
@@ -40,6 +41,7 @@ def get_holistic_attention_module(*args):
4041

4142

4243
class BatchAttention:
44+
@flashinfer_api
4345
def __init__(
4446
self,
4547
kv_layout: str = "NHD",
@@ -65,6 +67,7 @@ def __init__(
6567
pin_memory=True,
6668
)
6769

70+
@flashinfer_api
6871
def plan(
6972
self,
7073
qo_indptr: torch.Tensor,
@@ -132,6 +135,7 @@ def plan(
132135
causal,
133136
)
134137

138+
@flashinfer_api
135139
def run(
136140
self,
137141
q: torch.Tensor,

flashinfer/cascade.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121

22+
from .api_logging import flashinfer_api
2223
from .decode import BatchDecodeWithPagedKVCacheWrapper
2324
from .jit.cascade import gen_cascade_module
2425
from .prefill import BatchPrefillWithPagedKVCacheWrapper, single_prefill_with_kv_cache
@@ -30,6 +31,7 @@ def get_cascade_module():
3031
return gen_cascade_module().build_and_load()
3132

3233

34+
@flashinfer_api
3335
@register_custom_op("flashinfer::merge_state", mutates_args=())
3436
def merge_state(
3537
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
@@ -96,6 +98,7 @@ def _fake_merge_state(
9698
return v, s
9799

98100

101+
@flashinfer_api
99102
@register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s"))
100103
def merge_state_in_place(
101104
v: torch.Tensor,
@@ -156,6 +159,7 @@ def _fake_merge_state_in_place(
156159
pass
157160

158161

162+
@flashinfer_api
159163
@register_custom_op("flashinfer::merge_states", mutates_args=())
160164
def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
161165
r"""Merge multiple attention states (v, s).
@@ -287,6 +291,7 @@ class MultiLevelCascadeAttentionWrapper:
287291
BatchPrefillWithPagedKVCacheWrapper
288292
"""
289293

294+
@flashinfer_api
290295
def __init__(
291296
self,
292297
num_levels,
@@ -386,6 +391,7 @@ def reset_workspace_buffer(
386391
):
387392
wrapper.reset_workspace_buffer(float_workspace_buffer, int_workspace_buffer)
388393

394+
@flashinfer_api
389395
def plan(
390396
self,
391397
qo_indptr_arr: List[torch.Tensor],
@@ -506,6 +512,7 @@ def plan(
506512

507513
begin_forward = plan
508514

515+
@flashinfer_api
509516
def run(
510517
self,
511518
q: torch.Tensor,
@@ -629,6 +636,7 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
629636
manages the lifecycle of these data structures.
630637
"""
631638

639+
@flashinfer_api
632640
def __init__(
633641
self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
634642
) -> None:
@@ -656,6 +664,7 @@ def reset_workspace_buffer(
656664
float_workspace_buffer, int_workspace_buffer
657665
)
658666

667+
@flashinfer_api
659668
def begin_forward(
660669
self,
661670
unique_kv_indptr: torch.Tensor,
@@ -717,6 +726,7 @@ def begin_forward(
717726
data_type=data_type,
718727
)
719728

729+
@flashinfer_api
720730
def forward(
721731
self,
722732
q: torch.Tensor,
@@ -780,6 +790,7 @@ def forward(
780790
merge_state_in_place(V_shared, S_shared, V_unique, S_unique)
781791
return V_shared
782792

793+
@flashinfer_api
783794
def end_forward(self) -> None:
784795
r"""Warning: this function is deprecated and has no effect"""
785796
pass
@@ -876,6 +887,7 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
876887
layers). This wrapper class manages the lifecycle of these data structures.
877888
"""
878889

890+
@flashinfer_api
879891
def __init__(
880892
self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
881893
) -> None:
@@ -914,6 +926,7 @@ def reset_workspace_buffer(
914926
float_workspace_buffer, int_workspace_buffer
915927
)
916928

929+
@flashinfer_api
917930
def begin_forward(
918931
self,
919932
qo_indptr: torch.Tensor,
@@ -969,6 +982,7 @@ def begin_forward(
969982
page_size,
970983
)
971984

985+
@flashinfer_api
972986
def forward(
973987
self,
974988
q: torch.Tensor,
@@ -1060,6 +1074,7 @@ def forward(
10601074
merge_state_in_place(V_shared, S_shared, V_unique, S_unique)
10611075
return V_shared
10621076

1077+
@flashinfer_api
10631078
def end_forward(self) -> None:
10641079
r"""Warning: this function is deprecated and has no effect"""
10651080
pass

flashinfer/decode.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,6 +1501,7 @@ class BatchDecodeMlaWithPagedKVCacheWrapper:
15011501
a more efficient and general MLA implementation that supports decode and incremental prefill.
15021502
"""
15031503

1504+
@flashinfer_api
15041505
def __init__(
15051506
self,
15061507
float_workspace_buffer: torch.Tensor,
@@ -1615,6 +1616,7 @@ def reset_workspace_buffer(
16151616
pin_memory=True,
16161617
)
16171618

1619+
@flashinfer_api
16181620
def plan(
16191621
self,
16201622
indptr: torch.Tensor,
@@ -1740,6 +1742,7 @@ def plan(
17401742
self._rope_scale = rope_scale
17411743
self._rope_theta = rope_theta
17421744

1745+
@flashinfer_api
17431746
def run(
17441747
self,
17451748
q_nope: torch.Tensor,

flashinfer/fp4_quantization.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import torch
2323

24+
from .api_logging import flashinfer_api
2425
from .jit import JitSpec
2526
from .jit import env as jit_env
2627
from .jit import (
@@ -624,6 +625,7 @@ def _fake_e2m1_and_ufp8sf_scale_to_float_sm100(
624625
)
625626

626627

628+
@flashinfer_api
627629
def fp4_quantize(
628630
input: torch.Tensor,
629631
global_scale: Optional[torch.Tensor] = None,
@@ -689,6 +691,7 @@ def fp4_quantize(
689691
return x_q, sf
690692

691693

694+
@flashinfer_api
692695
def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
693696
"""Swizzle block scale tensor for FP4 format.
694697
@@ -721,6 +724,7 @@ def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
721724
nvfp4_block_scale_interleave = block_scale_interleave
722725

723726

727+
@flashinfer_api
724728
def e2m1_and_ufp8sf_scale_to_float(
725729
e2m1_tensor: torch.Tensor,
726730
ufp8_scale_tensor: torch.Tensor,
@@ -763,6 +767,7 @@ def e2m1_and_ufp8sf_scale_to_float(
763767
)
764768

765769

770+
@flashinfer_api
766771
def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
767772
"""
768773
PyTorch equivalent of trtllm-gen `shuffleMatrixA`
@@ -772,6 +777,7 @@ def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.
772777
return input_tensor[row_indices.to(input_tensor.device)]
773778

774779

780+
@flashinfer_api
775781
def shuffle_matrix_sf_a(
776782
input_tensor: torch.Tensor,
777783
epilogue_tile_m: int,
@@ -806,6 +812,7 @@ class SfLayout(Enum):
806812
layout_linear = 2
807813

808814

815+
@flashinfer_api
809816
def nvfp4_quantize(
810817
a,
811818
a_global_sf,
@@ -866,6 +873,7 @@ def nvfp4_quantize(
866873
return a_fp4, a_sf
867874

868875

876+
@flashinfer_api
869877
def mxfp4_quantize(a):
870878
"""
871879
Quantize input tensor to MXFP4 format.
@@ -883,6 +891,7 @@ def mxfp4_quantize(a):
883891
return a_fp4, a_sf
884892

885893

894+
@flashinfer_api
886895
def mxfp4_dequantize(a_fp4, a_sf):
887896
"""
888897
Dequantize input tensor from MXFP4 format.
@@ -904,6 +913,7 @@ def mxfp4_dequantize(a_fp4, a_sf):
904913
)
905914

906915

916+
@flashinfer_api
907917
def mxfp4_dequantize_host(
908918
weight: torch.Tensor,
909919
scale: torch.Tensor,
@@ -932,6 +942,7 @@ def mxfp4_dequantize_host(
932942
)
933943

934944

945+
@flashinfer_api
935946
def nvfp4_batched_quantize(
936947
a,
937948
a_global_sf,
@@ -961,6 +972,7 @@ def nvfp4_batched_quantize(
961972
return a_fp4, a_sf
962973

963974

975+
@flashinfer_api
964976
def scaled_fp4_grouped_quantize(
965977
a,
966978
mask,

flashinfer/fp8_quantization.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from .api_logging import flashinfer_api
78
from .jit.fp8_quantization import gen_mxfp8_quantization_sm100_module
89
from .utils import (
910
device_support_pdl,
@@ -142,6 +143,7 @@ def _fake_mxfp8_dequantize_host_sm100(
142143
)
143144

144145

146+
@flashinfer_api
145147
def mxfp8_quantize(
146148
input: torch.Tensor,
147149
is_sf_swizzled_layout: bool = True,
@@ -178,6 +180,7 @@ def mxfp8_quantize(
178180
return x_q, sf
179181

180182

183+
@flashinfer_api
181184
def mxfp8_dequantize_host(
182185
input: torch.Tensor,
183186
scale_tensor: torch.Tensor,

flashinfer/gemm/routergemm_dsv3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ..api_logging import flashinfer_api
12
from flashinfer.jit import gen_dsv3_router_gemm_module
23
import functools
34
from types import SimpleNamespace
@@ -85,6 +86,7 @@ def mm_M1_16_K7168_N256(
8586
)
8687

8788

89+
@flashinfer_api
8890
@backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks)
8991
def mm_M1_16_K7168_N256(
9092
mat_a: torch.Tensor,

0 commit comments

Comments
 (0)