5555 _get_range_buf ,
5656 _unpack_paged_kv_cache ,
5757 canonicalize_torch_dtype ,
58+ determine_attention_backend ,
5859 device_support_pdl ,
5960 get_device_sm_count ,
6061 is_float8 ,
@@ -710,7 +711,7 @@ def __init__(
710711 self ._jit_module = get_batch_prefill_jit_module (
711712 jit_args [0 ],
712713 gen_customize_batch_prefill_module (
713- "fa2" , * jit_args
714+ backend , * jit_args
714715 ).build_and_load (),
715716 )
716717 else :
@@ -822,6 +823,7 @@ def plan(
822823 logits_soft_cap : Optional [float ] = None ,
823824 q_data_type : Optional [Union [str , torch .dtype ]] = "float16" ,
824825 kv_data_type : Optional [Union [str , torch .dtype ]] = None ,
826+ o_data_type : Optional [Union [str , torch .dtype ]] = None ,
825827 data_type : Optional [Union [str , torch .dtype ]] = None ,
826828 sm_scale : Optional [float ] = None ,
827829 rope_scale : Optional [float ] = None ,
@@ -869,6 +871,9 @@ def plan(
869871 kv_data_type : Optional[Union[str, torch.dtype]]
870872 The data type of the key/value tensor. If None, will be set to
871873 ``q_data_type``. Defaults to ``None``.
874+ o_data_type : Optional[Union[str, torch.dtype]]
875+ The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
876+ For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
872877 data_type: Optional[Union[str, torch.dtype]]
873878 The data type of both the query and key/value tensors. Defaults to torch.float16.
874879 data_type is deprecated, please use q_data_type and kv_data_type instead.
@@ -964,6 +969,10 @@ def plan(
964969 if kv_data_type is None :
965970 kv_data_type = q_data_type
966971 kv_data_type = canonicalize_torch_dtype (kv_data_type )
972+ if o_data_type is None :
973+ o_data_type = q_data_type
974+ o_data_type = canonicalize_torch_dtype (o_data_type )
975+
967976 if fixed_split_size is not None and not self .use_tensor_cores :
968977 raise ValueError (
969978 "fixed_split_size is only supported by tensor core decode for now."
@@ -973,6 +982,7 @@ def plan(
973982
974983 self ._cached_q_data_type = q_data_type
975984 self ._cached_kv_data_type = kv_data_type
985+ self ._cached_o_data_type = o_data_type
976986 self ._batch_size = batch_size
977987 self ._num_qo_heads = num_qo_heads
978988 self ._num_kv_heads = num_kv_heads
@@ -1012,7 +1022,7 @@ def plan(
10121022 self ._cached_module = get_trtllm_gen_decode_module (
10131023 q_data_type ,
10141024 kv_data_type ,
1015- q_data_type ,
1025+ o_data_type ,
10161026 indptr .dtype ,
10171027 head_dim ,
10181028 head_dim ,
@@ -1027,11 +1037,20 @@ def plan(
10271037 if self ._jit_module is not None :
10281038 self ._cached_module = self ._jit_module
10291039 else :
1040+ if self ._backend == "auto" :
1041+ self ._backend = determine_attention_backend (
1042+ self .device ,
1043+ PosEncodingMode [pos_encoding_mode ].value ,
1044+ False , # use_fp16_qk_reduction
1045+ False , # use_custom_mask
1046+ q_data_type ,
1047+ kv_data_type ,
1048+ )
10301049 self ._cached_module = get_batch_prefill_module (
1031- "fa2" ,
1050+ self . _backend ,
10321051 q_data_type ,
10331052 kv_data_type ,
1034- q_data_type ,
1053+ o_data_type ,
10351054 indptr .dtype ,
10361055 head_dim , # head_dim_qk
10371056 head_dim , # head_dim_vo
@@ -1041,7 +1060,13 @@ def plan(
10411060 False , # use_fp16_qk_reduction
10421061 )
10431062
1044- self ._plan_info = self ._cached_module .plan (
1063+ if self ._backend == "fa3" and q_data_type in [torch .float8_e4m3fn , torch .float8_e5m2 ]:
1064+ num_heads = max (num_qo_heads , num_kv_heads )
1065+ self ._dummy_scales = torch .ones (num_heads , device = self .device , dtype = torch .float32 )
1066+ else :
1067+ self ._dummy_scales = None
1068+
1069+ args = [
10451070 self ._float_workspace_buffer ,
10461071 self ._int_workspace_buffer ,
10471072 self ._pin_memory_int_workspace_buffer ,
@@ -1058,9 +1083,13 @@ def plan(
10581083 head_dim ,
10591084 False , # causal
10601085 window_left ,
1061- fixed_split_size ,
1062- disable_split_kv ,
1063- 0 , # num_colocated_ctas
1086+ ]
1087+ if self ._backend == "fa2" :
1088+ args .append (fixed_split_size )
1089+ args .append (disable_split_kv )
1090+ args .append (0 ) # num_colocated_ctas
1091+ self ._plan_info = self ._cached_module .plan (
1092+ * args ,
10641093 )
10651094 else :
10661095 if self ._jit_module is not None :
@@ -1069,7 +1098,7 @@ def plan(
10691098 self ._cached_module = get_batch_decode_module (
10701099 q_data_type ,
10711100 kv_data_type ,
1072- q_data_type ,
1101+ o_data_type ,
10731102 indptr .dtype ,
10741103 head_dim , # head_dim_qk
10751104 head_dim , # head_dim_vo
@@ -1278,9 +1307,13 @@ def run(
12781307 )
12791308
12801309 if out is None :
1281- out = torch .empty_like (q )
1310+ out_dtype = getattr (self , "_cached_o_data_type" , None ) or q .dtype
1311+ out = torch .empty (
1312+ q .shape [:- 1 ] + v_cache .shape [- 1 :], dtype = out_dtype , device = q .device
1313+ )
12821314 else :
1283- check_shape_dtype_device (out , q .shape , q .dtype , q .device , "out" )
1315+ out_dtype = getattr (self , "_cached_o_data_type" , None ) or q .dtype
1316+ check_shape_dtype_device (out , q .shape , out_dtype , q .device , "out" )
12841317
12851318 if self ._backend == "trtllm-gen" :
12861319 q = q .view (q .size (0 ) // q_len_per_req , q_len_per_req , q .size (1 ), q .size (2 ))
@@ -1308,6 +1341,20 @@ def run(
13081341 if self ._jit_module is not None :
13091342 run_args .extend (list (args ))
13101343 else :
1344+ # Extract FP8 scale tensors from *args if q is FP8
1345+ fp8_scale_q = None
1346+ fp8_scale_k = None
1347+ fp8_scale_v = None
1348+ if is_float8 (q ) and len (args ) >= 3 :
1349+ fp8_scale_q = args [0 ]
1350+ fp8_scale_k = args [1 ]
1351+ fp8_scale_v = args [2 ]
1352+ if fp8_scale_q is None :
1353+ fp8_scale_q = self ._dummy_scales
1354+ if fp8_scale_k is None :
1355+ fp8_scale_k = self ._dummy_scales
1356+ if fp8_scale_v is None :
1357+ fp8_scale_v = self ._dummy_scales
13111358 run_args += [
13121359 None , # packed_custom_mask
13131360 None , # mask_indptr_buf
@@ -1317,9 +1364,9 @@ def run(
13171364 None , # maybe_max_item_len_ptr
13181365 logits_soft_cap ,
13191366 sm_scale ,
1320- None , # scale_q, not supported yet
1321- None , # scale_k
1322- None , # scale_v
1367+ fp8_scale_q ,
1368+ fp8_scale_k ,
1369+ fp8_scale_v ,
13231370 rope_scale ,
13241371 rope_theta ,
13251372 0 , # token_pos_in_items_len
@@ -2921,8 +2968,8 @@ def fast_decode_plan(
29212968 kv_lens_arr_host = get_seq_lens (indptr_host , last_page_len_host , page_size )
29222969
29232970 try :
2924- # Make sure we pass exactly 16 arguments for tensor core version
2925- self . _plan_info = self . _cached_module . plan (
2971+ # Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend
2972+ args = [
29262973 self ._float_workspace_buffer ,
29272974 self ._int_workspace_buffer ,
29282975 self ._pin_memory_int_workspace_buffer ,
@@ -2939,9 +2986,13 @@ def fast_decode_plan(
29392986 head_dim ,
29402987 False , # causal
29412988 window_left ,
2942- fixed_split_size ,
2943- disable_split_kv ,
2944- 0 , # num_colocated_ctas
2989+ ]
2990+ if self ._backend == "fa2" :
2991+ args .append (fixed_split_size )
2992+ args .append (disable_split_kv )
2993+ args .append (0 ) # num_colocated_ctas
2994+ self ._plan_info = self ._cached_module .plan (
2995+ * args ,
29452996 )
29462997 except Exception as e :
29472998 raise RuntimeError (f"Error in standard plan: { e } " ) from e
0 commit comments