Skip to content

Commit a9e954b

Browse files
committed
wip
1 parent 9d8f78f commit a9e954b

File tree

3 files changed

+41
-15
lines changed

3 files changed

+41
-15
lines changed

helion/_compiler/device_function.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,7 @@ def block_size_var(self, block_id: int) -> str | None:
282282

283283
var_name = self.new_var(f"_BLOCK_SIZE_{block_id}")
284284
self.block_size_var_cache[key] = var_name
285-
host_expr = HostFunction.current().literal_expr(block_value)
286-
if self.constexpr_arg(var_name, host_expr):
287-
self.codegen.host_statements.append(
288-
statement_from_string(f"{var_name} = {host_expr}")
289-
)
285+
self.constexpr_arg_with_host_def(var_name, block_value)
290286

291287
return self.block_size_var_cache[key]
292288

@@ -484,14 +480,50 @@ def expr_arg(self, sym: sympy.Expr, origin: Origin) -> SymbolArgument:
484480
self._expr_args[sym] = arg
485481
return self._expr_args[sym]
486482

487-
def constexpr_arg(self, name: str, host_str: str | None = None) -> bool:
483+
def constexpr_arg(self, name: str, value: object | None = None) -> bool:
488484
"""Create a constexpr argument, returns True if created, False if already exists."""
489485
if name in self._constexpr_args:
490486
return False
491-
self._constexpr_args[name] = rv = ConstExprArg(name, host_str or name)
487+
host_str = name if value is None else self._format_constexpr_value(value)
488+
self._constexpr_args[name] = rv = ConstExprArg(name, host_str)
492489
self.arguments.append(rv)
493490
return True
494491

492+
def constexpr_arg_with_host_def(self, name: str, value: object) -> None:
493+
"""Create a constexpr argument and add its host-side definition if needed."""
494+
if self.constexpr_arg(name, value):
495+
host_expr = self._constexpr_args[name].host_str()
496+
self.codegen.host_statements.append(
497+
statement_from_string(f"{name} = {host_expr}")
498+
)
499+
500+
def _format_constexpr_value(self, value: object) -> str:
501+
if isinstance(value, str):
502+
return value
503+
if isinstance(value, (int, float, bool)):
504+
return repr(value)
505+
506+
# Extract sympy expression from torch symbolic types
507+
if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
508+
value = value._sympy_()
509+
510+
# Handle sympy expressions (sanitize by replacing triton_helpers functions)
511+
if isinstance(value, sympy.Expr):
512+
expr = value.replace(
513+
lambda node: isinstance(node, sympy.Function)
514+
and getattr(node.func, "__name__", "")
515+
== "triton_helpers.div_floor_integer",
516+
lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
517+
).replace(
518+
lambda node: isinstance(node, sympy.Function)
519+
and getattr(node.func, "__name__", "")
520+
== "triton_helpers.remainder_integer",
521+
lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
522+
)
523+
return HostFunction.current().sympy_expr(expr)
524+
525+
return HostFunction.current().literal_expr(value)
526+
495527
def _tensor_property(
496528
self,
497529
prop_cls: type[_P],

helion/_compiler/inductor_lowering.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,8 +1241,7 @@ def _create_named_result(self, node: Node, result: ast.expr) -> str:
12411241
):
12421242
# This expression is used in tl.arange, make it a constexpr
12431243
name = self.cg.device_function.new_var(node.name)
1244-
host_expr = self.cg.device_function.sympy_expr(val._sympy_())
1245-
self.cg.device_function.constexpr_arg(name, host_expr)
1244+
self.cg.device_function.constexpr_arg(name, val._sympy_())
12461245
return name
12471246

12481247
# If the lowering produced a named value that is already defined elsewhere

helion/_compiler/tile_strategy.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,7 @@ def _setup_block_size_constexpr(
244244
self, state: CodegenState, block_size_var: str, block_size: SymIntLike
245245
) -> None:
246246
"""Helper to setup constexpr block size variable on host."""
247-
if state.device_function.constexpr_arg(block_size_var):
248-
state.codegen.host_statements.append(
249-
statement_from_string(
250-
f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}"
251-
)
252-
)
247+
state.device_function.constexpr_arg_with_host_def(block_size_var, block_size)
253248

254249

255250
class BlockSizeTileStrategy(TileStrategy):

0 commit comments

Comments
 (0)