Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ The table below lists the operators ordered by their backend.
|--------------|-------------|
| `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer attention with KV cache support |
| `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE (Rotary Position Embedding) implementation |
| `torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa` | Grouped SDPA (Scaled Dot Product Attention) with BSND format |
| `torch.ops.auto_deploy.torch_attention_deepseek_fused_mla` | DeepSeek fused MLA (Multi-head Linear Attention) |
| `torch.ops.auto_deploy.torch_attention_deepseek_mla` | DeepSeek MLA implementation |
| `torch.ops.auto_deploy.torch_attention_grouped_sdpa` | Grouped SDPA implementation |
| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported |
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def get_num_qkv_args(cls) -> int:
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
"""Get the source attention op that we target for replacement."""
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
return torch.ops.auto_deploy.torch_attention

@classmethod
def get_cached_attention_op(cls) -> MHACallable:
Expand Down Expand Up @@ -399,6 +399,21 @@ def _init_workspace(si: SequenceInfo) -> torch.Tensor:

@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
# Sanity check: layout == "bsnd"
# Prefer kwargs; fall back to the final positional arg if it's a string.
layout = source_attn_node.kwargs.get("layout", None)
if (
layout is None
and len(source_attn_node.args) > 0
and isinstance(source_attn_node.args[-1], str)
):
layout = source_attn_node.args[-1]
if layout != "bsnd":
raise RuntimeError(
f"Expected torch_attention layout='bsnd' but got {layout!r} "
f"for node: {source_attn_node.format_node()}"
)

# Double check other arguments
attn_mask, dropout_p, is_causal = extract_op_args(
source_attn_node, "attn_mask", "dropout_p", "is_causal"
Expand Down
84 changes: 28 additions & 56 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ def scaled_dot_product_attention_fake(
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()


@torch.library.custom_op("auto_deploy::torch_attention_grouped_sdpa", mutates_args=())
def grouped_sdpa(
# Unified attention op
@torch.library.custom_op("auto_deploy::torch_attention", mutates_args=())
def torch_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -103,8 +104,25 @@ def grouped_sdpa(
sinks: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
logit_cap: Optional[float] = None,
layout: str = "bnsd", # "bnsd" or "bsnd"
) -> torch.Tensor:
"""SDPA attention that can handle GQA. Expects bnsd format inputs."""
"""
SDPA attention (with optional GQA) that supports two memory layouts via `layout`:
- "bnsd": [batch, num_heads, seq_len, head_dim]
- "bsnd": [batch, seq_len, num_heads, head_dim]

The `attn_mask` is always interpreted as [b, n, s_q, s_k].

Returns a tensor in the SAME layout as inputs specified by `layout`.
"""
if layout not in ("bnsd", "bsnd"):
raise ValueError(f"layout must be 'bnsd' or 'bsnd', got {layout!r}")

if layout == "bsnd":
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()

b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim]
_, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]

Expand Down Expand Up @@ -188,72 +206,26 @@ def grouped_sdpa(
if dropout_p > 0.0:
attn_out = F.dropout(attn_out, p=dropout_p, training=False)

# Return in bnsd format (same as input format)
return attn_out
if layout == "bsnd":
return attn_out.transpose(1, 2).contiguous()
else:
return attn_out.contiguous()


@grouped_sdpa.register_fake
def grouped_sdpa_fake(
@torch_attention.register_fake
def torch_attention_fake(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
sinks=None,
sliding_window=None,
logit_cap=None,
):
"""Fake implementation of grouped SDPA."""
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()


@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=())
def bsnd_grouped_sdpa(
query: torch.Tensor, # layout: [b, s_q, n, d]
key: torch.Tensor, # layout: [b, s_k, n, d]
value: torch.Tensor, # layout: [b, s_k, n, d]
attn_mask: Optional[torch.Tensor] = None, # layout: [b, n, s_q, s_k]
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
sinks: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
logit_cap: Optional[float] = None,
) -> torch.Tensor:
"""Attention that assumes the input layout is bsnd.

Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the
original sdpa op!
"""
# Transpose inputs to bnsd format for grouped_sdpa
query = query.transpose(1, 2).contiguous() # [b, s_q, n, d] -> [b, n, s_q, d]
key = key.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
value = value.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]

# Call grouped_sdpa with bnsd inputs
out = grouped_sdpa(
query, key, value, attn_mask, dropout_p, is_causal, scale, sinks, sliding_window, logit_cap
)
# Transpose back to bsnd format
return out.transpose(1, 2).contiguous()


@bsnd_grouped_sdpa.register_fake
def bsnd_grouped_sdpa_fake(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
sinks=None,
sliding_window=None,
logit_cap=None,
layout: str = "bnsd",
):
"""Fake implementation of bnsd grouped SDPA."""
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def get_num_qkv_args(cls) -> int:

@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
return torch.ops.auto_deploy.torch_attention

@classmethod
def get_cached_attention_op(cls) -> MHACallable:
Expand Down Expand Up @@ -460,6 +460,21 @@ def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitial

@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
# Sanity check: layout == "bsnd"
# Prefer kwargs; fall back to the final positional arg if it's a string.
layout = source_attn_node.kwargs.get("layout", None)
if (
layout is None
and len(source_attn_node.args) > 0
and isinstance(source_attn_node.args[-1], str)
):
layout = source_attn_node.args[-1]
if layout != "bsnd":
raise RuntimeError(
f"Expected torch_attention layout='bsnd' but got {layout!r} "
f"for node: {source_attn_node.format_node()}"
)

# Check other arguments
attn_mask, dropout_p, is_causal = extract_op_args(
source_attn_node, "attn_mask", "dropout_p", "is_causal"
Expand Down
17 changes: 16 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def get_num_qkv_args(cls) -> int:

@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
return torch.ops.auto_deploy.torch_attention

@classmethod
def get_cached_attention_op(cls) -> MHACallable:
Expand Down Expand Up @@ -390,6 +390,21 @@ def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitial

@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
# Sanity check: layout == "bsnd"
# Prefer kwargs; fall back to the final positional arg if it's a string.
layout = source_attn_node.kwargs.get("layout", None)
if (
layout is None
and len(source_attn_node.args) > 0
and isinstance(source_attn_node.args[-1], str)
):
layout = source_attn_node.args[-1]
if layout != "bsnd":
raise RuntimeError(
f"Expected torch_attention layout='bsnd' but got {layout!r} "
f"for node: {source_attn_node.format_node()}"
)

# retrieve head_dim from k_fake
attn_mask, dropout_p, is_causal = extract_op_args(
source_attn_node, "attn_mask", "dropout_p", "is_causal"
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def gpt_oss_attention(
sinks = self.sinks

# Use custom op to capture attention. This layout is bsnd (batch, seq, num_heads, head_dim)
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(
attn_output = torch.ops.auto_deploy.torch_attention(
query_states,
key_states,
value_states,
Expand All @@ -69,6 +69,7 @@ def gpt_oss_attention(
scale=self.scaling,
sinks=sinks,
sliding_window=sliding_window,
layout="bsnd",
)

# Reshape back to original input shape
Expand Down
Loading