| 
15 | 15 | from .. import exc  | 
16 | 16 | from .._compat import get_tensor_descriptor_fn_name  | 
17 | 17 | from .ast_extension import expr_from_string  | 
 | 18 | +from .ast_extension import statement_from_string  | 
18 | 19 | from .compile_environment import CompileEnvironment  | 
19 | 20 | from .device_function import DeviceFunction  | 
20 | 21 | from .host_function import HostFunction  | 
@@ -353,7 +354,6 @@ def codegen_load(  | 
353 | 354 |             )  | 
354 | 355 |         assert extra_mask is None  | 
355 | 356 |         indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)  | 
356 |  | - | 
357 | 357 |         # Load from tensor descriptor with permuted offsets  | 
358 | 358 |         load_expr = expr_from_string(  | 
359 | 359 |             f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str_permuted(state)})"  | 
@@ -383,23 +383,119 @@ def codegen_store(  | 
383 | 383 |             )  | 
384 | 384 |         assert extra_mask is None  | 
385 | 385 |         indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)  | 
 | 386 | +        store_value = indexing.reshape_store(state, value)  | 
 | 387 | + | 
 | 388 | +        config = DeviceFunction.current().config  | 
 | 389 | +        epilogue_subtiles = state.config.epilogue_subtiling  | 
 | 390 | +        if torch.cuda.get_device_capability() >= (9, 0) and (idx := state.device_function.device_store_index) < len(epilogue_subtiles):  | 
 | 391 | +            subtile_split = epilogue_subtiles[idx]  | 
 | 392 | +            state.device_function.device_store_index += 1  | 
 | 393 | + | 
 | 394 | +            subtile_codegen = self._codegen_epilogue_subtile_store(state, fake_tensor, indexing, store_value, subtile_split, config)  | 
 | 395 | +            if subtile_codegen is not None:  | 
 | 396 | +                return subtile_codegen  | 
386 | 397 | 
 
  | 
387 | 398 |         # Apply permutation to the value being stored if needed  | 
388 | 399 |         desc_arg = indexing.tensor_descriptor_arg(state)  | 
389 |  | -        store_value = indexing.reshape_store(state, value)  | 
390 | 400 | 
 
  | 
391 | 401 |         if desc_arg.permutation is not None:  | 
392 | 402 |             # Apply permutation to the value  | 
393 | 403 |             store_value = expr_from_string(  | 
394 | 404 |                 f"tl.permute({{store_val}}, {desc_arg.permutation!r})",  | 
395 | 405 |                 store_val=store_value,  | 
396 | 406 |             )  | 
397 |  | - | 
 | 407 | +          | 
398 | 408 |         return expr_from_string(  | 
399 | 409 |             f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})",  | 
400 | 410 |             value=store_value,  | 
401 | 411 |         )  | 
402 | 412 | 
 
  | 
 | 413 | +    def _codegen_epilogue_subtile_store(  | 
 | 414 | +        self,  | 
 | 415 | +        state: CodegenState,  | 
 | 416 | +        fake_tensor: torch.Tensor,  | 
 | 417 | +        indexing: BlockedSubscriptIndexing,  | 
 | 418 | +        store_value: ast.AST,  | 
 | 419 | +        subtile_split: int,  | 
 | 420 | +        config: Config,  | 
 | 421 | +    ) -> ast.AST | None:  | 
 | 422 | +       # Currently support 2D tiles without permutations  | 
 | 423 | +        if len(indexing.block_shape) != 2 or len(indexing.offsets) != 2 or subtile_split == 0:  | 
 | 424 | +            return None  | 
 | 425 | + | 
 | 426 | +        env = CompileEnvironment.current()  | 
 | 427 | +        block_m, block_n = indexing.block_shape  | 
 | 428 | +        try:  | 
 | 429 | +            block_n_hint = env.size_hint(block_n)  | 
 | 430 | +            block_idx = env.get_block_id(block_n)  | 
 | 431 | +            block_size = env.block_sizes[block_idx].from_config(config)  | 
 | 432 | +        except Exception:  | 
 | 433 | +            return None  | 
 | 434 | +          | 
 | 435 | +        if block_n_hint % 2 != 0 or block_size <= 16:  | 
 | 436 | +            return None  | 
 | 437 | + | 
 | 438 | +        device_fn = state.device_function  | 
 | 439 | +        codegen = state.codegen  | 
 | 440 | + | 
 | 441 | +        block_m_str = device_fn.literal_expr(block_m)  | 
 | 442 | +        block_n_str = device_fn.literal_expr(block_n)  | 
 | 443 | +        indexing.block_shape[1] //= subtile_split  | 
 | 444 | + | 
 | 445 | +        desc_arg = indexing.tensor_descriptor_arg(state)  | 
 | 446 | + | 
 | 447 | +        # TODO: Support more epilogue subtile configs besides 2  | 
 | 448 | +        block_n_half_str = f"({block_n_str} // {subtile_split})"  | 
 | 449 | + | 
 | 450 | +        # Lift the store value into a temporary variable for reuse  | 
 | 451 | +        acc_var = codegen.lift(store_value, prefix="acc")  | 
 | 452 | + | 
 | 453 | +        reshape_expr = expr_from_string(  | 
 | 454 | +            "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)",  | 
 | 455 | +            acc=acc_var,  | 
 | 456 | +            dim_m=expr_from_string(block_m_str),  | 
 | 457 | +            dim_half=expr_from_string(block_n_half_str),  | 
 | 458 | +        )  | 
 | 459 | +        reshape_var = codegen.lift(reshape_expr, prefix="acc")  | 
 | 460 | + | 
 | 461 | +        acc0_name = codegen.tmpvar(prefix="acc")  | 
 | 462 | +        acc1_name = codegen.tmpvar(prefix="acc")  | 
 | 463 | +        codegen.add_statement(  | 
 | 464 | +            statement_from_string(  | 
 | 465 | +                f"{acc0_name}, {acc1_name} = tl.split({{acc}})",  | 
 | 466 | +                acc=reshape_var,  | 
 | 467 | +            )  | 
 | 468 | +        )  | 
 | 469 | +        acc0 = expr_from_string(acc0_name)  | 
 | 470 | +        acc1 = expr_from_string(acc1_name)  | 
 | 471 | + | 
 | 472 | +        desc_name = indexing.tensor_descriptor(state)  | 
 | 473 | +        offset0 = expr_from_string(indexing.offsets[0])  | 
 | 474 | +        offset1 = expr_from_string(indexing.offsets[1])  | 
 | 475 | + | 
 | 476 | +        # First subtile store  | 
 | 477 | +        codegen.add_statement(  | 
 | 478 | +            statement_from_string(  | 
 | 479 | +                f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",  | 
 | 480 | +                off0=offset0,  | 
 | 481 | +                off1=offset1,  | 
 | 482 | +                value=acc0,  | 
 | 483 | +            )  | 
 | 484 | +        )  | 
 | 485 | + | 
 | 486 | +        offset1_shifted = expr_from_string(  | 
 | 487 | +            "({offset} + {half})",  | 
 | 488 | +            offset=expr_from_string(indexing.offsets[1]),  | 
 | 489 | +            half=expr_from_string(block_n_half_str),  | 
 | 490 | +        )  | 
 | 491 | + | 
 | 492 | +        # Emit second subtile store as the expression returned to the caller  | 
 | 493 | +        return expr_from_string(  | 
 | 494 | +            f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",  | 
 | 495 | +            off0=offset0,  | 
 | 496 | +            off1=offset1_shifted,  | 
 | 497 | +            value=acc1,  | 
 | 498 | +        )  | 
403 | 499 | 
 
  | 
404 | 500 | class StackIndexingStrategy:  | 
405 | 501 |     """  | 
 | 
0 commit comments