-
Notifications
You must be signed in to change notification settings - Fork 35
TLE
Triton is an operator programming language in the form of a Python DSL. Based on the Block programming concept, it shields hardware details such as storage hierarchy, layout, pipelining, and synchronization, and achieves high-performance operators through compiler optimization. These advantages of Triton have attracted a large number of developers, forming a large community and ecosystem.
However, in recent years, the further development of Triton has encountered some difficulties. On the one hand, the adaptation progress on DSA and new GPU architectures has been relatively slow. On the other hand, compared with some emerging languages such as TileLang, Triton lacks abstractions in fine-grained control of storage hierarchy and parallel granularity, showing some performance disadvantages. In response to the development challenges of Triton, we have proposed TLE (Triton Language Extensions), which extends Triton at three levels to meet the urgent needs of users at different levels for operator programming languages.
TLE is located in the middle layer of the AI ecosystem, with the upper layer connecting to AI frameworks through graph compilers and operator libraries, and the lower layer interfacing with various hardware Runtimes.
TLE-Lite is a lightweight extension of Triton, with all features compatible with various hardware backends , requiring only minor modifications to the original Triton kernels to achieve significant performance improvements. It is mainly targeted at algorithm engineers and scenarios of rapid performance optimization.
TLE-Struct abstracts by clustering according to the hardware architecture, classifies (e.g., GPGPU, DSA) to provide extensions , meeting the needs of further Performance optimization. Developers need to have a certain understanding of the characteristics and optimization techniques of the target hardware.
TLE-Raw provides the most direct control over hardware and can use the native programming language of hardware manufacturers to achieve the ultimate performance. It requires developers to have in-depth knowledge of the target hardware and is mainly targeted at performance optimization experts.
Among them, TLE-Lite and TLE-Struct will ultimately be lowered to LLVM IR via FLIR, while TLE-Raw will be lowered to LLVM IR through the compilation pipeline corresponding to the language (such as the vendor's proprietary compiler). Finally, they will be linked together to jointly generate a complete kernel for the runtime to load and execute.
Design Philosophy: Write once, run anywhere.
Core Concept: By introducing high-level semantic hints rather than mandatory constraints, it guides the compiler to perform heuristic optimization. It emphasizes backward compatibility , allowing developers to achieve cross-platform performance improvement with minimal code intrusion without disrupting the original Triton programming paradigm.
An extension of tl.load that supports asynchronous hints
x = tle.load(..., is_async=True)Design Philosophy: Architecture Awareness, Fine-Tuning.
Core Concept: Based on the topological characteristics of hardware, the backend is divided into clustering such as GPGPU and DSA, exposing a general-purpose hierarchical parallel and storage structure. It allows developers to explicitly define the structured mapping relationship between computation and data (such as Warp Group control and pipeline orchestration), decoupling the algorithm logic from the physical implementation of specific hardware at the abstract level.
Allocate memory.
a_smem = tle.gpu.alloc([XBLOCK, YBLOCK], dtype=tl.float32,
layout=None, scope=tle.gpu.storage_kind.smem)Copy between memory spaces.
tle.gpu.copy(a_ptrs + ystride_a * yoffs[None, :], a_smem, [XBLOCK, YBLOCK])Load tensor from local memory.
aval = tle.gpu.local_load(a_smem)Design Philosophy: Native pass-through, ultimate control.
Core Concept: Break the abstraction boundary of DSL and supportinline vendor native code. It allows directly generating target instructions through the vendor's private compilation pipeline, bypassing the intermediate layer conversion overhead of the general compiler, and giving expert-level usersabsolute controlover instruction scheduling, register allocation, and low-level synchronization primitives.
from typing import Annotated
from mlir import ir
from mlir.dialects import arith, nvvm, tensor
import triton.language as tl
from triton.experimental.flagtree.edsl import dialect
import triton.experimental.flagtree.language as fl
# 1. 方言声明
@tle.raw.language(name="mlir")
# 2. 硬件约束
@tle.hardware_constraint(threads_dim=1, sync_scope="block")
# 3. 函数实现
def vector_add_tile(
x: Annotated[ir.RankedTensorType, "tensor<1024xf32>"],
y: Annotated[ir.RankedTensorType, "tensor<1024xf32>"],
output: Annotated[ir.RankedTensorType, "tensor<1024xf32>"]
):
# 使用 MLIR Python 绑定直接编写底层操作
tidx = nvvm.ThreadIdXOp(ir.IntegerType.get_signless(32)).res
bidx = nvvm.BlockIdXOp(ir.IntegerType.get_signless(32)).res
bdimx = nvvm.BlockDimXOp(ir.IntegerType.get_signless(32)).res
idx = arith.addi(arith.muli(bidx, bdimx), tidx)
idx = arith.index_cast(ir.IndexType.get(), idx)
xval = tensor.extract(x, [idx])
yval = tensor.extract(y, [idx])
result = arith.addf(xval, yval)
tensor.insert(result, output, [idx])
@tle.jit
def add_kernel(
x_ptr, y_ptr, output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
# Tile 语言主体代码
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = tl.zeros_like(x)
# 4. 函数调用
tle.call(
vector_add_tile,
args=[x, y, output],
hardware={
"threads": (BLOCK_SIZE,), # 必须满足 threads_dim=1
},
layout={
x: {"space": "shared", "order": [0]}, # 共享内存,一维布局(用于优化衔接)
y: {"space": "shared", "order": [0]},
output: {"space": "shared", "order": [0]}
}
)
tl.store(output_ptr + offsets, output, mask=mask)Currently, optimization and testing have been carried out on the SparseMLA operator in DSA for RTX 5060Ti, H200 and H800.
TileLang version is v0.1.7.post1
@triton.jit
def triton_sparse_mla_fwd(
q,
kv,
indices,
sm_scale: tl.constexpr,
output,
lse,
stride_qb, stride_qh, stride_qm, stride_qd,
stride_kvb, stride_kvg, stride_kvn, stride_kvd,
stride_tb, stride_tg, stride_tm, stride_tt, # topk,对应indices
stride_ob, stride_oh, stride_om, stride_od,
stride_lb, stride_lh, stride_lm,
B: tl.constexpr,
SQ: tl.constexpr, # seqlen
SKV: tl.constexpr,
K: tl.constexpr, # topk
D: tl.constexpr, # QKV dim
TD: tl.constexpr, # tail dim
DP: tl.constexpr,
TDP: tl.constexpr,
H: tl.constexpr, # q_head_dim
G: tl.constexpr, # group_size
VG: tl.constexpr, # H/G KV groups
BK: tl.constexpr,
BH: tl.constexpr,
# BD: tl.constexpr, # block of output dim
is_causal: tl.constexpr
):
i_b, i_sq, i_gbh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_g, i_bh = i_gbh // G, i_gbh % G
q_base = q + i_b*stride_qb + i_sq*stride_qm + i_gbh*(BH*stride_qh) # 留两个维度,后面逐块载入
tq_base = q_base + D*stride_qd
kv_base = kv + i_b*stride_kvb + i_g*stride_kvg
tkv_base = kv_base + D*stride_kvd
t_base = indices + i_b*stride_tb + i_sq*stride_tm + i_g*stride_tg
o_base = output + i_b*stride_ob + i_sq*stride_om + i_gbh*(BH*stride_oh)
l_base = lse + i_b*stride_lb + i_sq*stride_lm + i_gbh*(BH*stride_lh)
offs_h = tl.arange(0, BH)
offs_d = tl.arange(0, DP)
offs_td = tl.arange(0, TDP)
offs_od = tl.arange(0, DP)
offs_t = tl.arange(0, BK)
mask_h = i_bh * BH + offs_h < G
mask_d = offs_d < D
mask_td = offs_td < TD
mask_od = mask_d
q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
q_msk = mask_h[:, None] & mask_d[None, :]
q_blk = tl.load(q_ptr, q_msk, other=0.0)
tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd
tq_msk = mask_h[:, None] & mask_td[None, :]
tq_blk = tl.load(tq_ptr, tq_msk, other=0.0)
max_log = tl.full([BH], float('-inf'), dtype=tl.bfloat16)
sum_exp = tl.full([BH], 1.0, dtype=tl.float32)
acc = tl.zeros([BH, DP], dtype=tl.float32)
log_scale: tl.constexpr = sm_scale * 1.44269504
max_col = i_sq if is_causal else SQ-1
NK = tl.cdiv(K, BK)
for ck in tl.range(NK, num_stages=0):
if ck * BK <= max_col:
t_ptr = (BK * ck + offs_t) * stride_tt
t_msk = t_ptr < K
t_ptr += t_base
kv_ids = tl.load(t_ptr, t_msk, other=-1)
mask_ids = (kv_ids <= max_col) & (kv_ids >= 0)
kv_ptr = kv_base + offs_d[:, None]*stride_kvd + kv_ids[None, :]*stride_kvn
kv_msk = mask_d[:, None] & mask_ids[None, :]
# kv_blk = tl.load(kv_ptr, kv_msk, other=0.0) #[DP, BK]
# Replace the above line
kv_blk = tle.load(kv_ptr, kv_msk, other=0.0, is_async=True) #[DP, BK]
tkv_ptr = tkv_base + offs_td[:, None]*stride_kvd + kv_ids[None, :]*stride_kvn
tkv_msk = mask_td[:, None] & mask_ids[None, :]
tkv_blk = tl.load(tkv_ptr, tkv_msk, other=0.0) #[TDP, BK]
qk = tl.dot(tq_blk, tkv_blk, out_dtype=tl.float32)
qk = tl.dot(q_blk, kv_blk, qk, out_dtype=tl.float32) * log_scale
qk = tl.where(mask_ids[None, :], qk, float('-inf')) #[BH, BK]
new_max = tl.maximum(max_log, tl.max(qk, axis=1))
exp_qk = tl.math.exp2(qk - new_max[:, None])
sum_qk = tl.sum(exp_qk, axis=1)
alpha = tl.math.exp2(max_log - new_max)
sum_exp = sum_exp*alpha + sum_qk
acc = acc*alpha[:, None]
acc = tl.dot(exp_qk.to(tl.bfloat16), kv_blk.trans(), acc, out_dtype=tl.float32) #[BH, BK] @ [BK, DP] = [BH, DP]
max_log = new_max.to(tl.bfloat16)
out_vals = acc / sum_exp[:, None]
o_ptr = o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od
o_msk = mask_h[:, None] & mask_od[None, :]
tl.store(o_ptr, out_vals.to(q_blk.dtype), o_msk)
fin_log = max_log + tl.math.log2(sum_exp.to(tl.float32)) # 返回 lse / ln2
l_ptr = l_base + offs_h * stride_lh
l_msk = mask_h
tl.store(l_ptr, fin_log.to(q_blk.dtype), l_msk)| TFlops | Triton | TileLang | TLE | TLE over Triton |
|---|---|---|---|---|
| RTX 5060Ti | 30.7 | Not supported | 32.8 | 7% |
| H20 | 81.0 | 110.2 | 93.2 | 15% |
| H800 | 165.5 | 355.0 | 210.6 | 27% |