-
Notifications
You must be signed in to change notification settings - Fork 35
Hints
FlagTree introduces a non-invasive performance hints injection mechanism that enables hardware-aware optimizations while maintaining full compatibility with native Triton code. The mechanism is simple: programmers add inline comments (#@hint: <hint_name>) to the corresponding Triton operations (e.g., tl.load) to provide hardware-aware optimization hints. These hints are encoded as MLIR attributes during compilation, enabling the mid-end and backend to apply hardware-aware optimizations and multi-platform dynamic adaptation based on an elastic verification strategy.
This mechanism has the following characteristics:
-
Native Compatibility: Hints are optional—kernels remain valid Triton and run correctly with the original Triton compiler.
-
Low Learning Overhead: Hints are added via lightweight comments (
flagtree_hints) without changing core Triton syntax. -
Stronger Compiler Extensibility: New optimizations can be introduced by evolving hint schemas and MLIR attributes, avoiding language-level op/syntax extensions.
-
Stronger Performance Capability: Hardware-aware hints unlock additional compiler optimizations to better utilize hardware features.
Please refer to the README in the corresponding branch for installation instructions.
Hints are added as inline comments in Triton code using the #@hint: prefix. Place the hint comment on the same line as the Triton operation:
import triton
import triton.language as tl
@triton.jit
def kernel(x_ptr, y_ptr, N):
pid = tl.program_id(0)
x = tl.load(x_ptr + pid) #@hint: shared_memory
y = x + 1
tl.store(y_ptr + pid, y)The format is #@hint: <hint_name>, where <hint_name> is the specific hint supported by your backend. See the Appendix for available hints per backend.
TBD
-
Lowering and TLE Integration: Align the hints lowering pipeline with TLE, so that hint-driven transformations can seamlessly compose with existing TLE passes and codegen.
-
Unified Registration and Lowering Interface: Introduce a backend-agnostic registry for hint definitions, verification, and lowering hooks. This provides a single place to manage per-backend hint support, schemas, and lowering behavior.
-
Hint AST for Richer Expressions: Design a structured hint AST that allows comments to reference variables (e.g., compile-time constants or derived values) instead of only raw strings, improving expressiveness while keeping the source code Triton-compatible.
-
Debuggability: Add optional debug traces and diagnostics for hint parsing, propagation, and lowering, controlled via environment variables to keep the default path clean.
| Hint Name | TL Operation | Description | Branch |
|---|---|---|---|
shared_memory |
tl.load |
Converts a global memory load operation to an asynchronous copy to shared memory, then loads from shared memory. The load must be at least 4 bytes and convertible to async load. | triton_v3.5.x |
| Hint Name | TL Operation | Description | Branch |
|---|---|---|---|
dot_pad_only_k |
tl.load |
Optimizes matrix multiplication performance by padding only the K dimension for dot operations. Equivalent to tl.compile_hint(tensor, "dot_pad_only_k") in triton-ascend. |
triton_v3.2.x_ascend_hints |
bind_sub_block |
for loop |
Optimizes parallel execution by binding loop iterations to sub-blocks. Equivalent to tl.parallel(..., bind_sub_block=True) in triton-ascend. |
triton_v3.2.x_ascend_hints |
multibuffer |
tl.load |
Enables multi-buffering optimization to overlap data transfer and computation (fixed to 2 buffer copies). Equivalent to tl.multibuffer(tensor, 2) or tl.compile_hint(tensor, "multi_buffer", 2) in triton-ascend. |
triton_v3.2.x_ascend_hints |
| Hint Name | TL Operation | Description | Branch |
|---|---|---|---|
dma |
tl.load |
Enables asynchronous DMA transfers for improved performance. Lowers memref.copy operations to memref.dma_start and memref.dma_wait operations. Requires stride-1 memory access patterns. |
triton_v3.3.x |
shared_memory |
tl.load |
Loads data into shared memory for faster access. Allocates shared memory (memory space 8) with size 4x (for AIPUcore parallel) the original tensor shape and copies data from global memory. | triton_v3.3.x |