1515
1616# Global cudnn handle. need to make it per device in future
1717_cudnn_handle = None
18+ _dummy_scale_tensor = None
19+
20+
21+ def _get_dummy_scale_tensor (device : torch .device ):
22+ global _dummy_scale_tensor
23+
24+ _dummy_scale_tensor = torch .tensor ([1.0 ], device = device , dtype = torch .float32 )
25+ return _dummy_scale_tensor
1826
1927
2028def _create_cudnn_handle (stream : torch .cuda .Stream ):
2129 global _cudnn_handle
30+
2231 if _cudnn_handle is None :
2332 _cudnn_handle = cudnn .create_handle ()
2433 cudnn .set_stream (_cudnn_handle , stream .cuda_stream )
@@ -49,6 +58,16 @@ class UIDs(Enum):
4958 O_UID = 1000 # Output tensor
5059 STATS_UID = 1001 # Stats tensor
5160
61+ Q_SCALE_UID = 150 # Query scale tensor
62+ K_SCALE_UID = 151 # Key scale tensor
63+ V_SCALE_UID = 152 # Value scale tensor
64+ S_SCALE_UID = 153 # Scale tensor
65+ S_DESCALE_UID = 154 # Descale tensor
66+ O_SCALE_UID = 155 # Output scale tensor
67+
68+ S_AMAX_UID = 160 # Scale amax tensor
69+ O_AMAX_UID = 161 # Output amax tensor
70+
5271
5372def _sdpa_prefill_key_fn (
5473 q : torch .Tensor ,
@@ -136,6 +155,13 @@ def _build_prefill_graph(
136155 graph_s_qo = max_token_seq_q
137156 graph_s_kv = max_sequence_kv
138157
158+ if not cudnn .datatypes .is_torch_available ():
159+ raise RuntimeError ("torch is not available" )
160+
161+ cudnn_q_data_type = cudnn .datatypes ._torch_to_cudnn_data_type (q .dtype )
162+ cudnn_k_data_type = cudnn .datatypes ._torch_to_cudnn_data_type (k_cache .dtype )
163+ cudnn_v_data_type = cudnn .datatypes ._torch_to_cudnn_data_type (v_cache .dtype )
164+
139165 with cudnn .graph (handle ) as (g , _ ):
140166 # Create tensors from the input tensors
141167 if q .dim () == 3 :
@@ -149,9 +175,62 @@ def _build_prefill_graph(
149175 name = "q" ,
150176 dim = (graph_b , h_qo , graph_s_qo , d_qk ),
151177 stride = (h_qo * d_qk , d_qk , d_qk * h_qo , 1 ),
152- data_type = cudnn . data_type . BFLOAT16 ,
178+ data_type = cudnn_q_data_type ,
153179 )
154180
181+ if (
182+ cudnn_q_data_type == cudnn .data_type .FP8_E4M3
183+ or cudnn_q_data_type == cudnn .data_type .FP8_E5M2
184+ ):
185+ cudnn_q_scale = g .tensor (
186+ name = "q_scale" ,
187+ dim = (1 , 1 , 1 , 1 ),
188+ stride = (1 , 1 , 1 , 1 ),
189+ data_type = cudnn .data_type .FLOAT ,
190+ )
191+
192+ cudnn_k_scale = g .tensor (
193+ name = "k_scale" ,
194+ dim = (1 , 1 , 1 , 1 ),
195+ stride = (1 , 1 , 1 , 1 ),
196+ data_type = cudnn .data_type .FLOAT ,
197+ )
198+
199+ cudnn_v_scale = g .tensor (
200+ name = "v_scale" ,
201+ dim = (1 , 1 , 1 , 1 ),
202+ stride = (1 , 1 , 1 , 1 ),
203+ data_type = cudnn .data_type .FLOAT ,
204+ )
205+
206+ cudnn_s_scale = g .tensor (
207+ name = "s_scale" ,
208+ dim = (1 , 1 , 1 , 1 ),
209+ stride = (1 , 1 , 1 , 1 ),
210+ data_type = cudnn .data_type .FLOAT ,
211+ )
212+
213+ cudnn_s_descale = g .tensor (
214+ name = "s_descale" ,
215+ dim = (1 , 1 , 1 , 1 ),
216+ stride = (1 , 1 , 1 , 1 ),
217+ data_type = cudnn .data_type .FLOAT ,
218+ )
219+
220+ cudnn_o_scale = g .tensor (
221+ name = "o_scale" ,
222+ dim = (1 , 1 , 1 , 1 ),
223+ stride = (1 , 1 , 1 , 1 ),
224+ data_type = cudnn .data_type .FLOAT ,
225+ )
226+
227+ cudnn_q_scale .set_uid (UIDs .Q_SCALE_UID .value )
228+ cudnn_k_scale .set_uid (UIDs .K_SCALE_UID .value )
229+ cudnn_v_scale .set_uid (UIDs .V_SCALE_UID .value )
230+ cudnn_s_scale .set_uid (UIDs .S_SCALE_UID .value )
231+ cudnn_s_descale .set_uid (UIDs .S_DESCALE_UID .value )
232+ cudnn_o_scale .set_uid (UIDs .O_SCALE_UID .value )
233+
155234 if batch_offsets_q is not None :
156235 ragged_q = g .tensor_like (batch_offsets_q )
157236 ragged_q .set_uid (UIDs .RAGGED_Q_UID .value )
@@ -175,7 +254,7 @@ def _build_prefill_graph(
175254 name = "k_cache" ,
176255 dim = (graph_b , h_kv , graph_s_kv , d_qk ),
177256 stride = (h_kv * d_qk * graph_s_kv , d_qk , d_qk * h_kv , 1 ),
178- data_type = cudnn . data_type . BFLOAT16 ,
257+ data_type = cudnn_k_data_type ,
179258 )
180259
181260 if batch_offsets_k is not None :
@@ -187,7 +266,7 @@ def _build_prefill_graph(
187266 name = "v_cache" ,
188267 dim = (graph_b , h_kv , graph_s_kv , d_vo ),
189268 stride = (h_kv * d_vo * graph_s_kv , d_vo , d_vo * h_kv , 1 ),
190- data_type = cudnn . data_type . BFLOAT16 ,
269+ data_type = cudnn_v_data_type ,
191270 )
192271
193272 if batch_offsets_v is not None :
@@ -200,14 +279,14 @@ def _build_prefill_graph(
200279 name = "k_cache" ,
201280 dim = k_cache .shape ,
202281 stride = k_cache .stride (),
203- data_type = cudnn . data_type . BFLOAT16 ,
282+ data_type = cudnn_k_data_type ,
204283 )
205284
206285 cudnn_v_cache = g .tensor (
207286 name = "v_cache" ,
208287 dim = v_cache .shape ,
209288 stride = v_cache .stride (),
210- data_type = cudnn . data_type . BFLOAT16 ,
289+ data_type = cudnn_v_data_type ,
211290 )
212291
213292 cudnn_q .set_uid (UIDs .Q_UID .value )
@@ -238,32 +317,83 @@ def _build_prefill_graph(
238317 actual_seq_lens_q is not None and actual_seq_lens_kv is not None
239318 )
240319
241- O , Stats = g .sdpa (
242- name = "sdpa" ,
243- q = cudnn_q ,
244- k = cudnn_k_cache ,
245- v = cudnn_v_cache ,
246- seq_len_q = (
247- cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None
248- ),
249- seq_len_kv = (
250- cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None
251- ),
252- use_padding_mask = padding_mask ,
253- attn_scale = scale ,
254- generate_stats = return_lse ,
255- use_causal_mask_bottom_right = bottom_right_causal_mask ,
256- paged_attention_k_table = (
257- cudnn_k_block_tables if block_tables is not None else None
258- ),
259- paged_attention_v_table = (
260- cudnn_v_block_tables if block_tables is not None else None
261- ),
262- paged_attention_max_seq_len_kv = (
263- graph_s_kv if block_tables is not None else None
264- ),
265- compute_data_type = cudnn .data_type .FLOAT ,
266- )
320+ if cudnn_q_data_type == cudnn .data_type .BFLOAT16 :
321+ O , Stats = g .sdpa (
322+ name = "sdpa" ,
323+ q = cudnn_q ,
324+ k = cudnn_k_cache ,
325+ v = cudnn_v_cache ,
326+ seq_len_q = (
327+ cudnn_actual_seq_lens_q
328+ if actual_seq_lens_q is not None
329+ else None
330+ ),
331+ seq_len_kv = (
332+ cudnn_actual_seq_lens_kv
333+ if actual_seq_lens_kv is not None
334+ else None
335+ ),
336+ use_padding_mask = padding_mask ,
337+ attn_scale = scale ,
338+ generate_stats = return_lse ,
339+ use_causal_mask_bottom_right = bottom_right_causal_mask ,
340+ paged_attention_k_table = (
341+ cudnn_k_block_tables if block_tables is not None else None
342+ ),
343+ paged_attention_v_table = (
344+ cudnn_v_block_tables if block_tables is not None else None
345+ ),
346+ paged_attention_max_seq_len_kv = (
347+ graph_s_kv if block_tables is not None else None
348+ ),
349+ compute_data_type = cudnn .data_type .FLOAT ,
350+ )
351+
352+ elif (
353+ cudnn_q_data_type == cudnn .data_type .FP8_E4M3
354+ or cudnn_q_data_type == cudnn .data_type .FP8_E5M2
355+ ):
356+ O , Stats , amax_s , amax_o = g .sdpa_fp8 (
357+ q = cudnn_q ,
358+ k = cudnn_k_cache ,
359+ v = cudnn_v_cache ,
360+ descale_q = cudnn_q_scale ,
361+ descale_k = cudnn_k_scale ,
362+ descale_v = cudnn_v_scale ,
363+ scale_s = cudnn_s_scale ,
364+ descale_s = cudnn_s_descale ,
365+ scale_o = cudnn_o_scale ,
366+ generate_stats = True ,
367+ attn_scale = scale ,
368+ use_causal_mask_bottom_right = bottom_right_causal_mask ,
369+ use_padding_mask = padding_mask ,
370+ seq_len_q = (
371+ cudnn_actual_seq_lens_q
372+ if actual_seq_lens_q is not None
373+ else None
374+ ),
375+ seq_len_kv = (
376+ cudnn_actual_seq_lens_kv
377+ if actual_seq_lens_kv is not None
378+ else None
379+ ),
380+ paged_attention_k_table = (
381+ cudnn_k_block_tables if block_tables is not None else None
382+ ),
383+ paged_attention_v_table = (
384+ cudnn_v_block_tables if block_tables is not None else None
385+ ),
386+ paged_attention_max_seq_len_kv = (
387+ graph_s_kv if block_tables is not None else None
388+ ),
389+ )
390+
391+ amax_s .set_uid (UIDs .S_AMAX_UID .value ).set_output (False ).set_dim (
392+ (1 , 1 , 1 , 1 )
393+ ).set_stride ((1 , 1 , 1 , 1 )).set_data_type (cudnn .data_type .FLOAT )
394+ amax_o .set_uid (UIDs .O_AMAX_UID .value ).set_output (False ).set_dim (
395+ (1 , 1 , 1 , 1 )
396+ ).set_stride ((1 , 1 , 1 , 1 )).set_data_type (cudnn .data_type .FLOAT )
267397
268398 if batch_offsets_o is not None :
269399 ragged_o = g .tensor_like (batch_offsets_o )
@@ -279,7 +409,7 @@ def _build_prefill_graph(
279409 [graph_b , h_qo , graph_s_qo , d_vo ]
280410 ).set_stride (
281411 [graph_s_qo * d_vo * h_qo , d_vo , d_vo * h_qo , 1 ]
282- ).set_data_type (cudnn . data_type . BFLOAT16 )
412+ ).set_data_type (cudnn_q_data_type )
283413
284414 if return_lse :
285415 Stats .set_uid (UIDs .STATS_UID .value ).set_output (
@@ -314,6 +444,9 @@ def _batch_prefill_with_kv_cache(
314444 block_tables : Optional [torch .Tensor ] = None ,
315445 causal : bool ,
316446 return_lse : bool ,
447+ q_scale : Optional [torch .Tensor ] = None ,
448+ k_scale : Optional [torch .Tensor ] = None ,
449+ v_scale : Optional [torch .Tensor ] = None ,
317450 batch_offsets_q : Optional [torch .Tensor ] = None ,
318451 batch_offsets_o : Optional [torch .Tensor ] = None ,
319452 batch_offsets_k : Optional [torch .Tensor ] = None ,
@@ -374,6 +507,17 @@ def _batch_prefill_with_kv_cache(
374507 if batch_offsets_stats is not None :
375508 var_map [UIDs .RAGGED_STATS_UID .value ] = batch_offsets_stats
376509
510+ if q_scale is not None :
511+ dummy_scale_tensor = _get_dummy_scale_tensor (q .device )
512+ var_map [UIDs .Q_SCALE_UID .value ] = q_scale
513+ var_map [UIDs .S_SCALE_UID .value ] = dummy_scale_tensor
514+ var_map [UIDs .S_DESCALE_UID .value ] = dummy_scale_tensor
515+ var_map [UIDs .O_SCALE_UID .value ] = dummy_scale_tensor
516+ if k_scale is not None :
517+ var_map [UIDs .K_SCALE_UID .value ] = k_scale
518+ if v_scale is not None :
519+ var_map [UIDs .V_SCALE_UID .value ] = v_scale
520+
377521 handle = _create_cudnn_handle (torch .cuda .current_stream (q .device ))
378522 graph .execute (var_map , workspace = workspace_buffer , handle = handle )
379523
@@ -397,6 +541,9 @@ def cudnn_batch_prefill_with_kv_cache(
397541 block_tables : Optional [torch .Tensor ] = None ,
398542 causal : bool ,
399543 return_lse : bool ,
544+ q_scale : Optional [torch .Tensor ] = None ,
545+ k_scale : Optional [torch .Tensor ] = None ,
546+ v_scale : Optional [torch .Tensor ] = None ,
400547 batch_offsets_q : Optional [torch .Tensor ] = None ,
401548 batch_offsets_o : Optional [torch .Tensor ] = None ,
402549 batch_offsets_k : Optional [torch .Tensor ] = None ,
@@ -425,6 +572,9 @@ def cudnn_batch_prefill_with_kv_cache(
425572 out: Optional pre-allocated output tensor
426573 lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None
427574 is_cuda_graph_compatible: Whether the prefill operation is compatible with CUDA graph
575+ q_scale: Optional scale tensor for query tensor of shape (1, 1, 1, 1) on GPU
576+ k_scale: Optional scale tensor for key tensor of shape (1, 1, 1, 1) on GPU
577+ v_scale: Optional scale tensor for value tensor of shape (1, 1, 1, 1) on GPU
428578 batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU
429579 batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU
430580 batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU
@@ -488,6 +638,9 @@ def cudnn_batch_prefill_with_kv_cache(
488638 block_tables = block_tables ,
489639 causal = causal ,
490640 return_lse = return_lse ,
641+ q_scale = q_scale ,
642+ k_scale = k_scale ,
643+ v_scale = v_scale ,
491644 batch_offsets_q = batch_offsets_q ,
492645 batch_offsets_o = batch_offsets_o ,
493646 batch_offsets_k = batch_offsets_k ,
0 commit comments