1515from .. import exc
1616from .._compat import get_tensor_descriptor_fn_name
1717from .ast_extension import expr_from_string
18+ from .ast_extension import statement_from_string
1819from .compile_environment import CompileEnvironment
1920from .device_function import DeviceFunction
2021from .host_function import HostFunction
@@ -103,6 +104,12 @@ def _get_tile_with_offset_info(
103104
104105 return None
105106
107+ def _supports_epilogue_subtiling ():
108+ env = CompileEnvironment .current ()
109+ if env .device .type != "cuda" or not env .settings .allow_epilogue_subtiling :
110+ return False
111+ return torch .cuda .get_device_capability () >= (10 , 0 )
112+
106113
107114class IndexingStrategy :
108115 def codegen_load (
@@ -376,6 +383,7 @@ def codegen_store(
376383 subscript : list [object ],
377384 value : ast .AST ,
378385 extra_mask : ast .AST | None ,
386+ epilogue_subtile : int | None ,
379387 ) -> ast .AST :
380388 if not self .is_supported (state , fake_tensor , subscript , extra_mask ):
381389 return PointerIndexingStrategy ().codegen_store (
@@ -384,6 +392,10 @@ def codegen_store(
384392 assert extra_mask is None
385393 indexing = BlockedSubscriptIndexing .create (state , fake_tensor , subscript )
386394
395+ config = DeviceFunction .current ().config
396+ if _supports_epilogue_subtiling and config .epilogue_subtiling :
397+ return self ._codegen_epilogue_subtile_store (state , fake_tensor , indexing , store_value )
398+
387399 # Apply permutation to the value being stored if needed
388400 desc_arg = indexing .tensor_descriptor_arg (state )
389401 store_value = indexing .reshape_store (state , value )
@@ -394,12 +406,102 @@ def codegen_store(
394406 f"tl.permute({{store_val}}, { desc_arg .permutation !r} )" ,
395407 store_val = store_value ,
396408 )
397-
409+
398410 return expr_from_string (
399411 f"{ indexing .tensor_descriptor (state )} .store({ indexing .offsets_str_permuted (state )} , {{value}})" ,
400412 value = store_value ,
401413 )
402414
415+ def _codegen_epilogue_subtile_store (
416+ self ,
417+ state : CodegenState ,
418+ fake_tensor : torch .Tensor ,
419+ indexing : BlockedSubscriptIndexing ,
420+ store_value : ast .AST ,
421+ ) -> ast .AST | None :
422+ # Currently support 2D tiles without permutations
423+ if len (indexing .block_shape ) != 2 or len (indexing .offsets ) != 2 :
424+ return None
425+
426+ env = CompileEnvironment .current ()
427+ block_m , block_n = indexing .block_shape
428+ try :
429+ block_n_hint = env .size_hint (block_n )
430+ except Exception :
431+ return None
432+
433+ if block_n_hint % 2 != 0 :
434+ return None
435+
436+ device_fn = state .device_function
437+ codegen = state .codegen
438+
439+ block_m_str = device_fn .literal_expr (block_m )
440+ block_n_str = device_fn .literal_expr (block_n )
441+ indexing .block_shape [1 ] //= 2
442+ desc_arg = indexing .tensor_descriptor_arg (state )
443+
444+ if desc_arg .permutation is not None :
445+ return None
446+
447+
448+ block_n_half_str = f"({ block_n_str } // 2)"
449+
450+ # Lift the store value into a temporary variable for reuse
451+ acc_var = codegen .lift (store_value , prefix = "acc" )
452+
453+ reshape_expr = expr_from_string (
454+ "tl.reshape({acc}, [{dim_m}, 2, {dim_half}])" ,
455+ acc = acc_var ,
456+ dim_m = expr_from_string (block_m_str ),
457+ dim_half = expr_from_string (block_n_half_str ),
458+ )
459+ reshape_var = codegen .lift (reshape_expr , prefix = "acc" )
460+
461+ permute_expr = expr_from_string (
462+ "tl.permute({acc}, [0, 2, 1])" ,
463+ acc = reshape_var ,
464+ )
465+ permute_var = codegen .lift (permute_expr , prefix = "acc" )
466+
467+ acc0_name = codegen .tmpvar (prefix = "acc" )
468+ acc1_name = codegen .tmpvar (prefix = "acc" )
469+ codegen .add_statement (
470+ statement_from_string (
471+ f"{ acc0_name } , { acc1_name } = tl.split({{acc}})" ,
472+ acc = permute_var ,
473+ )
474+ )
475+ acc0 = expr_from_string (acc0_name )
476+ acc1 = expr_from_string (acc1_name )
477+
478+ desc_name = indexing .tensor_descriptor (state )
479+ offset0 = expr_from_string (indexing .offsets [0 ])
480+ offset1 = expr_from_string (indexing .offsets [1 ])
481+
482+ # First subtile store
483+ codegen .add_statement (
484+ statement_from_string (
485+ f"{ desc_name } .store([{{off0}}, {{off1}}], {{value}})" ,
486+ off0 = offset0 ,
487+ off1 = offset1 ,
488+ value = acc0 ,
489+ )
490+ )
491+
492+ offset1_shifted = expr_from_string (
493+ "({offset} + {half})" ,
494+ offset = expr_from_string (indexing .offsets [1 ]),
495+ half = expr_from_string (block_n_half_str ),
496+ )
497+
498+ # Emit second subtile store as the expression returned to the caller
499+ return expr_from_string (
500+ f"{ desc_name } .store([{{off0}}, {{off1}}], {{value}})" ,
501+ off0 = offset0 ,
502+ off1 = offset1_shifted ,
503+ value = acc1 ,
504+ )
403505
404506class StackIndexingStrategy :
405507 """
0 commit comments