Skip to content

Commit 2bc36d0

Browse files
committed
Add epilogue subtiling
stack-info: PR: #948, branch: PaulZhang12/stack/14
1 parent 1aaba3f commit 2bc36d0

File tree

5 files changed

+117
-2
lines changed

5 files changed

+117
-2
lines changed

examples/matmul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def matmul(
4444
Returns:
4545
Tensor: Resulting matrix of shape [m, n].
4646
"""
47+
4748
m, k = x.size()
4849
k2, n = y.size()
4950
assert k == k2, f"size mismatch {k} != {k2}"

helion/_compiler/device_function.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,9 +420,14 @@ def tensor_arg(
420420
def tensor_descriptor_arg(
421421
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
422422
) -> TensorDescriptorArg:
423+
import re
423424
host_function = HostFunction.current()
424425
block_size_expr = ", ".join(map(self.literal_expr, block_size))
426+
pattern = r'triton_helpers\.div_floor_integer\(([^,]+),\s*(\d+)\)'
427+
replacement = r'\1 // \2'
428+
block_size_expr = re.sub(pattern, replacement, block_size_expr)
425429
key = (fake_value, block_size_expr)
430+
426431
if key not in self._tensor_descriptor_args:
427432
origin = host_function.tensor_to_origin[fake_value]
428433
desc_name = self.new_var(origin.suggest_var_name() + "_desc")

helion/_compiler/indexing_strategy.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .. import exc
1616
from .._compat import get_tensor_descriptor_fn_name
1717
from .ast_extension import expr_from_string
18+
from .ast_extension import statement_from_string
1819
from .compile_environment import CompileEnvironment
1920
from .device_function import DeviceFunction
2021
from .host_function import HostFunction
@@ -384,22 +385,116 @@ def codegen_store(
384385
assert extra_mask is None
385386
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
386387

388+
config = DeviceFunction.current().config
389+
store_value = indexing.reshape_store(state, value)
390+
if config.epilogue_subtiling:
391+
return self._codegen_epilogue_subtile_store(state, fake_tensor, indexing, store_value)
392+
387393
# Apply permutation to the value being stored if needed
388394
desc_arg = indexing.tensor_descriptor_arg(state)
389-
store_value = indexing.reshape_store(state, value)
390395

391396
if desc_arg.permutation is not None:
392397
# Apply permutation to the value
393398
store_value = expr_from_string(
394399
f"tl.permute({{store_val}}, {desc_arg.permutation!r})",
395400
store_val=store_value,
396401
)
397-
402+
398403
return expr_from_string(
399404
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})",
400405
value=store_value,
401406
)
402407

408+
def _codegen_epilogue_subtile_store(
409+
self,
410+
state: CodegenState,
411+
fake_tensor: torch.Tensor,
412+
indexing: BlockedSubscriptIndexing,
413+
store_value: ast.AST,
414+
) -> ast.AST | None:
415+
# Currently support 2D tiles without permutations
416+
if len(indexing.block_shape) != 2 or len(indexing.offsets) != 2:
417+
return None
418+
419+
env = CompileEnvironment.current()
420+
block_m, block_n = indexing.block_shape
421+
try:
422+
block_n_hint = env.size_hint(block_n)
423+
except Exception:
424+
return None
425+
426+
if block_n_hint % 2 != 0:
427+
return None
428+
429+
device_fn = state.device_function
430+
codegen = state.codegen
431+
432+
block_m_str = device_fn.literal_expr(block_m)
433+
block_n_str = device_fn.literal_expr(block_n)
434+
indexing.block_shape[1] //= 2
435+
desc_arg = indexing.tensor_descriptor_arg(state)
436+
437+
if desc_arg.permutation is not None:
438+
return None
439+
440+
441+
block_n_half_str = f"({block_n_str} // 2)"
442+
443+
# Lift the store value into a temporary variable for reuse
444+
acc_var = codegen.lift(store_value, prefix="acc")
445+
446+
reshape_expr = expr_from_string(
447+
"tl.reshape({acc}, [{dim_m}, 2, {dim_half}])",
448+
acc=acc_var,
449+
dim_m=expr_from_string(block_m_str),
450+
dim_half=expr_from_string(block_n_half_str),
451+
)
452+
reshape_var = codegen.lift(reshape_expr, prefix="acc")
453+
454+
permute_expr = expr_from_string(
455+
"tl.permute({acc}, [0, 2, 1])",
456+
acc=reshape_var,
457+
)
458+
permute_var = codegen.lift(permute_expr, prefix="acc")
459+
460+
acc0_name = codegen.tmpvar(prefix="acc")
461+
acc1_name = codegen.tmpvar(prefix="acc")
462+
codegen.add_statement(
463+
statement_from_string(
464+
f"{acc0_name}, {acc1_name} = tl.split({{acc}})",
465+
acc=permute_var,
466+
)
467+
)
468+
acc0 = expr_from_string(acc0_name)
469+
acc1 = expr_from_string(acc1_name)
470+
471+
desc_name = indexing.tensor_descriptor(state)
472+
offset0 = expr_from_string(indexing.offsets[0])
473+
offset1 = expr_from_string(indexing.offsets[1])
474+
475+
# First subtile store
476+
codegen.add_statement(
477+
statement_from_string(
478+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
479+
off0=offset0,
480+
off1=offset1,
481+
value=acc0,
482+
)
483+
)
484+
485+
offset1_shifted = expr_from_string(
486+
"({offset} + {half})",
487+
offset=expr_from_string(indexing.offsets[1]),
488+
half=expr_from_string(block_n_half_str),
489+
)
490+
491+
# Emit second subtile store as the expression returned to the caller
492+
return expr_from_string(
493+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
494+
off0=offset0,
495+
off1=offset1_shifted,
496+
value=acc1,
497+
)
403498

404499
class StackIndexingStrategy:
405500
"""

helion/autotuner/config_spec.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"pid_type",
5353
"indexing",
5454
"load_eviction_policies",
55+
"epilogue_subtiling"
5556
]
5657
)
5758
VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved")
@@ -105,6 +106,7 @@ class ConfigSpec:
105106
EnumFragment(choices=VALID_EVICTION_POLICIES), length=0
106107
)
107108
)
109+
epilogue_subtiling: bool = dataclasses.field(default=False)
108110

109111
@staticmethod
110112
def _valid_indexing_types() -> tuple[IndexingLiteral, ...]:
@@ -224,6 +226,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
224226
config.setdefault(
225227
"load_eviction_policies", self.load_eviction_policies.default()
226228
)
229+
config.setdefault("epilogue_subtiling", False)
227230
# TODO(jansel): include num_ctas and max_nreg
228231

229232
for name, values in (
@@ -238,6 +241,9 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
238241
else:
239242
config[name] = values[0]
240243

244+
if config["indexing"] != "tensor_descriptor" or any(block_id < 16 for block_id in config["block_sizes"]):
245+
config["epilogue_subtiling"] = False
246+
241247
# Set default values for grid indices when pid_type is not persistent
242248
pid_type = config["pid_type"]
243249
if pid_type in ("flat", "xyz") and self.grid_block_ids:
@@ -279,6 +285,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
279285
"indexing": fn(EnumFragment(self._valid_indexing_types())),
280286
"pid_type": fn(EnumFragment(self.allowed_pid_types)),
281287
"load_eviction_policies": fn(self.load_eviction_policies),
288+
"epilogue_subtiling": fn(BooleanFragment()),
282289
}
283290
# Add tunable parameters
284291
config.update(

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
num_stages: int | None = None,
4040
pid_type: PidTypeLiteral | None = None,
4141
indexing: IndexingLiteral | None = None,
42+
epilogue_subtiling: bool | None = None,
4243
# For user-defined properties
4344
**kwargs: object,
4445
) -> None:
@@ -61,6 +62,7 @@ def __init__(
6162
num_stages: Number of stages for software pipelining.
6263
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
6364
indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr").
65+
epilogue_subtiling: Whether to use subtiling for epilogue.
6466
**kwargs: Additional user-defined configuration parameters.
6567
"""
6668
self.config = {}
@@ -81,6 +83,7 @@ def __init__(
8183
"num_stages": num_stages,
8284
"indexing": indexing,
8385
"pid_type": pid_type,
86+
"epilogue_subtiling": epilogue_subtiling,
8487
}
8588
for key, value in core_props.items():
8689
if value is not None:
@@ -206,6 +209,10 @@ def load_eviction_policies(self) -> list[str]:
206209
def indexing(self) -> IndexingLiteral:
207210
return self.config.get("indexing", "pointer") # type: ignore[return-value]
208211

212+
@property
213+
def epilogue_subtiling(self) -> bool:
214+
return self.config.get("epilogue_subtiling", False) # type: ignore[return-value]
215+
209216

210217
def _list_to_tuple(x: object) -> object:
211218
if isinstance(x, list):

0 commit comments

Comments
 (0)