Skip to content

Commit 965b193

Browse files
committed
Add epilogue subtiling
stack-info: PR: #948, branch: PaulZhang12/stack/14
1 parent 1aaba3f commit 965b193

File tree

5 files changed

+118
-1
lines changed

5 files changed

+118
-1
lines changed

examples/matmul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def matmul(
4444
Returns:
4545
Tensor: Resulting matrix of shape [m, n].
4646
"""
47+
4748
m, k = x.size()
4849
k2, n = y.size()
4950
assert k == k2, f"size mismatch {k} != {k2}"

helion/_compiler/device_function.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,9 +420,14 @@ def tensor_arg(
420420
def tensor_descriptor_arg(
421421
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
422422
) -> TensorDescriptorArg:
423+
import re
423424
host_function = HostFunction.current()
424425
block_size_expr = ", ".join(map(self.literal_expr, block_size))
426+
pattern = r'triton_helpers\.div_floor_integer\(([^,]+),\s*(\d+)\)'
427+
replacement = r'\1 // \2'
428+
block_size_expr = re.sub(pattern, replacement, block_size_expr)
425429
key = (fake_value, block_size_expr)
430+
426431
if key not in self._tensor_descriptor_args:
427432
origin = host_function.tensor_to_origin[fake_value]
428433
desc_name = self.new_var(origin.suggest_var_name() + "_desc")

helion/_compiler/indexing_strategy.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .. import exc
1616
from .._compat import get_tensor_descriptor_fn_name
1717
from .ast_extension import expr_from_string
18+
from .ast_extension import statement_from_string
1819
from .compile_environment import CompileEnvironment
1920
from .device_function import DeviceFunction
2021
from .host_function import HostFunction
@@ -103,6 +104,12 @@ def _get_tile_with_offset_info(
103104

104105
return None
105106

107+
def _supports_epilogue_subtiling():
108+
env = CompileEnvironment.current()
109+
if env.device.type != "cuda" or not env.settings.allow_epilogue_subtiling:
110+
return False
111+
return torch.cuda.get_device_capability() >= (10, 0)
112+
106113

107114
class IndexingStrategy:
108115
def codegen_load(
@@ -384,6 +391,10 @@ def codegen_store(
384391
assert extra_mask is None
385392
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
386393

394+
config = DeviceFunction.current().config
395+
if _supports_epilogue_subtiling and config.epilogue_subtiling:
396+
return self._codegen_epilogue_subtile_store(state, fake_tensor, indexing, store_value)
397+
387398
# Apply permutation to the value being stored if needed
388399
desc_arg = indexing.tensor_descriptor_arg(state)
389400
store_value = indexing.reshape_store(state, value)
@@ -394,12 +405,102 @@ def codegen_store(
394405
f"tl.permute({{store_val}}, {desc_arg.permutation!r})",
395406
store_val=store_value,
396407
)
397-
408+
398409
return expr_from_string(
399410
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})",
400411
value=store_value,
401412
)
402413

414+
def _codegen_epilogue_subtile_store(
415+
self,
416+
state: CodegenState,
417+
fake_tensor: torch.Tensor,
418+
indexing: BlockedSubscriptIndexing,
419+
store_value: ast.AST,
420+
) -> ast.AST | None:
421+
# Currently support 2D tiles without permutations
422+
if len(indexing.block_shape) != 2 or len(indexing.offsets) != 2:
423+
return None
424+
425+
env = CompileEnvironment.current()
426+
block_m, block_n = indexing.block_shape
427+
try:
428+
block_n_hint = env.size_hint(block_n)
429+
except Exception:
430+
return None
431+
432+
if block_n_hint % 2 != 0:
433+
return None
434+
435+
device_fn = state.device_function
436+
codegen = state.codegen
437+
438+
block_m_str = device_fn.literal_expr(block_m)
439+
block_n_str = device_fn.literal_expr(block_n)
440+
indexing.block_shape[1] //= 2
441+
desc_arg = indexing.tensor_descriptor_arg(state)
442+
443+
if desc_arg.permutation is not None:
444+
return None
445+
446+
447+
block_n_half_str = f"({block_n_str} // 2)"
448+
449+
# Lift the store value into a temporary variable for reuse
450+
acc_var = codegen.lift(store_value, prefix="acc")
451+
452+
reshape_expr = expr_from_string(
453+
"tl.reshape({acc}, [{dim_m}, 2, {dim_half}])",
454+
acc=acc_var,
455+
dim_m=expr_from_string(block_m_str),
456+
dim_half=expr_from_string(block_n_half_str),
457+
)
458+
reshape_var = codegen.lift(reshape_expr, prefix="acc")
459+
460+
permute_expr = expr_from_string(
461+
"tl.permute({acc}, [0, 2, 1])",
462+
acc=reshape_var,
463+
)
464+
permute_var = codegen.lift(permute_expr, prefix="acc")
465+
466+
acc0_name = codegen.tmpvar(prefix="acc")
467+
acc1_name = codegen.tmpvar(prefix="acc")
468+
codegen.add_statement(
469+
statement_from_string(
470+
f"{acc0_name}, {acc1_name} = tl.split({{acc}})",
471+
acc=permute_var,
472+
)
473+
)
474+
acc0 = expr_from_string(acc0_name)
475+
acc1 = expr_from_string(acc1_name)
476+
477+
desc_name = indexing.tensor_descriptor(state)
478+
offset0 = expr_from_string(indexing.offsets[0])
479+
offset1 = expr_from_string(indexing.offsets[1])
480+
481+
# First subtile store
482+
codegen.add_statement(
483+
statement_from_string(
484+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
485+
off0=offset0,
486+
off1=offset1,
487+
value=acc0,
488+
)
489+
)
490+
491+
offset1_shifted = expr_from_string(
492+
"({offset} + {half})",
493+
offset=expr_from_string(indexing.offsets[1]),
494+
half=expr_from_string(block_n_half_str),
495+
)
496+
497+
# Emit second subtile store as the expression returned to the caller
498+
return expr_from_string(
499+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
500+
off0=offset0,
501+
off1=offset1_shifted,
502+
value=acc1,
503+
)
403504

404505
class StackIndexingStrategy:
405506
"""

helion/autotuner/config_spec.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"pid_type",
5353
"indexing",
5454
"load_eviction_policies",
55+
"epilogue_subtiling"
5556
]
5657
)
5758
VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved")
@@ -105,6 +106,7 @@ class ConfigSpec:
105106
EnumFragment(choices=VALID_EVICTION_POLICIES), length=0
106107
)
107108
)
109+
epilogue_subtiling: bool = dataclasses.field(default=False)
108110

109111
@staticmethod
110112
def _valid_indexing_types() -> tuple[IndexingLiteral, ...]:
@@ -224,6 +226,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
224226
config.setdefault(
225227
"load_eviction_policies", self.load_eviction_policies.default()
226228
)
229+
config.setdefault("epilogue_subtiling", False)
227230
# TODO(jansel): include num_ctas and max_nreg
228231

229232
for name, values in (

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
num_stages: int | None = None,
4040
pid_type: PidTypeLiteral | None = None,
4141
indexing: IndexingLiteral | None = None,
42+
epilogue_subtiling: bool | None = None,
4243
# For user-defined properties
4344
**kwargs: object,
4445
) -> None:
@@ -61,6 +62,7 @@ def __init__(
6162
num_stages: Number of stages for software pipelining.
6263
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
6364
indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr").
65+
epilogue_subtiling: Whether to use subtiling for epilogue.
6466
**kwargs: Additional user-defined configuration parameters.
6567
"""
6668
self.config = {}
@@ -81,6 +83,7 @@ def __init__(
8183
"num_stages": num_stages,
8284
"indexing": indexing,
8385
"pid_type": pid_type,
86+
"epilogue_subtiling": epilogue_subtiling,
8487
}
8588
for key, value in core_props.items():
8689
if value is not None:
@@ -206,6 +209,10 @@ def load_eviction_policies(self) -> list[str]:
206209
def indexing(self) -> IndexingLiteral:
207210
return self.config.get("indexing", "pointer") # type: ignore[return-value]
208211

212+
@property
213+
def epilogue_subtiling(self) -> bool:
214+
return self.config.get("epilogue_subtiling", False) # type: ignore[return-value]
215+
209216

210217
def _list_to_tuple(x: object) -> object:
211218
if isinstance(x, list):

0 commit comments

Comments
 (0)