|
15 | 15 | from .. import exc |
16 | 16 | from .._compat import get_tensor_descriptor_fn_name |
17 | 17 | from .ast_extension import expr_from_string |
| 18 | +from .ast_extension import statement_from_string |
18 | 19 | from .compile_environment import CompileEnvironment |
19 | 20 | from .device_function import DeviceFunction |
20 | 21 | from .host_function import HostFunction |
21 | 22 | from .tile_strategy import DeviceLoopState |
| 23 | +from .utils import _allow_epilogue_subtiling |
22 | 24 | from .utils import compute_slice_size |
23 | 25 | from .variable_origin import BlockSizeOrigin |
24 | 26 |
|
@@ -352,7 +354,6 @@ def codegen_load( |
352 | 354 | ) |
353 | 355 | assert extra_mask is None |
354 | 356 | indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) |
355 | | - |
356 | 357 | # Load from tensor descriptor with permuted offsets |
357 | 358 | load_expr = expr_from_string( |
358 | 359 | f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str_permuted(state)})" |
@@ -382,23 +383,188 @@ def codegen_store( |
382 | 383 | ) |
383 | 384 | assert extra_mask is None |
384 | 385 | indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) |
| 386 | + store_value = indexing.reshape_store(state, value) |
385 | 387 |
|
| 388 | + config = DeviceFunction.current().config |
| 389 | + epilogue_subtiles = state.config.epilogue_subtiling |
386 | 390 | # Apply permutation to the value being stored if needed |
387 | 391 | desc_arg = indexing.tensor_descriptor_arg(state) |
388 | | - store_value = indexing.reshape_store(state, value) |
389 | 392 |
|
390 | 393 | if desc_arg.permutation is not None: |
391 | 394 | # Apply permutation to the value |
392 | 395 | store_value = expr_from_string( |
393 | 396 | f"tl.permute({{store_val}}, {desc_arg.permutation!r})", |
394 | 397 | store_val=store_value, |
395 | 398 | ) |
| 399 | + |
| 400 | + if _allow_epilogue_subtiling() and ( |
| 401 | + idx := state.device_function.device_store_index |
| 402 | + ) <= len(epilogue_subtiles): |
| 403 | + subtile_split = epilogue_subtiles[idx - 1] |
| 404 | + |
| 405 | + subtile_codegen = self._codegen_epilogue_subtile_store( |
| 406 | + state, |
| 407 | + fake_tensor, |
| 408 | + indexing, |
| 409 | + store_value, |
| 410 | + subtile_split, |
| 411 | + config, |
| 412 | + ) |
| 413 | + if subtile_codegen is not None: |
| 414 | + return subtile_codegen |
| 415 | + |
| 416 | + if "pointwise_in" in state.fx_node.meta: |
| 417 | + # We still need to codegen pointwise if subtile_codegen is None |
| 418 | + store_value = self._apply_pointwise_to_subtile( |
| 419 | + state, state.fx_node.meta["pointwise_in"], store_value |
| 420 | + ) |
396 | 421 |
|
397 | 422 | return expr_from_string( |
398 | 423 | f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})", |
399 | 424 | value=store_value, |
400 | 425 | ) |
401 | 426 |
|
| 427 | + def _apply_pointwise_to_subtile( |
| 428 | + self, state: CodegenState, pointwise_node: torch.fx.Node, subtile_value: ast.AST |
| 429 | + ) -> ast.AST: |
| 430 | + """Apply a pointwise operation to a subtile value. |
| 431 | +
|
| 432 | + Args: |
| 433 | + state: The codegen state |
| 434 | + pointwise_node: The FX node representing the pointwise operation |
| 435 | + subtile_value: The AST for the subtile value to apply the operation to |
| 436 | +
|
| 437 | + Returns: |
| 438 | + AST for the result after applying the pointwise operation |
| 439 | + """ |
| 440 | + from torch._inductor import ir |
| 441 | + |
| 442 | + from .inductor_lowering import PointwiseLowering |
| 443 | + from .inductor_lowering import install_inductor_kernel_handlers |
| 444 | + |
| 445 | + lowering = pointwise_node.meta["lowering"] |
| 446 | + assert isinstance(lowering, PointwiseLowering) |
| 447 | + |
| 448 | + # Get the pointwise buffer |
| 449 | + buffer = lowering.buffer |
| 450 | + assert isinstance(buffer.data, ir.Pointwise) |
| 451 | + |
| 452 | + # Create a temporary variable for the subtile |
| 453 | + codegen = state.codegen |
| 454 | + subtile_var = codegen.lift(subtile_value, prefix="subtile") |
| 455 | + |
| 456 | + # Set up the inductor kernel handlers with the subtile as input |
| 457 | + with install_inductor_kernel_handlers( |
| 458 | + codegen, {lowering.input_names[0]: subtile_var} |
| 459 | + ): |
| 460 | + # Generate the pointwise operation |
| 461 | + indices = [sympy.Symbol(f"i{n}") for n in range(len(buffer.data.ranges))] |
| 462 | + from .inductor_lowering import _unpack_opsvalue |
| 463 | + |
| 464 | + result_name = _unpack_opsvalue(buffer.data.inner_fn(indices)) |
| 465 | + return expr_from_string(result_name) |
| 466 | + |
| 467 | + def _codegen_epilogue_subtile_store( |
| 468 | + self, |
| 469 | + state: CodegenState, |
| 470 | + fake_tensor: torch.Tensor, |
| 471 | + indexing: BlockedSubscriptIndexing, |
| 472 | + store_value: ast.AST, |
| 473 | + subtile_split: int, |
| 474 | + config: Config, |
| 475 | + ) -> ast.AST | None: |
| 476 | + env = CompileEnvironment.current() |
| 477 | + block_m, block_n = indexing.block_shape |
| 478 | + block_n_hint = env.size_hint(block_n) |
| 479 | + block_idx = env.get_block_id(block_n) |
| 480 | + block_size = env.block_sizes[block_idx].from_config(config) |
| 481 | + |
| 482 | + if "pointwise_in" in state.fx_node.meta: |
| 483 | + fused_pointwise_node = state.fx_node.meta["pointwise_in"] |
| 484 | + assert fused_pointwise_node == state.fx_node.args[2] |
| 485 | + else: |
| 486 | + fused_pointwise_node = None |
| 487 | + |
| 488 | + # Currently support 2D tiles without permutations |
| 489 | + if ( |
| 490 | + len(indexing.block_shape) != 2 |
| 491 | + or len(indexing.offsets) != 2 |
| 492 | + or subtile_split == 0 |
| 493 | + or block_n_hint % 2 != 0 |
| 494 | + or block_size <= 16 |
| 495 | + ): |
| 496 | + return None |
| 497 | + |
| 498 | + device_fn = state.device_function |
| 499 | + codegen = state.codegen |
| 500 | + |
| 501 | + block_m_str = device_fn.literal_expr(block_m) |
| 502 | + block_n_str = device_fn.literal_expr(block_n) |
| 503 | + indexing.block_shape[1] //= subtile_split |
| 504 | + |
| 505 | + # TODO(PaulZhang12): Support more epilogue subtile configs besides 2 |
| 506 | + block_n_half_str = f"({block_n_str} // {subtile_split})" |
| 507 | + |
| 508 | + # Lift the store value into a temporary variable for reuse |
| 509 | + acc_var = codegen.lift(store_value, prefix="acc") |
| 510 | + |
| 511 | + reshape_expr = expr_from_string( |
| 512 | + "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)", |
| 513 | + acc=acc_var, |
| 514 | + dim_m=expr_from_string(block_m_str), |
| 515 | + dim_half=expr_from_string(block_n_half_str), |
| 516 | + ) |
| 517 | + reshape_var = codegen.lift(reshape_expr, prefix="acc") |
| 518 | + |
| 519 | + acc0_name = codegen.tmpvar(prefix="acc") |
| 520 | + acc1_name = codegen.tmpvar(prefix="acc") |
| 521 | + codegen.add_statement( |
| 522 | + statement_from_string( |
| 523 | + f"{acc0_name}, {acc1_name} = tl.split({{acc}})", |
| 524 | + acc=reshape_var, |
| 525 | + ) |
| 526 | + ) |
| 527 | + |
| 528 | + # Now apply the pointwise operation per-subtile if we have one |
| 529 | + if fused_pointwise_node is not None: |
| 530 | + acc0 = self._apply_pointwise_to_subtile( |
| 531 | + state, fused_pointwise_node, expr_from_string(acc0_name) |
| 532 | + ) |
| 533 | + acc1 = self._apply_pointwise_to_subtile( |
| 534 | + state, fused_pointwise_node, expr_from_string(acc1_name) |
| 535 | + ) |
| 536 | + else: |
| 537 | + acc0 = expr_from_string(acc0_name) |
| 538 | + acc1 = expr_from_string(acc1_name) |
| 539 | + |
| 540 | + desc_name = indexing.tensor_descriptor(state) |
| 541 | + offset0 = expr_from_string(indexing.offsets[0]) |
| 542 | + offset1 = expr_from_string(indexing.offsets[1]) |
| 543 | + |
| 544 | + # First subtile store |
| 545 | + codegen.add_statement( |
| 546 | + statement_from_string( |
| 547 | + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", |
| 548 | + off0=offset0, |
| 549 | + off1=offset1, |
| 550 | + value=acc0, |
| 551 | + ) |
| 552 | + ) |
| 553 | + |
| 554 | + offset1_shifted = expr_from_string( |
| 555 | + "({offset} + {half})", |
| 556 | + offset=expr_from_string(indexing.offsets[1]), |
| 557 | + half=expr_from_string(block_n_half_str), |
| 558 | + ) |
| 559 | + |
| 560 | + # Emit second subtile store as the expression returned to the caller |
| 561 | + return expr_from_string( |
| 562 | + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", |
| 563 | + off0=offset0, |
| 564 | + off1=offset1_shifted, |
| 565 | + value=acc1, |
| 566 | + ) |
| 567 | + |
402 | 568 |
|
403 | 569 | class StackIndexingStrategy: |
404 | 570 | """ |
|
0 commit comments