|  | 
| 15 | 15 | from .. import exc | 
| 16 | 16 | from .._compat import get_tensor_descriptor_fn_name | 
| 17 | 17 | from .ast_extension import expr_from_string | 
|  | 18 | +from .ast_extension import statement_from_string | 
| 18 | 19 | from .compile_environment import CompileEnvironment | 
| 19 | 20 | from .device_function import DeviceFunction | 
| 20 | 21 | from .host_function import HostFunction | 
| @@ -353,7 +354,6 @@ def codegen_load( | 
| 353 | 354 |             ) | 
| 354 | 355 |         assert extra_mask is None | 
| 355 | 356 |         indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) | 
| 356 |  | - | 
| 357 | 357 |         # Load from tensor descriptor with permuted offsets | 
| 358 | 358 |         load_expr = expr_from_string( | 
| 359 | 359 |             f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str_permuted(state)})" | 
| @@ -383,23 +383,119 @@ def codegen_store( | 
| 383 | 383 |             ) | 
| 384 | 384 |         assert extra_mask is None | 
| 385 | 385 |         indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) | 
|  | 386 | +        store_value = indexing.reshape_store(state, value) | 
|  | 387 | + | 
|  | 388 | +        config = DeviceFunction.current().config | 
|  | 389 | +        epilogue_subtiles = state.config.epilogue_subtiling | 
|  | 390 | +        if torch.cuda.get_device_capability() >= (9, 0) and (idx := state.device_function.device_store_index) < len(epilogue_subtiles): | 
|  | 391 | +            subtile_split = epilogue_subtiles[idx] | 
|  | 392 | +            state.device_function.device_store_index += 1 | 
|  | 393 | + | 
|  | 394 | +            subtile_codegen = self._codegen_epilogue_subtile_store(state, fake_tensor, indexing, store_value, subtile_split, config) | 
|  | 395 | +            if subtile_codegen is not None: | 
|  | 396 | +                return subtile_codegen | 
| 386 | 397 | 
 | 
| 387 | 398 |         # Apply permutation to the value being stored if needed | 
| 388 | 399 |         desc_arg = indexing.tensor_descriptor_arg(state) | 
| 389 |  | -        store_value = indexing.reshape_store(state, value) | 
| 390 | 400 | 
 | 
| 391 | 401 |         if desc_arg.permutation is not None: | 
| 392 | 402 |             # Apply permutation to the value | 
| 393 | 403 |             store_value = expr_from_string( | 
| 394 | 404 |                 f"tl.permute({{store_val}}, {desc_arg.permutation!r})", | 
| 395 | 405 |                 store_val=store_value, | 
| 396 | 406 |             ) | 
| 397 |  | - | 
|  | 407 | +         | 
| 398 | 408 |         return expr_from_string( | 
| 399 | 409 |             f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})", | 
| 400 | 410 |             value=store_value, | 
| 401 | 411 |         ) | 
| 402 | 412 | 
 | 
|  | 413 | +    def _codegen_epilogue_subtile_store( | 
|  | 414 | +        self, | 
|  | 415 | +        state: CodegenState, | 
|  | 416 | +        fake_tensor: torch.Tensor, | 
|  | 417 | +        indexing: BlockedSubscriptIndexing, | 
|  | 418 | +        store_value: ast.AST, | 
|  | 419 | +        subtile_split: int, | 
|  | 420 | +        config: Config, | 
|  | 421 | +    ) -> ast.AST | None: | 
|  | 422 | +       # Currently support 2D tiles without permutations | 
|  | 423 | +        if len(indexing.block_shape) != 2 or len(indexing.offsets) != 2 or subtile_split == 0: | 
|  | 424 | +            return None | 
|  | 425 | + | 
|  | 426 | +        env = CompileEnvironment.current() | 
|  | 427 | +        block_m, block_n = indexing.block_shape | 
|  | 428 | +        try: | 
|  | 429 | +            block_n_hint = env.size_hint(block_n) | 
|  | 430 | +            block_idx = env.get_block_id(block_n) | 
|  | 431 | +            block_size = env.block_sizes[block_idx].from_config(config) | 
|  | 432 | +        except Exception: | 
|  | 433 | +            return None | 
|  | 434 | +         | 
|  | 435 | +        if block_n_hint % 2 != 0 or block_size <= 16: | 
|  | 436 | +            return None | 
|  | 437 | + | 
|  | 438 | +        device_fn = state.device_function | 
|  | 439 | +        codegen = state.codegen | 
|  | 440 | + | 
|  | 441 | +        block_m_str = device_fn.literal_expr(block_m) | 
|  | 442 | +        block_n_str = device_fn.literal_expr(block_n) | 
|  | 443 | +        indexing.block_shape[1] //= subtile_split | 
|  | 444 | + | 
|  | 445 | +        desc_arg = indexing.tensor_descriptor_arg(state) | 
|  | 446 | + | 
|  | 447 | +        # TODO: Support more epilogue subtile configs besides 2 | 
|  | 448 | +        block_n_half_str = f"({block_n_str} // {subtile_split})" | 
|  | 449 | + | 
|  | 450 | +        # Lift the store value into a temporary variable for reuse | 
|  | 451 | +        acc_var = codegen.lift(store_value, prefix="acc") | 
|  | 452 | + | 
|  | 453 | +        reshape_expr = expr_from_string( | 
|  | 454 | +            "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)", | 
|  | 455 | +            acc=acc_var, | 
|  | 456 | +            dim_m=expr_from_string(block_m_str), | 
|  | 457 | +            dim_half=expr_from_string(block_n_half_str), | 
|  | 458 | +        ) | 
|  | 459 | +        reshape_var = codegen.lift(reshape_expr, prefix="acc") | 
|  | 460 | + | 
|  | 461 | +        acc0_name = codegen.tmpvar(prefix="acc") | 
|  | 462 | +        acc1_name = codegen.tmpvar(prefix="acc") | 
|  | 463 | +        codegen.add_statement( | 
|  | 464 | +            statement_from_string( | 
|  | 465 | +                f"{acc0_name}, {acc1_name} = tl.split({{acc}})", | 
|  | 466 | +                acc=reshape_var, | 
|  | 467 | +            ) | 
|  | 468 | +        ) | 
|  | 469 | +        acc0 = expr_from_string(acc0_name) | 
|  | 470 | +        acc1 = expr_from_string(acc1_name) | 
|  | 471 | + | 
|  | 472 | +        desc_name = indexing.tensor_descriptor(state) | 
|  | 473 | +        offset0 = expr_from_string(indexing.offsets[0]) | 
|  | 474 | +        offset1 = expr_from_string(indexing.offsets[1]) | 
|  | 475 | + | 
|  | 476 | +        # First subtile store | 
|  | 477 | +        codegen.add_statement( | 
|  | 478 | +            statement_from_string( | 
|  | 479 | +                f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", | 
|  | 480 | +                off0=offset0, | 
|  | 481 | +                off1=offset1, | 
|  | 482 | +                value=acc0, | 
|  | 483 | +            ) | 
|  | 484 | +        ) | 
|  | 485 | + | 
|  | 486 | +        offset1_shifted = expr_from_string( | 
|  | 487 | +            "({offset} + {half})", | 
|  | 488 | +            offset=expr_from_string(indexing.offsets[1]), | 
|  | 489 | +            half=expr_from_string(block_n_half_str), | 
|  | 490 | +        ) | 
|  | 491 | + | 
|  | 492 | +        # Emit second subtile store as the expression returned to the caller | 
|  | 493 | +        return expr_from_string( | 
|  | 494 | +            f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", | 
|  | 495 | +            off0=offset0, | 
|  | 496 | +            off1=offset1_shifted, | 
|  | 497 | +            value=acc1, | 
|  | 498 | +        ) | 
| 403 | 499 | 
 | 
| 404 | 500 | class StackIndexingStrategy: | 
| 405 | 501 |     """ | 
|  | 
0 commit comments