1515from  .. import  exc 
1616from  .._compat  import  get_tensor_descriptor_fn_name 
1717from  .ast_extension  import  expr_from_string 
18+ from  .ast_extension  import  statement_from_string 
1819from  .compile_environment  import  CompileEnvironment 
1920from  .device_function  import  DeviceFunction 
2021from  .host_function  import  HostFunction 
@@ -353,7 +354,6 @@ def codegen_load(
353354            )
354355        assert  extra_mask  is  None 
355356        indexing  =  BlockedSubscriptIndexing .create (state , fake_tensor , subscript )
356- 
357357        # Load from tensor descriptor with permuted offsets 
358358        load_expr  =  expr_from_string (
359359            f"{ indexing .tensor_descriptor (state )} { indexing .offsets_str_permuted (state )}  
@@ -383,10 +383,24 @@ def codegen_store(
383383            )
384384        assert  extra_mask  is  None 
385385        indexing  =  BlockedSubscriptIndexing .create (state , fake_tensor , subscript )
386+         store_value  =  indexing .reshape_store (state , value )
387+ 
388+         config  =  DeviceFunction .current ().config 
389+         epilogue_subtiles  =  state .config .epilogue_subtiling 
390+         if  torch .cuda .get_device_capability () >=  (9 , 0 ) and  (
391+             idx  :=  state .device_function .device_store_index 
392+         ) <  len (epilogue_subtiles ):
393+             subtile_split  =  epilogue_subtiles [idx ]
394+             state .device_function .device_store_index  +=  1 
395+ 
396+             subtile_codegen  =  self ._codegen_epilogue_subtile_store (
397+                 state , fake_tensor , indexing , store_value , subtile_split , config 
398+             )
399+             if  subtile_codegen  is  not None :
400+                 return  subtile_codegen 
386401
387402        # Apply permutation to the value being stored if needed 
388403        desc_arg  =  indexing .tensor_descriptor_arg (state )
389-         store_value  =  indexing .reshape_store (state , value )
390404
391405        if  desc_arg .permutation  is  not None :
392406            # Apply permutation to the value 
@@ -400,6 +414,95 @@ def codegen_store(
400414            value = store_value ,
401415        )
402416
417+     def  _codegen_epilogue_subtile_store (
418+         self ,
419+         state : CodegenState ,
420+         fake_tensor : torch .Tensor ,
421+         indexing : BlockedSubscriptIndexing ,
422+         store_value : ast .AST ,
423+         subtile_split : int ,
424+         config : Config ,
425+     ) ->  ast .AST  |  None :
426+         # Currently support 2D tiles without permutations 
427+         if  (
428+             len (indexing .block_shape ) !=  2 
429+             or  len (indexing .offsets ) !=  2 
430+             or  subtile_split  ==  0 
431+         ):
432+             return  None 
433+ 
434+         env  =  CompileEnvironment .current ()
435+         block_m , block_n  =  indexing .block_shape 
436+         try :
437+             block_n_hint  =  env .size_hint (block_n )
438+             block_idx  =  env .get_block_id (block_n )
439+             block_size  =  env .block_sizes [block_idx ].from_config (config )
440+         except  Exception :
441+             return  None 
442+ 
443+         if  block_n_hint  %  2  !=  0  or  block_size  <=  16 :
444+             return  None 
445+ 
446+         device_fn  =  state .device_function 
447+         codegen  =  state .codegen 
448+ 
449+         block_m_str  =  device_fn .literal_expr (block_m )
450+         block_n_str  =  device_fn .literal_expr (block_n )
451+         indexing .block_shape [1 ] //=  subtile_split 
452+ 
453+         # TODO(PaulZhang12): Support more epilogue subtile configs besides 2 
454+         block_n_half_str  =  f"({ block_n_str } { subtile_split }  
455+ 
456+         # Lift the store value into a temporary variable for reuse 
457+         acc_var  =  codegen .lift (store_value , prefix = "acc" )
458+ 
459+         reshape_expr  =  expr_from_string (
460+             "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)" ,
461+             acc = acc_var ,
462+             dim_m = expr_from_string (block_m_str ),
463+             dim_half = expr_from_string (block_n_half_str ),
464+         )
465+         reshape_var  =  codegen .lift (reshape_expr , prefix = "acc" )
466+ 
467+         acc0_name  =  codegen .tmpvar (prefix = "acc" )
468+         acc1_name  =  codegen .tmpvar (prefix = "acc" )
469+         codegen .add_statement (
470+             statement_from_string (
471+                 f"{ acc0_name } { acc1_name }  ,
472+                 acc = reshape_var ,
473+             )
474+         )
475+         acc0  =  expr_from_string (acc0_name )
476+         acc1  =  expr_from_string (acc1_name )
477+ 
478+         desc_name  =  indexing .tensor_descriptor (state )
479+         offset0  =  expr_from_string (indexing .offsets [0 ])
480+         offset1  =  expr_from_string (indexing .offsets [1 ])
481+ 
482+         # First subtile store 
483+         codegen .add_statement (
484+             statement_from_string (
485+                 f"{ desc_name }  ,
486+                 off0 = offset0 ,
487+                 off1 = offset1 ,
488+                 value = acc0 ,
489+             )
490+         )
491+ 
492+         offset1_shifted  =  expr_from_string (
493+             "({offset} + {half})" ,
494+             offset = expr_from_string (indexing .offsets [1 ]),
495+             half = expr_from_string (block_n_half_str ),
496+         )
497+ 
498+         # Emit second subtile store as the expression returned to the caller 
499+         return  expr_from_string (
500+             f"{ desc_name }  ,
501+             off0 = offset0 ,
502+             off1 = offset1_shifted ,
503+             value = acc1 ,
504+         )
505+ 
403506
404507class  StackIndexingStrategy :
405508    """ 
0 commit comments