Skip to content

Commit afbd479

Browse files
committed
Added an initial implementation of Q and KV Cache in fp8 and to use the cudnn implementation
1 parent 2d68a6b commit afbd479

File tree

3 files changed

+399
-41
lines changed

3 files changed

+399
-41
lines changed

flashinfer/cudnn/prefill.py

Lines changed: 185 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,19 @@
1515

1616
# Global cudnn handle. need to make it per device in future
1717
_cudnn_handle = None
18+
_dummy_scale_tensor = None
19+
20+
21+
def _get_dummy_scale_tensor(device: torch.device):
22+
global _dummy_scale_tensor
23+
24+
_dummy_scale_tensor = torch.tensor([1.0], device=device, dtype=torch.float32)
25+
return _dummy_scale_tensor
1826

1927

2028
def _create_cudnn_handle(stream: torch.cuda.Stream):
2129
global _cudnn_handle
30+
2231
if _cudnn_handle is None:
2332
_cudnn_handle = cudnn.create_handle()
2433
cudnn.set_stream(_cudnn_handle, stream.cuda_stream)
@@ -49,6 +58,16 @@ class UIDs(Enum):
4958
O_UID = 1000 # Output tensor
5059
STATS_UID = 1001 # Stats tensor
5160

61+
Q_SCALE_UID = 150 # Query scale tensor
62+
K_SCALE_UID = 151 # Key scale tensor
63+
V_SCALE_UID = 152 # Value scale tensor
64+
S_SCALE_UID = 153 # Scale tensor
65+
S_DESCALE_UID = 154 # Descale tensor
66+
O_SCALE_UID = 155 # Output scale tensor
67+
68+
S_AMAX_UID = 160 # Scale amax tensor
69+
O_AMAX_UID = 161 # Output amax tensor
70+
5271

5372
def _sdpa_prefill_key_fn(
5473
q: torch.Tensor,
@@ -136,6 +155,13 @@ def _build_prefill_graph(
136155
graph_s_qo = max_token_seq_q
137156
graph_s_kv = max_sequence_kv
138157

158+
if not cudnn.datatypes.is_torch_available():
159+
raise RuntimeError("torch is not available")
160+
161+
cudnn_q_data_type = cudnn.datatypes._torch_to_cudnn_data_type(q.dtype)
162+
cudnn_k_data_type = cudnn.datatypes._torch_to_cudnn_data_type(k_cache.dtype)
163+
cudnn_v_data_type = cudnn.datatypes._torch_to_cudnn_data_type(v_cache.dtype)
164+
139165
with cudnn.graph(handle) as (g, _):
140166
# Create tensors from the input tensors
141167
if q.dim() == 3:
@@ -149,9 +175,62 @@ def _build_prefill_graph(
149175
name="q",
150176
dim=(graph_b, h_qo, graph_s_qo, d_qk),
151177
stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1),
152-
data_type=cudnn.data_type.BFLOAT16,
178+
data_type=cudnn_q_data_type,
153179
)
154180

181+
if (
182+
cudnn_q_data_type == cudnn.data_type.FP8_E4M3
183+
or cudnn_q_data_type == cudnn.data_type.FP8_E5M2
184+
):
185+
cudnn_q_scale = g.tensor(
186+
name="q_scale",
187+
dim=(1, 1, 1, 1),
188+
stride=(1, 1, 1, 1),
189+
data_type=cudnn.data_type.FLOAT,
190+
)
191+
192+
cudnn_k_scale = g.tensor(
193+
name="k_scale",
194+
dim=(1, 1, 1, 1),
195+
stride=(1, 1, 1, 1),
196+
data_type=cudnn.data_type.FLOAT,
197+
)
198+
199+
cudnn_v_scale = g.tensor(
200+
name="v_scale",
201+
dim=(1, 1, 1, 1),
202+
stride=(1, 1, 1, 1),
203+
data_type=cudnn.data_type.FLOAT,
204+
)
205+
206+
cudnn_s_scale = g.tensor(
207+
name="s_scale",
208+
dim=(1, 1, 1, 1),
209+
stride=(1, 1, 1, 1),
210+
data_type=cudnn.data_type.FLOAT,
211+
)
212+
213+
cudnn_s_descale = g.tensor(
214+
name="s_descale",
215+
dim=(1, 1, 1, 1),
216+
stride=(1, 1, 1, 1),
217+
data_type=cudnn.data_type.FLOAT,
218+
)
219+
220+
cudnn_o_scale = g.tensor(
221+
name="o_scale",
222+
dim=(1, 1, 1, 1),
223+
stride=(1, 1, 1, 1),
224+
data_type=cudnn.data_type.FLOAT,
225+
)
226+
227+
cudnn_q_scale.set_uid(UIDs.Q_SCALE_UID.value)
228+
cudnn_k_scale.set_uid(UIDs.K_SCALE_UID.value)
229+
cudnn_v_scale.set_uid(UIDs.V_SCALE_UID.value)
230+
cudnn_s_scale.set_uid(UIDs.S_SCALE_UID.value)
231+
cudnn_s_descale.set_uid(UIDs.S_DESCALE_UID.value)
232+
cudnn_o_scale.set_uid(UIDs.O_SCALE_UID.value)
233+
155234
if batch_offsets_q is not None:
156235
ragged_q = g.tensor_like(batch_offsets_q)
157236
ragged_q.set_uid(UIDs.RAGGED_Q_UID.value)
@@ -175,7 +254,7 @@ def _build_prefill_graph(
175254
name="k_cache",
176255
dim=(graph_b, h_kv, graph_s_kv, d_qk),
177256
stride=(h_kv * d_qk * graph_s_kv, d_qk, d_qk * h_kv, 1),
178-
data_type=cudnn.data_type.BFLOAT16,
257+
data_type=cudnn_k_data_type,
179258
)
180259

181260
if batch_offsets_k is not None:
@@ -187,7 +266,7 @@ def _build_prefill_graph(
187266
name="v_cache",
188267
dim=(graph_b, h_kv, graph_s_kv, d_vo),
189268
stride=(h_kv * d_vo * graph_s_kv, d_vo, d_vo * h_kv, 1),
190-
data_type=cudnn.data_type.BFLOAT16,
269+
data_type=cudnn_v_data_type,
191270
)
192271

193272
if batch_offsets_v is not None:
@@ -200,14 +279,14 @@ def _build_prefill_graph(
200279
name="k_cache",
201280
dim=k_cache.shape,
202281
stride=k_cache.stride(),
203-
data_type=cudnn.data_type.BFLOAT16,
282+
data_type=cudnn_k_data_type,
204283
)
205284

206285
cudnn_v_cache = g.tensor(
207286
name="v_cache",
208287
dim=v_cache.shape,
209288
stride=v_cache.stride(),
210-
data_type=cudnn.data_type.BFLOAT16,
289+
data_type=cudnn_v_data_type,
211290
)
212291

213292
cudnn_q.set_uid(UIDs.Q_UID.value)
@@ -238,32 +317,83 @@ def _build_prefill_graph(
238317
actual_seq_lens_q is not None and actual_seq_lens_kv is not None
239318
)
240319

241-
O, Stats = g.sdpa(
242-
name="sdpa",
243-
q=cudnn_q,
244-
k=cudnn_k_cache,
245-
v=cudnn_v_cache,
246-
seq_len_q=(
247-
cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None
248-
),
249-
seq_len_kv=(
250-
cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None
251-
),
252-
use_padding_mask=padding_mask,
253-
attn_scale=scale,
254-
generate_stats=return_lse,
255-
use_causal_mask_bottom_right=bottom_right_causal_mask,
256-
paged_attention_k_table=(
257-
cudnn_k_block_tables if block_tables is not None else None
258-
),
259-
paged_attention_v_table=(
260-
cudnn_v_block_tables if block_tables is not None else None
261-
),
262-
paged_attention_max_seq_len_kv=(
263-
graph_s_kv if block_tables is not None else None
264-
),
265-
compute_data_type=cudnn.data_type.FLOAT,
266-
)
320+
if cudnn_q_data_type == cudnn.data_type.BFLOAT16:
321+
O, Stats = g.sdpa(
322+
name="sdpa",
323+
q=cudnn_q,
324+
k=cudnn_k_cache,
325+
v=cudnn_v_cache,
326+
seq_len_q=(
327+
cudnn_actual_seq_lens_q
328+
if actual_seq_lens_q is not None
329+
else None
330+
),
331+
seq_len_kv=(
332+
cudnn_actual_seq_lens_kv
333+
if actual_seq_lens_kv is not None
334+
else None
335+
),
336+
use_padding_mask=padding_mask,
337+
attn_scale=scale,
338+
generate_stats=return_lse,
339+
use_causal_mask_bottom_right=bottom_right_causal_mask,
340+
paged_attention_k_table=(
341+
cudnn_k_block_tables if block_tables is not None else None
342+
),
343+
paged_attention_v_table=(
344+
cudnn_v_block_tables if block_tables is not None else None
345+
),
346+
paged_attention_max_seq_len_kv=(
347+
graph_s_kv if block_tables is not None else None
348+
),
349+
compute_data_type=cudnn.data_type.FLOAT,
350+
)
351+
352+
elif (
353+
cudnn_q_data_type == cudnn.data_type.FP8_E4M3
354+
or cudnn_q_data_type == cudnn.data_type.FP8_E5M2
355+
):
356+
O, Stats, amax_s, amax_o = g.sdpa_fp8(
357+
q=cudnn_q,
358+
k=cudnn_k_cache,
359+
v=cudnn_v_cache,
360+
descale_q=cudnn_q_scale,
361+
descale_k=cudnn_k_scale,
362+
descale_v=cudnn_v_scale,
363+
scale_s=cudnn_s_scale,
364+
descale_s=cudnn_s_descale,
365+
scale_o=cudnn_o_scale,
366+
generate_stats=True,
367+
attn_scale=scale,
368+
use_causal_mask_bottom_right=bottom_right_causal_mask,
369+
use_padding_mask=padding_mask,
370+
seq_len_q=(
371+
cudnn_actual_seq_lens_q
372+
if actual_seq_lens_q is not None
373+
else None
374+
),
375+
seq_len_kv=(
376+
cudnn_actual_seq_lens_kv
377+
if actual_seq_lens_kv is not None
378+
else None
379+
),
380+
paged_attention_k_table=(
381+
cudnn_k_block_tables if block_tables is not None else None
382+
),
383+
paged_attention_v_table=(
384+
cudnn_v_block_tables if block_tables is not None else None
385+
),
386+
paged_attention_max_seq_len_kv=(
387+
graph_s_kv if block_tables is not None else None
388+
),
389+
)
390+
391+
amax_s.set_uid(UIDs.S_AMAX_UID.value).set_output(False).set_dim(
392+
(1, 1, 1, 1)
393+
).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT)
394+
amax_o.set_uid(UIDs.O_AMAX_UID.value).set_output(False).set_dim(
395+
(1, 1, 1, 1)
396+
).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT)
267397

268398
if batch_offsets_o is not None:
269399
ragged_o = g.tensor_like(batch_offsets_o)
@@ -279,7 +409,7 @@ def _build_prefill_graph(
279409
[graph_b, h_qo, graph_s_qo, d_vo]
280410
).set_stride(
281411
[graph_s_qo * d_vo * h_qo, d_vo, d_vo * h_qo, 1]
282-
).set_data_type(cudnn.data_type.BFLOAT16)
412+
).set_data_type(cudnn_q_data_type)
283413

284414
if return_lse:
285415
Stats.set_uid(UIDs.STATS_UID.value).set_output(
@@ -314,6 +444,9 @@ def _batch_prefill_with_kv_cache(
314444
block_tables: Optional[torch.Tensor] = None,
315445
causal: bool,
316446
return_lse: bool,
447+
q_scale: Optional[torch.Tensor] = None,
448+
k_scale: Optional[torch.Tensor] = None,
449+
v_scale: Optional[torch.Tensor] = None,
317450
batch_offsets_q: Optional[torch.Tensor] = None,
318451
batch_offsets_o: Optional[torch.Tensor] = None,
319452
batch_offsets_k: Optional[torch.Tensor] = None,
@@ -374,6 +507,17 @@ def _batch_prefill_with_kv_cache(
374507
if batch_offsets_stats is not None:
375508
var_map[UIDs.RAGGED_STATS_UID.value] = batch_offsets_stats
376509

510+
if q_scale is not None:
511+
dummy_scale_tensor = _get_dummy_scale_tensor(q.device)
512+
var_map[UIDs.Q_SCALE_UID.value] = q_scale
513+
var_map[UIDs.S_SCALE_UID.value] = dummy_scale_tensor
514+
var_map[UIDs.S_DESCALE_UID.value] = dummy_scale_tensor
515+
var_map[UIDs.O_SCALE_UID.value] = dummy_scale_tensor
516+
if k_scale is not None:
517+
var_map[UIDs.K_SCALE_UID.value] = k_scale
518+
if v_scale is not None:
519+
var_map[UIDs.V_SCALE_UID.value] = v_scale
520+
377521
handle = _create_cudnn_handle(torch.cuda.current_stream(q.device))
378522
graph.execute(var_map, workspace=workspace_buffer, handle=handle)
379523

@@ -397,6 +541,9 @@ def cudnn_batch_prefill_with_kv_cache(
397541
block_tables: Optional[torch.Tensor] = None,
398542
causal: bool,
399543
return_lse: bool,
544+
q_scale: Optional[torch.Tensor] = None,
545+
k_scale: Optional[torch.Tensor] = None,
546+
v_scale: Optional[torch.Tensor] = None,
400547
batch_offsets_q: Optional[torch.Tensor] = None,
401548
batch_offsets_o: Optional[torch.Tensor] = None,
402549
batch_offsets_k: Optional[torch.Tensor] = None,
@@ -425,6 +572,9 @@ def cudnn_batch_prefill_with_kv_cache(
425572
out: Optional pre-allocated output tensor
426573
lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None
427574
is_cuda_graph_compatible: Whether the prefill operation is compatible with CUDA graph
575+
q_scale: Optional scale tensor for query tensor of shape (1, 1, 1, 1) on GPU
576+
k_scale: Optional scale tensor for key tensor of shape (1, 1, 1, 1) on GPU
577+
v_scale: Optional scale tensor for value tensor of shape (1, 1, 1, 1) on GPU
428578
batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU
429579
batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU
430580
batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU
@@ -488,6 +638,9 @@ def cudnn_batch_prefill_with_kv_cache(
488638
block_tables=block_tables,
489639
causal=causal,
490640
return_lse=return_lse,
641+
q_scale=q_scale,
642+
k_scale=k_scale,
643+
v_scale=v_scale,
491644
batch_offsets_q=batch_offsets_q,
492645
batch_offsets_o=batch_offsets_o,
493646
batch_offsets_k=batch_offsets_k,

flashinfer/prefill.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1980,9 +1980,9 @@ def run(
19801980
q: torch.Tensor,
19811981
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
19821982
*args,
1983-
q_scale: Optional[float] = None,
1984-
k_scale: Optional[float] = None,
1985-
v_scale: Optional[float] = None,
1983+
q_scale: Optional[Union[float, torch.Tensor]] = None,
1984+
k_scale: Optional[Union[float, torch.Tensor]] = None,
1985+
v_scale: Optional[Union[float, torch.Tensor]] = None,
19861986
out: Optional[torch.Tensor] = None,
19871987
lse: Optional[torch.Tensor] = None,
19881988
return_lse: bool = False,
@@ -2012,9 +2012,11 @@ def run(
20122012
20132013
*args
20142014
Additional arguments for custom kernels.
2015-
k_scale : Optional[float]
2015+
q_scale : Optional[Union[float, torch.Tensor]]
2016+
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
2017+
k_scale : Optional[Union[float, torch.Tensor]]
20162018
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
2017-
v_scale : Optional[float]
2019+
v_scale : Optional[Union[float, torch.Tensor]]
20182020
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
20192021
out : Optional[torch.Tensor]
20202022
The output tensor, if not provided, will be allocated internally.
@@ -2061,10 +2063,11 @@ def run(
20612063
logits_soft_cap = 0.0
20622064
if sm_scale is None:
20632065
sm_scale = 1.0 / math.sqrt(q.size(-1))
2064-
if q_scale is not None:
2065-
sm_scale *= q_scale
2066-
if k_scale is not None:
2067-
sm_scale *= k_scale
2066+
if self._backend != "cudnn":
2067+
if q_scale is not None:
2068+
sm_scale *= q_scale
2069+
if k_scale is not None:
2070+
sm_scale *= k_scale
20682071
if rope_scale is None:
20692072
rope_scale = 1.0
20702073
if rope_theta is None:
@@ -2143,6 +2146,9 @@ def run(
21432146
block_tables=self._block_tables,
21442147
causal=self._causal,
21452148
return_lse=return_lse,
2149+
q_scale=q_scale,
2150+
k_scale=k_scale,
2151+
v_scale=v_scale,
21462152
batch_offsets_q=self._qo_indptr_buf,
21472153
batch_offsets_o=self._qo_indptr_buf,
21482154
out=out,

0 commit comments

Comments
 (0)