1515from .. import exc
1616from .._compat import get_tensor_descriptor_fn_name
1717from .ast_extension import expr_from_string
18+ from .ast_extension import statement_from_string
1819from .compile_environment import CompileEnvironment
1920from .device_function import DeviceFunction
2021from .host_function import HostFunction
@@ -103,6 +104,12 @@ def _get_tile_with_offset_info(
103104
104105 return None
105106
107+ def _supports_epilogue_subtiling ():
108+ env = CompileEnvironment .current ()
109+ if env .device .type != "cuda" or not env .settings .allow_epilogue_subtiling :
110+ return False
111+ return torch .cuda .get_device_capability () >= (10 , 0 )
112+
106113
107114class IndexingStrategy :
108115 def codegen_load (
@@ -384,6 +391,10 @@ def codegen_store(
384391 assert extra_mask is None
385392 indexing = BlockedSubscriptIndexing .create (state , fake_tensor , subscript )
386393
394+ config = DeviceFunction .current ().config
395+ if _supports_epilogue_subtiling and config .epilogue_subtiling :
396+ return self ._codegen_epilogue_subtile_store (state , fake_tensor , indexing , store_value )
397+
387398 # Apply permutation to the value being stored if needed
388399 desc_arg = indexing .tensor_descriptor_arg (state )
389400 store_value = indexing .reshape_store (state , value )
@@ -394,12 +405,102 @@ def codegen_store(
394405 f"tl.permute({{store_val}}, { desc_arg .permutation !r} )" ,
395406 store_val = store_value ,
396407 )
397-
408+
398409 return expr_from_string (
399410 f"{ indexing .tensor_descriptor (state )} .store({ indexing .offsets_str_permuted (state )} , {{value}})" ,
400411 value = store_value ,
401412 )
402413
414+ def _codegen_epilogue_subtile_store (
415+ self ,
416+ state : CodegenState ,
417+ fake_tensor : torch .Tensor ,
418+ indexing : BlockedSubscriptIndexing ,
419+ store_value : ast .AST ,
420+ ) -> ast .AST | None :
421+ # Currently support 2D tiles without permutations
422+ if len (indexing .block_shape ) != 2 or len (indexing .offsets ) != 2 :
423+ return None
424+
425+ env = CompileEnvironment .current ()
426+ block_m , block_n = indexing .block_shape
427+ try :
428+ block_n_hint = env .size_hint (block_n )
429+ except Exception :
430+ return None
431+
432+ if block_n_hint % 2 != 0 :
433+ return None
434+
435+ device_fn = state .device_function
436+ codegen = state .codegen
437+
438+ block_m_str = device_fn .literal_expr (block_m )
439+ block_n_str = device_fn .literal_expr (block_n )
440+ indexing .block_shape [1 ] //= 2
441+ desc_arg = indexing .tensor_descriptor_arg (state )
442+
443+ if desc_arg .permutation is not None :
444+ return None
445+
446+
447+ block_n_half_str = f"({ block_n_str } // 2)"
448+
449+ # Lift the store value into a temporary variable for reuse
450+ acc_var = codegen .lift (store_value , prefix = "acc" )
451+
452+ reshape_expr = expr_from_string (
453+ "tl.reshape({acc}, [{dim_m}, 2, {dim_half}])" ,
454+ acc = acc_var ,
455+ dim_m = expr_from_string (block_m_str ),
456+ dim_half = expr_from_string (block_n_half_str ),
457+ )
458+ reshape_var = codegen .lift (reshape_expr , prefix = "acc" )
459+
460+ permute_expr = expr_from_string (
461+ "tl.permute({acc}, [0, 2, 1])" ,
462+ acc = reshape_var ,
463+ )
464+ permute_var = codegen .lift (permute_expr , prefix = "acc" )
465+
466+ acc0_name = codegen .tmpvar (prefix = "acc" )
467+ acc1_name = codegen .tmpvar (prefix = "acc" )
468+ codegen .add_statement (
469+ statement_from_string (
470+ f"{ acc0_name } , { acc1_name } = tl.split({{acc}})" ,
471+ acc = permute_var ,
472+ )
473+ )
474+ acc0 = expr_from_string (acc0_name )
475+ acc1 = expr_from_string (acc1_name )
476+
477+ desc_name = indexing .tensor_descriptor (state )
478+ offset0 = expr_from_string (indexing .offsets [0 ])
479+ offset1 = expr_from_string (indexing .offsets [1 ])
480+
481+ # First subtile store
482+ codegen .add_statement (
483+ statement_from_string (
484+ f"{ desc_name } .store([{{off0}}, {{off1}}], {{value}})" ,
485+ off0 = offset0 ,
486+ off1 = offset1 ,
487+ value = acc0 ,
488+ )
489+ )
490+
491+ offset1_shifted = expr_from_string (
492+ "({offset} + {half})" ,
493+ offset = expr_from_string (indexing .offsets [1 ]),
494+ half = expr_from_string (block_n_half_str ),
495+ )
496+
497+ # Emit second subtile store as the expression returned to the caller
498+ return expr_from_string (
499+ f"{ desc_name } .store([{{off0}}, {{off1}}], {{value}})" ,
500+ off0 = offset0 ,
501+ off1 = offset1_shifted ,
502+ value = acc1 ,
503+ )
403504
404505class StackIndexingStrategy :
405506 """
0 commit comments