1515from .. import exc
1616from .._compat import get_tensor_descriptor_fn_name
1717from .ast_extension import expr_from_string
18+ from .ast_extension import statement_from_string
1819from .compile_environment import CompileEnvironment
1920from .device_function import DeviceFunction
2021from .host_function import HostFunction
@@ -353,7 +354,6 @@ def codegen_load(
353354 )
354355 assert extra_mask is None
355356 indexing = BlockedSubscriptIndexing .create (state , fake_tensor , subscript )
356-
357357 # Load from tensor descriptor with permuted offsets
358358 load_expr = expr_from_string (
359359 f"{ indexing .tensor_descriptor (state )} .load({ indexing .offsets_str_permuted (state )} )"
@@ -383,10 +383,12 @@ def codegen_store(
383383 )
384384 assert extra_mask is None
385385 indexing = BlockedSubscriptIndexing .create (state , fake_tensor , subscript )
386+ store_value = indexing .reshape_store (state , value )
386387
388+ config = DeviceFunction .current ().config
389+ epilogue_subtiles = state .config .epilogue_subtiling
387390 # Apply permutation to the value being stored if needed
388391 desc_arg = indexing .tensor_descriptor_arg (state )
389- store_value = indexing .reshape_store (state , value )
390392
391393 if desc_arg .permutation is not None :
392394 # Apply permutation to the value
@@ -395,11 +397,204 @@ def codegen_store(
395397 store_val = store_value ,
396398 )
397399
400+ if (idx := state .device_function .device_store_index ) < len (epilogue_subtiles ):
401+ subtile_split = epilogue_subtiles [idx ]
402+ state .device_function .device_store_index += 1
403+
404+ # Check if we should fuse a pointwise operation into the epilogue store
405+ fused_pointwise_node = self ._get_fusable_pointwise_node (state )
406+
407+ subtile_codegen = self ._codegen_epilogue_subtile_store (
408+ state ,
409+ fake_tensor ,
410+ indexing ,
411+ store_value ,
412+ subtile_split ,
413+ config ,
414+ fused_pointwise_node ,
415+ )
416+ if subtile_codegen is not None :
417+ return subtile_codegen
418+
398419 return expr_from_string (
399420 f"{ indexing .tensor_descriptor (state )} .store({ indexing .offsets_str_permuted (state )} , {{value}})" ,
400421 value = store_value ,
401422 )
402423
424+ def _get_fusable_pointwise_node (self , state : CodegenState ) -> torch .fx .Node | None :
425+ """Find a pointwise node feeding into this store that can be fused.
426+
427+ Returns the pointwise FX node if found, None otherwise.
428+ """
429+ if state .fx_node is None :
430+ return None
431+
432+ # Get the value being stored (3rd argument to store)
433+ if len (state .fx_node .args ) < 3 :
434+ return None
435+
436+ value_node = state .fx_node .args [2 ]
437+ if not isinstance (value_node , torch .fx .Node ):
438+ return None
439+
440+ # Check if this is a pointwise node
441+ from .inductor_lowering import PointwiseLowering
442+
443+ lowering = value_node .meta .get ("lowering" )
444+ if not isinstance (lowering , PointwiseLowering ):
445+ return None
446+
447+ # Check if this node only has one user (the store)
448+ if len (list (value_node .users )) != 1 :
449+ return None
450+
451+ return value_node
452+
453+ def _apply_pointwise_to_subtile (
454+ self , state : CodegenState , pointwise_node : torch .fx .Node , subtile_value : ast .AST
455+ ) -> ast .AST :
456+ """Apply a pointwise operation to a subtile value.
457+
458+ Args:
459+ state: The codegen state
460+ pointwise_node: The FX node representing the pointwise operation
461+ subtile_value: The AST for the subtile value to apply the operation to
462+
463+ Returns:
464+ AST for the result after applying the pointwise operation
465+ """
466+ from torch ._inductor import ir
467+
468+ from .inductor_lowering import PointwiseLowering
469+ from .inductor_lowering import install_inductor_kernel_handlers
470+
471+ lowering = pointwise_node .meta ["lowering" ]
472+ assert isinstance (lowering , PointwiseLowering )
473+
474+ # Get the pointwise buffer
475+ buffer = lowering .buffer
476+ assert isinstance (buffer .data , ir .Pointwise )
477+
478+ # Create a temporary variable for the subtile
479+ codegen = state .codegen
480+ subtile_var = codegen .lift (subtile_value , prefix = "subtile" )
481+
482+ # Set up the inductor kernel handlers with the subtile as input
483+ with install_inductor_kernel_handlers (
484+ codegen , {lowering .input_names [0 ]: subtile_var }
485+ ):
486+ # Generate the pointwise operation
487+ indices = [sympy .Symbol (f"i{ n } " ) for n in range (len (buffer .data .ranges ))]
488+ from .inductor_lowering import _unpack_opsvalue
489+
490+ result_name = _unpack_opsvalue (buffer .data .inner_fn (indices ))
491+ return expr_from_string (result_name )
492+
493+ def _codegen_epilogue_subtile_store (
494+ self ,
495+ state : CodegenState ,
496+ fake_tensor : torch .Tensor ,
497+ indexing : BlockedSubscriptIndexing ,
498+ store_value : ast .AST ,
499+ subtile_split : int ,
500+ config : Config ,
501+ fused_pointwise_node : torch .fx .Node | None = None ,
502+ ) -> ast .AST | None :
503+ # Currently support 2D tiles without permutations
504+ if (
505+ len (indexing .block_shape ) != 2
506+ or len (indexing .offsets ) != 2
507+ or subtile_split == 0
508+ ):
509+ return None
510+
511+ env = CompileEnvironment .current ()
512+ block_m , block_n = indexing .block_shape
513+ try :
514+ block_n_hint = env .size_hint (block_n )
515+ block_idx = env .get_block_id (block_n )
516+ block_size = env .block_sizes [block_idx ].from_config (config )
517+ except Exception :
518+ return None
519+
520+ if block_n_hint % 2 != 0 or block_size <= 16 :
521+ return None
522+
523+ device_fn = state .device_function
524+ codegen = state .codegen
525+
526+ block_m_str = device_fn .literal_expr (block_m )
527+ block_n_str = device_fn .literal_expr (block_n )
528+ indexing .block_shape [1 ] //= subtile_split
529+
530+ # TODO(PaulZhang12): Support more epilogue subtile configs besides 2
531+ block_n_half_str = f"({ block_n_str } // { subtile_split } )"
532+
533+ # If we have a fused pointwise operation, mark it to skip normal codegen
534+ # and get its input value instead
535+ if fused_pointwise_node is not None :
536+ fused_pointwise_node .meta ["fused_into_store" ] = True
537+
538+ # Lift the store value into a temporary variable for reuse
539+ acc_var = codegen .lift (store_value , prefix = "acc" )
540+
541+ reshape_expr = expr_from_string (
542+ "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)" ,
543+ acc = acc_var ,
544+ dim_m = expr_from_string (block_m_str ),
545+ dim_half = expr_from_string (block_n_half_str ),
546+ )
547+ reshape_var = codegen .lift (reshape_expr , prefix = "acc" )
548+
549+ acc0_name = codegen .tmpvar (prefix = "acc" )
550+ acc1_name = codegen .tmpvar (prefix = "acc" )
551+ codegen .add_statement (
552+ statement_from_string (
553+ f"{ acc0_name } , { acc1_name } = tl.split({{acc}})" ,
554+ acc = reshape_var ,
555+ )
556+ )
557+
558+ # Now apply the pointwise operation per-subtile if we have one
559+ if fused_pointwise_node is not None :
560+ acc0 = self ._apply_pointwise_to_subtile (
561+ state , fused_pointwise_node , expr_from_string (acc0_name )
562+ )
563+ acc1 = self ._apply_pointwise_to_subtile (
564+ state , fused_pointwise_node , expr_from_string (acc1_name )
565+ )
566+ else :
567+ acc0 = expr_from_string (acc0_name )
568+ acc1 = expr_from_string (acc1_name )
569+
570+ desc_name = indexing .tensor_descriptor (state )
571+ offset0 = expr_from_string (indexing .offsets [0 ])
572+ offset1 = expr_from_string (indexing .offsets [1 ])
573+
574+ # First subtile store
575+ codegen .add_statement (
576+ statement_from_string (
577+ f"{ desc_name } .store([{{off0}}, {{off1}}], {{value}})" ,
578+ off0 = offset0 ,
579+ off1 = offset1 ,
580+ value = acc0 ,
581+ )
582+ )
583+
584+ offset1_shifted = expr_from_string (
585+ "({offset} + {half})" ,
586+ offset = expr_from_string (indexing .offsets [1 ]),
587+ half = expr_from_string (block_n_half_str ),
588+ )
589+
590+ # Emit second subtile store as the expression returned to the caller
591+ return expr_from_string (
592+ f"{ desc_name } .store([{{off0}}, {{off1}}], {{value}})" ,
593+ off0 = offset0 ,
594+ off1 = offset1_shifted ,
595+ value = acc1 ,
596+ )
597+
403598
404599class StackIndexingStrategy :
405600 """
0 commit comments