@@ -250,6 +250,9 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
250250        self .rng_seed_count  =  0 
251251        self .device_load_index  =  0   # Track which load in device code we're generating (for eviction policy tuning) 
252252        # Name of the RNG seed buffer parameter in kernel signature 
253+         self .device_store_index  =  (
254+             0   # Track which store in device code we're generating (for subtiling) 
255+         )
253256        self .rng_seed_buffer_param_name  =  None 
254257
255258    def  has_rng_ops (self ) ->  bool :
@@ -421,8 +424,9 @@ def tensor_descriptor_arg(
421424        self , fake_value : torch .Tensor , block_size : list [int  |  torch .SymInt ]
422425    ) ->  TensorDescriptorArg :
423426        host_function  =  HostFunction .current ()
424-         block_size_expr  =  ", " .join (map ( self .literal_expr ,  block_size ) )
427+         block_size_expr  =  ", " .join (self .literal_expr ( dim )  for   dim   in   block_size )
425428        key  =  (fake_value , block_size_expr )
429+ 
426430        if  key  not  in self ._tensor_descriptor_args :
427431            origin  =  host_function .tensor_to_origin [fake_value ]
428432            desc_name  =  self .new_var (origin .suggest_var_name () +  "_desc" )
@@ -515,22 +519,6 @@ def _format_constexpr_value(self, value: object) -> str:
515519        if  isinstance (value , (torch .SymInt , torch .SymFloat , torch .SymBool )):
516520            value  =  value ._sympy_ ()
517521
518-         # Handle sympy expressions (sanitize by replacing triton_helpers functions) 
519-         if  isinstance (value , sympy .Expr ):
520-             sanitized  =  value .replace (  # pyright: ignore[reportAttributeAccessIssue] 
521-                 lambda  node : isinstance (node , sympy .Function )
522-                 and  getattr (node .func , "__name__" , "" )
523-                 ==  "triton_helpers.div_floor_integer" ,
524-                 lambda  node : sympy .floor (node .args [0 ] /  node .args [1 ]),  # pyright: ignore[reportAttributeAccessIssue] 
525-             ).replace (  # pyright: ignore[reportAttributeAccessIssue] 
526-                 lambda  node : isinstance (node , sympy .Function )
527-                 and  getattr (node .func , "__name__" , "" )
528-                 ==  "triton_helpers.remainder_integer" ,
529-                 lambda  node : sympy .Mod (node .args [0 ], node .args [1 ]),  # pyright: ignore[reportAttributeAccessIssue] 
530-             )
531-             expr  =  cast ("sympy.Expr" , sanitized )
532-             return  HostFunction .current ().sympy_expr (expr )
533- 
534522        return  HostFunction .current ().literal_expr (value )
535523
536524    def  _tensor_property (
@@ -708,11 +696,19 @@ def current() -> DeviceFunction:
708696
709697
710698class  HelionTritonPrinter (TritonPrinter ):
711-     """Custom Triton printer that avoids wrapping float literals in tl.full(). 
712- 
713-     Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value 
714-     via tl.full([], <val>, tl.float64). We override this to emit the raw numeric 
715-     literal, letting downstream type promotion and casts handle dtype. 
699+     """Custom Triton printer that does the following: 
700+ 
701+     - Avoids wrapping float literals in tl.full(). 
702+      Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value 
703+      via tl.full([], <val>, tl.float64). We override this to emit the raw numeric 
704+      literal, letting downstream type promotion and casts handle dtype. 
705+ 
706+     - Avoids triton_helpers.div_floor_integer(...) calls when both operands are 
707+       provably non-negative integers. TritonPrinter by default converts 
708+       floor(u1/2) to triton_helpers.div_floor_integer(...). We override this to 
709+       emit u1 // 2 only when the numerator is known to be non-negative and the 
710+       denominator is a positive integer, so that we keep helper calls for cases 
711+       that rely on floor semantics with mixed signs. 
716712    """ 
717713
718714    def  _print_Float (self , expr : sympy .Expr ) ->  str :
@@ -721,6 +717,53 @@ def _print_Float(self, expr: sympy.Expr) -> str:
721717    def  _print_ToFloat (self , expr : sympy .Expr ) ->  str :
722718        return  f"{ expr }  
723719
720+     def  _is_nonnegative (self , expr : sympy .Expr ) ->  bool :
721+         if  expr .is_nonnegative  is  True  or  expr .is_zero  is  True :
722+             return  True 
723+         if  expr .is_positive  is  True :
724+             return  True 
725+         try :
726+             host_fn  =  HostFunction .current ()
727+         except  NoCurrentFunction :
728+             host_fn  =  None 
729+         if  host_fn  is  not None :
730+             origin_info  =  host_fn .expr_to_origin .get (expr )
731+             if  origin_info  and  isinstance (
732+                 origin_info .origin , (BlockSizeOrigin , TensorSizeOrigin )
733+             ):
734+                 return  True 
735+         if  isinstance (expr , sympy .Symbol ) and  expr .name .startswith ("_BLOCK_SIZE_" ):
736+             return  True 
737+         if  isinstance (expr , sympy .Number ):
738+             return  bool (expr  >=  0 )
739+         return  False 
740+ 
741+     def  _format_trunc_div (self , lhs : sympy .Expr , rhs : sympy .Expr ) ->  str :
742+         lhs_str  =  self ._print (lhs )
743+         rhs_str  =  self ._print (rhs )
744+         if  not  (lhs .is_Integer  or  lhs .is_Symbol ):
745+             lhs_str  =  f"({ lhs_str }  
746+         if  not  (rhs .is_Integer  or  rhs .is_Symbol ):
747+             rhs_str  =  f"({ rhs_str }  
748+         return  f"{ lhs_str } { rhs_str }  
749+ 
750+     def  _print_floor (self , expr : sympy .Expr ) ->  str :
751+         inner  =  expr .args [0 ]
752+         numer , denom  =  inner .as_numer_denom ()
753+         if  (
754+             isinstance (denom , sympy .Integer )
755+             and  denom  >  1 
756+             and  self ._is_nonnegative (numer )
757+         ):
758+             return  self ._format_trunc_div (numer , denom )
759+         return  super ()._print_floor (expr )
760+ 
761+     def  _print_FloorDiv (self , expr : sympy .Expr ) ->  str :
762+         lhs , rhs  =  expr .args 
763+         if  isinstance (rhs , sympy .Integer ) and  rhs  >  0  and  self ._is_nonnegative (lhs ):
764+             return  self ._format_trunc_div (lhs , rhs )
765+         return  super ()._print_FloorDiv (expr )
766+ 
724767
725768def  texpr (expr : sympy .Expr ) ->  str :
726769    return  HelionTritonPrinter ().doprint (expr )
0 commit comments