@@ -282,11 +282,7 @@ def block_size_var(self, block_id: int) -> str | None:
282282
283283            var_name  =  self .new_var (f"_BLOCK_SIZE_{ block_id }  " )
284284            self .block_size_var_cache [key ] =  var_name 
285-             host_expr  =  HostFunction .current ().literal_expr (block_value )
286-             if  self .constexpr_arg (var_name , host_expr ):
287-                 self .codegen .host_statements .append (
288-                     statement_from_string (f"{ var_name }   = { host_expr }  " )
289-                 )
285+             self .constexpr_arg_with_host_def (var_name , block_value )
290286
291287        return  self .block_size_var_cache [key ]
292288
@@ -484,14 +480,50 @@ def expr_arg(self, sym: sympy.Expr, origin: Origin) -> SymbolArgument:
484480            self ._expr_args [sym ] =  arg 
485481        return  self ._expr_args [sym ]
486482
487-     def  constexpr_arg (self , name : str , host_str :  str  |  None  =  None ) ->  bool :
483+     def  constexpr_arg (self , name : str , value :  object  |  None  =  None ) ->  bool :
488484        """Create a constexpr argument, returns True if created, False if already exists.""" 
489485        if  name  in  self ._constexpr_args :
490486            return  False 
491-         self ._constexpr_args [name ] =  rv  =  ConstExprArg (name , host_str  or  name )
487+         host_str  =  name  if  value  is  None  else  self ._format_constexpr_value (value )
488+         self ._constexpr_args [name ] =  rv  =  ConstExprArg (name , host_str )
492489        self .arguments .append (rv )
493490        return  True 
494491
492+     def  constexpr_arg_with_host_def (self , name : str , value : object ) ->  None :
493+         """Create a constexpr argument and add its host-side definition if needed.""" 
494+         if  self .constexpr_arg (name , value ):
495+             host_expr  =  self ._constexpr_args [name ].host_str ()
496+             self .codegen .host_statements .append (
497+                 statement_from_string (f"{ name }   = { host_expr }  " )
498+             )
499+ 
500+     def  _format_constexpr_value (self , value : object ) ->  str :
501+         if  isinstance (value , str ):
502+             return  value 
503+         if  isinstance (value , (int , float , bool )):
504+             return  repr (value )
505+ 
506+         # Extract sympy expression from torch symbolic types 
507+         if  isinstance (value , (torch .SymInt , torch .SymFloat , torch .SymBool )):
508+             value  =  value ._sympy_ ()
509+ 
510+         # Handle sympy expressions (sanitize by replacing triton_helpers functions) 
511+         if  isinstance (value , sympy .Expr ):
512+             expr  =  value .replace (
513+                 lambda  node : isinstance (node , sympy .Function )
514+                 and  getattr (node .func , "__name__" , "" )
515+                 ==  "triton_helpers.div_floor_integer" ,
516+                 lambda  node : sympy .floor (node .args [0 ] /  node .args [1 ]),  # pyright: ignore[reportAttributeAccessIssue] 
517+             ).replace (
518+                 lambda  node : isinstance (node , sympy .Function )
519+                 and  getattr (node .func , "__name__" , "" )
520+                 ==  "triton_helpers.remainder_integer" ,
521+                 lambda  node : sympy .Mod (node .args [0 ], node .args [1 ]),  # pyright: ignore[reportAttributeAccessIssue] 
522+             )
523+             return  HostFunction .current ().sympy_expr (expr )
524+ 
525+         return  HostFunction .current ().literal_expr (value )
526+ 
495527    def  _tensor_property (
496528        self ,
497529        prop_cls : type [_P ],
0 commit comments