|
15 | 15 | from .. import exc |
16 | 16 | from .._compat import get_tensor_descriptor_fn_name |
17 | 17 | from .ast_extension import expr_from_string |
| 18 | +from .ast_extension import statement_from_string |
18 | 19 | from .compile_environment import CompileEnvironment |
19 | 20 | from .device_function import DeviceFunction |
20 | 21 | from .host_function import HostFunction |
@@ -384,22 +385,116 @@ def codegen_store( |
384 | 385 | assert extra_mask is None |
385 | 386 | indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) |
386 | 387 |
|
| 388 | + config = DeviceFunction.current().config |
| 389 | + store_value = indexing.reshape_store(state, value) |
| 390 | + if config.epilogue_subtiling: |
| 391 | + return self._codegen_epilogue_subtile_store(state, fake_tensor, indexing, store_value) |
| 392 | + |
387 | 393 | # Apply permutation to the value being stored if needed |
388 | 394 | desc_arg = indexing.tensor_descriptor_arg(state) |
389 | | - store_value = indexing.reshape_store(state, value) |
390 | 395 |
|
391 | 396 | if desc_arg.permutation is not None: |
392 | 397 | # Apply permutation to the value |
393 | 398 | store_value = expr_from_string( |
394 | 399 | f"tl.permute({{store_val}}, {desc_arg.permutation!r})", |
395 | 400 | store_val=store_value, |
396 | 401 | ) |
397 | | - |
| 402 | + |
398 | 403 | return expr_from_string( |
399 | 404 | f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})", |
400 | 405 | value=store_value, |
401 | 406 | ) |
402 | 407 |
|
| 408 | + def _codegen_epilogue_subtile_store( |
| 409 | + self, |
| 410 | + state: CodegenState, |
| 411 | + fake_tensor: torch.Tensor, |
| 412 | + indexing: BlockedSubscriptIndexing, |
| 413 | + store_value: ast.AST, |
| 414 | + ) -> ast.AST | None: |
| 415 | + # Currently support 2D tiles without permutations |
| 416 | + if len(indexing.block_shape) != 2 or len(indexing.offsets) != 2: |
| 417 | + return None |
| 418 | + |
| 419 | + env = CompileEnvironment.current() |
| 420 | + block_m, block_n = indexing.block_shape |
| 421 | + try: |
| 422 | + block_n_hint = env.size_hint(block_n) |
| 423 | + except Exception: |
| 424 | + return None |
| 425 | + |
| 426 | + if block_n_hint % 2 != 0: |
| 427 | + return None |
| 428 | + |
| 429 | + device_fn = state.device_function |
| 430 | + codegen = state.codegen |
| 431 | + |
| 432 | + block_m_str = device_fn.literal_expr(block_m) |
| 433 | + block_n_str = device_fn.literal_expr(block_n) |
| 434 | + indexing.block_shape[1] //= 2 |
| 435 | + desc_arg = indexing.tensor_descriptor_arg(state) |
| 436 | + |
| 437 | + if desc_arg.permutation is not None: |
| 438 | + return None |
| 439 | + |
| 440 | + |
| 441 | + block_n_half_str = f"({block_n_str} // 2)" |
| 442 | + |
| 443 | + # Lift the store value into a temporary variable for reuse |
| 444 | + acc_var = codegen.lift(store_value, prefix="acc") |
| 445 | + |
| 446 | + reshape_expr = expr_from_string( |
| 447 | + "tl.reshape({acc}, [{dim_m}, 2, {dim_half}])", |
| 448 | + acc=acc_var, |
| 449 | + dim_m=expr_from_string(block_m_str), |
| 450 | + dim_half=expr_from_string(block_n_half_str), |
| 451 | + ) |
| 452 | + reshape_var = codegen.lift(reshape_expr, prefix="acc") |
| 453 | + |
| 454 | + permute_expr = expr_from_string( |
| 455 | + "tl.permute({acc}, [0, 2, 1])", |
| 456 | + acc=reshape_var, |
| 457 | + ) |
| 458 | + permute_var = codegen.lift(permute_expr, prefix="acc") |
| 459 | + |
| 460 | + acc0_name = codegen.tmpvar(prefix="acc") |
| 461 | + acc1_name = codegen.tmpvar(prefix="acc") |
| 462 | + codegen.add_statement( |
| 463 | + statement_from_string( |
| 464 | + f"{acc0_name}, {acc1_name} = tl.split({{acc}})", |
| 465 | + acc=permute_var, |
| 466 | + ) |
| 467 | + ) |
| 468 | + acc0 = expr_from_string(acc0_name) |
| 469 | + acc1 = expr_from_string(acc1_name) |
| 470 | + |
| 471 | + desc_name = indexing.tensor_descriptor(state) |
| 472 | + offset0 = expr_from_string(indexing.offsets[0]) |
| 473 | + offset1 = expr_from_string(indexing.offsets[1]) |
| 474 | + |
| 475 | + # First subtile store |
| 476 | + codegen.add_statement( |
| 477 | + statement_from_string( |
| 478 | + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", |
| 479 | + off0=offset0, |
| 480 | + off1=offset1, |
| 481 | + value=acc0, |
| 482 | + ) |
| 483 | + ) |
| 484 | + |
| 485 | + offset1_shifted = expr_from_string( |
| 486 | + "({offset} + {half})", |
| 487 | + offset=expr_from_string(indexing.offsets[1]), |
| 488 | + half=expr_from_string(block_n_half_str), |
| 489 | + ) |
| 490 | + |
| 491 | + # Emit second subtile store as the expression returned to the caller |
| 492 | + return expr_from_string( |
| 493 | + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", |
| 494 | + off0=offset0, |
| 495 | + off1=offset1_shifted, |
| 496 | + value=acc1, |
| 497 | + ) |
403 | 498 |
|
404 | 499 | class StackIndexingStrategy: |
405 | 500 | """ |
|
0 commit comments