diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md index 74258bbcd27..5a65d610493 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md @@ -12,10 +12,9 @@ The table below lists the operators ordered by their backend. |--------------|-------------| | `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer attention with KV cache support | | `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE (Rotary Position Embedding) implementation | -| `torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa` | Grouped SDPA (Scaled Dot Product Attention) with BSND format | | `torch.ops.auto_deploy.torch_attention_deepseek_fused_mla` | DeepSeek fused MLA (Multi-head Linear Attention) | | `torch.ops.auto_deploy.torch_attention_deepseek_mla` | DeepSeek MLA implementation | -| `torch.ops.auto_deploy.torch_attention_grouped_sdpa` | Grouped SDPA implementation | +| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported | | `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention | | `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation | | `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation | diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 708859dc697..3200a21937d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -355,7 +355,7 @@ def get_num_qkv_args(cls) -> int: @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: """Get the source attention op that we target for replacement.""" - return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + return torch.ops.auto_deploy.torch_attention @classmethod def get_cached_attention_op(cls) -> MHACallable: @@ -399,6 +399,21 @@ def _init_workspace(si: SequenceInfo) -> torch.Tensor: @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: + # Sanity check: layout == "bsnd" + # Prefer kwargs; fall back to the final positional arg if it's a string. + layout = source_attn_node.kwargs.get("layout", None) + if ( + layout is None + and len(source_attn_node.args) > 0 + and isinstance(source_attn_node.args[-1], str) + ): + layout = source_attn_node.args[-1] + if layout != "bsnd": + raise RuntimeError( + f"Expected torch_attention layout='bsnd' but got {layout!r} " + f"for node: {source_attn_node.format_node()}" + ) + # Double check other arguments attn_mask, dropout_p, is_causal = extract_op_args( source_attn_node, "attn_mask", "dropout_p", "is_causal" diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py index b55bbe6bfd9..6c54103814b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py @@ -91,8 +91,9 @@ def scaled_dot_product_attention_fake( return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() -@torch.library.custom_op("auto_deploy::torch_attention_grouped_sdpa", mutates_args=()) -def grouped_sdpa( +# Unified attention op +@torch.library.custom_op("auto_deploy::torch_attention", mutates_args=()) +def torch_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -103,8 +104,25 @@ def grouped_sdpa( sinks: Optional[torch.Tensor] = None, sliding_window: Optional[int] = None, logit_cap: Optional[float] = None, + layout: str = "bnsd", # "bnsd" or "bsnd" ) -> torch.Tensor: - """SDPA attention that can handle GQA. Expects bnsd format inputs.""" + """ + SDPA attention (with optional GQA) that supports two memory layouts via `layout`: + - "bnsd": [batch, num_heads, seq_len, head_dim] + - "bsnd": [batch, seq_len, num_heads, head_dim] + + The `attn_mask` is always interpreted as [b, n, s_q, s_k]. + + Returns a tensor in the SAME layout as inputs specified by `layout`. + """ + if layout not in ("bnsd", "bsnd"): + raise ValueError(f"layout must be 'bnsd' or 'bsnd', got {layout!r}") + + if layout == "bsnd": + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim] _, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim] @@ -188,72 +206,26 @@ def grouped_sdpa( if dropout_p > 0.0: attn_out = F.dropout(attn_out, p=dropout_p, training=False) - # Return in bnsd format (same as input format) - return attn_out + if layout == "bsnd": + return attn_out.transpose(1, 2).contiguous() + else: + return attn_out.contiguous() -@grouped_sdpa.register_fake -def grouped_sdpa_fake( +@torch_attention.register_fake +def torch_attention_fake( query, key, value, attn_mask=None, - dropout_p=0.0, - is_causal=False, - scale=None, - sinks=None, - sliding_window=None, - logit_cap=None, -): - """Fake implementation of grouped SDPA.""" - return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() - - -@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=()) -def bsnd_grouped_sdpa( - query: torch.Tensor, # layout: [b, s_q, n, d] - key: torch.Tensor, # layout: [b, s_k, n, d] - value: torch.Tensor, # layout: [b, s_k, n, d] - attn_mask: Optional[torch.Tensor] = None, # layout: [b, n, s_q, s_k] dropout_p: float = 0.0, is_causal: bool = False, - scale: Optional[float] = None, - sinks: Optional[torch.Tensor] = None, - sliding_window: Optional[int] = None, - logit_cap: Optional[float] = None, -) -> torch.Tensor: - """Attention that assumes the input layout is bsnd. - - Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the - original sdpa op! - """ - # Transpose inputs to bnsd format for grouped_sdpa - query = query.transpose(1, 2).contiguous() # [b, s_q, n, d] -> [b, n, s_q, d] - key = key.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d] - value = value.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d] - - # Call grouped_sdpa with bnsd inputs - out = grouped_sdpa( - query, key, value, attn_mask, dropout_p, is_causal, scale, sinks, sliding_window, logit_cap - ) - # Transpose back to bsnd format - return out.transpose(1, 2).contiguous() - - -@bsnd_grouped_sdpa.register_fake -def bsnd_grouped_sdpa_fake( - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, scale=None, sinks=None, sliding_window=None, logit_cap=None, + layout: str = "bnsd", ): - """Fake implementation of bnsd grouped SDPA.""" return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py index 8a44b20e4cd..6eadb4b4466 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py @@ -409,7 +409,7 @@ def get_num_qkv_args(cls) -> int: @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: - return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + return torch.ops.auto_deploy.torch_attention @classmethod def get_cached_attention_op(cls) -> MHACallable: @@ -460,6 +460,21 @@ def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitial @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: + # Sanity check: layout == "bsnd" + # Prefer kwargs; fall back to the final positional arg if it's a string. + layout = source_attn_node.kwargs.get("layout", None) + if ( + layout is None + and len(source_attn_node.args) > 0 + and isinstance(source_attn_node.args[-1], str) + ): + layout = source_attn_node.args[-1] + if layout != "bsnd": + raise RuntimeError( + f"Expected torch_attention layout='bsnd' but got {layout!r} " + f"for node: {source_attn_node.format_node()}" + ) + # Check other arguments attn_mask, dropout_p, is_causal = extract_op_args( source_attn_node, "attn_mask", "dropout_p", "is_causal" diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index 0edaac9837e..56aad993a3c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -339,7 +339,7 @@ def get_num_qkv_args(cls) -> int: @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: - return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + return torch.ops.auto_deploy.torch_attention @classmethod def get_cached_attention_op(cls) -> MHACallable: @@ -390,6 +390,21 @@ def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitial @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: + # Sanity check: layout == "bsnd" + # Prefer kwargs; fall back to the final positional arg if it's a string. + layout = source_attn_node.kwargs.get("layout", None) + if ( + layout is None + and len(source_attn_node.args) > 0 + and isinstance(source_attn_node.args[-1], str) + ): + layout = source_attn_node.args[-1] + if layout != "bsnd": + raise RuntimeError( + f"Expected torch_attention layout='bsnd' but got {layout!r} " + f"for node: {source_attn_node.format_node()}" + ) + # retrieve head_dim from k_fake attn_mask, dropout_p, is_causal = extract_op_args( source_attn_node, "attn_mask", "dropout_p", "is_causal" diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py b/tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py index 3aecdc5eccd..8309e8a6de6 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py @@ -59,7 +59,7 @@ def gpt_oss_attention( sinks = self.sinks # Use custom op to capture attention. This layout is bsnd (batch, seq, num_heads, head_dim) - attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa( + attn_output = torch.ops.auto_deploy.torch_attention( query_states, key_states, value_states, @@ -69,6 +69,7 @@ def gpt_oss_attention( scale=self.scaling, sinks=sinks, sliding_window=sliding_window, + layout="bsnd", ) # Reshape back to original input shape diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py index 951d9bbf01e..9468220c59e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py @@ -1,5 +1,7 @@ """Pattern matching for detecting repeat_kv, eager, grouped attention patterns from Huggingface models.""" +from inspect import Parameter, Signature +from itertools import product from typing import Any, Callable, Dict, List, Tuple, Type import torch @@ -11,7 +13,6 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.logger import ad_logger -from ...utils.node_utils import is_op from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern from ..interface import ( BaseTransform, @@ -270,170 +271,6 @@ def causal_mask(): return patterns -def _grouped_attn_pattern_1(q, k, v, n_rep, attn_mask, dropout_p, scale): - k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) - v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -def _grouped_attn_replacement_1(q, k, v, n_rep, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -# Only expose torch_attention_grouped_sdpa after the transformation -def _grouped_attn_pattern_2(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -def _grouped_attn_replacement_2(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -def _grouped_attn_pattern_3(q, k, v, n_rep, attn_mask, dropout_p, scale): - k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) - v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -def _grouped_attn_replacement_3(q, k, v, n_rep, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -# Only expose torch_attention_grouped_sdpa after the transformation -def _grouped_attn_pattern_4(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -def _grouped_attn_replacement_4(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -def _grouped_attn_pattern_5(q, k, v, n_rep, attn_mask): - k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) - v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) - return torch.ops.auto_deploy.torch_attention_sdpa.default(q, k, v, attn_mask) - - -def _grouped_attn_replacement_5(q, k, v, n_rep, attn_mask): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(q, k, v, attn_mask) - - -def _grouped_attn_pattern_6(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=False, - scale=scale, - enable_gqa=True, - ) - - -def _grouped_attn_replacement_6(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -def _grouped_attn_pattern_7(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=True, - scale=scale, - enable_gqa=True, - ) - - -def _grouped_attn_replacement_7(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -def _grouped_attn_pattern_8(q, k, v, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, - k, - v, - attn_mask=None, - dropout_p=dropout_p, - is_causal=False, - scale=scale, - enable_gqa=True, - ) - - -def _grouped_attn_replacement_8(q, k, v, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( - q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -def _grouped_attn_pattern_9(q, k, v, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, - k, - v, - attn_mask=None, - dropout_p=dropout_p, - is_causal=True, - scale=scale, - enable_gqa=True, - ) - - -def _grouped_attn_replacement_9(q, k, v, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( - q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -def _grouped_attn_pattern_10(q, k, v, n_rep, dropout_p): - k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) - v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, - k, - v, - attn_mask=None, - dropout_p=dropout_p, - is_causal=True, - ) - - -def _grouped_attn_replacement_10(q, k, v, n_rep, dropout_p): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( - q, - k, - v, - attn_mask=None, - dropout_p=dropout_p, - is_causal=True, - ) - - @TransformRegistry.register("match_repeat_kv") class MatchRepeatKV(BaseTransform): """ @@ -508,11 +345,194 @@ def register_eager_attention(patterns: ADPatternMatcherPass): return gm, info +def _attach_signature(fn: Callable, argnames: List[str]) -> Callable: + # Make FX "see" q,k,v[,attn_mask][,dropout_p][,scale] even though fn(*args) internally + params = [Parameter(n, kind=Parameter.POSITIONAL_OR_KEYWORD) for n in argnames] + fn.__signature__ = Signature(parameters=params) + return fn + + +def _call_sdpa( + q, k, v, *, is_causal: bool, enable_gqa: bool, attn_mask=None, dropout_p=None, scale=None +): + kwargs = {"is_causal": is_causal} + if attn_mask is not None: + kwargs["attn_mask"] = attn_mask + if dropout_p is not None: + kwargs["dropout_p"] = dropout_p + if scale is not None: + kwargs["scale"] = scale + if enable_gqa: + kwargs["enable_gqa"] = True + return torch.ops.auto_deploy.torch_attention_sdpa.default(q, k, v, **kwargs) + + +def _call_attn(q, k, v, *, is_causal: bool, attn_mask=None, dropout_p=None, scale=None): + kwargs = {"is_causal": is_causal} + if attn_mask is not None: + kwargs["attn_mask"] = attn_mask + if dropout_p is not None: + kwargs["dropout_p"] = dropout_p + if scale is not None: + kwargs["scale"] = scale + return torch.ops.auto_deploy.torch_attention.default(q, k, v, **kwargs) + + +def make_grouped_attn_pair( + *, + repeat_kv: bool, + is_causal: bool, + has_scale: bool, + enable_gqa: bool, + has_attn_mask: bool, + has_dropout: bool, +) -> Tuple[Callable, Callable, List[str]]: + """ + Returns (pattern_fn, replacement_fn, argnames) with exact positional parity. + + Arg order rules: + Base: (q, k, v) + +repeat_kv -> insert n_rep after (q, k, v) + +attn_mask -> include attn_mask after n_rep if repeat_kv else after (q, k, v) + +dropout -> include dropout_p after attn_mask or after n_rep/base if no attn_mask + +scale -> include scale last + """ + argnames: List[str] = ["q", "k", "v"] + if repeat_kv: + argnames.append("n_rep") + if has_attn_mask: + argnames.append("attn_mask") + if has_dropout: + argnames.append("dropout_p") + if has_scale: + argnames.append("scale") + + def pattern_fn(*args): + if len(args) != len(argnames): + raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}") + m = dict(zip(argnames, args)) + + q = m["q"] + k = m["k"] + v = m["v"] + + if repeat_kv: + n_rep = m["n_rep"] + k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) + v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) + + return _call_sdpa( + q, + k, + v, + is_causal=is_causal, + enable_gqa=enable_gqa, + attn_mask=m.get("attn_mask"), + dropout_p=m.get("dropout_p"), + scale=m.get("scale"), + ) + + # Replacement: torch_attention.default mirroring the positional signature exactly. + # We do NOT pass enable_gqa here (it’s SDPA-only). We accept n_rep to mirror signature, + # but we don’t need to use it in the replacement graph. + def replacement_fn(*args): + if len(args) != len(argnames): + raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}") + m = dict(zip(argnames, args)) + return _call_attn( + m["q"], + m["k"], + m["v"], + is_causal=is_causal, + attn_mask=m.get("attn_mask"), + dropout_p=m.get("dropout_p"), + scale=m.get("scale"), + ) + + # Pattern matcher needs to see explicit arg names + _attach_signature(pattern_fn, argnames) + _attach_signature(replacement_fn, argnames) + + return pattern_fn, replacement_fn, argnames + + +def generate_and_register_grouped_attn_patterns(patterns, register_ad_pattern: Callable): + """ + Auto-generate all grouped attention patterns across these axes: + 1) repeat_kv: [False, True] + 2) is_causal: [False, True] + 3) has_scale: [False, True] + 4) enable_gqa: [False, True] (only a kwarg to SDPA side) + 5) has_attn_mask: [False, True] + 6) has_dropout: [False, True] + + For each valid combo, we: + - build pattern/replacement functions with exact-arg parity + - build dummy args matching the signature (with CUDA fp16 tensors etc.) + - build scalar_workaround dict for any scalars/n_rep present + - call register_ad_pattern(...) + """ + q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16) + k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16) + v1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16) + attn_mask_tensor = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16) + + dropout_val = 0.12345 + scale_val = 0.56789 + n_rep_val = 7 + + total = 0 + axes = ((False, True),) * 6 + for repeat_kv, is_causal, has_scale, enable_gqa, has_attn_mask, has_dropout in product(*axes): + pat_fn, rep_fn, argnames = make_grouped_attn_pair( + repeat_kv=repeat_kv, + is_causal=is_causal, + has_scale=has_scale, + enable_gqa=enable_gqa, + has_attn_mask=has_attn_mask, + has_dropout=has_dropout, + ) + + # Build dummy args in the same positional order + value_map = { + "q": q, + "k": k1, + "v": v1, + "n_rep": n_rep_val, + "attn_mask": attn_mask_tensor, + "dropout_p": dropout_val, + "scale": scale_val, + } + dummy_args: List[object] = [] + for name in argnames: + try: + dummy_args.append(value_map[name]) + except KeyError: + raise RuntimeError(f"Unexpected arg name: {name}") + + scalar_names = {"n_rep", "dropout_p", "scale"} + scalar_workaround: Dict[str, object] = { + n: value_map[n] for n in argnames if n in scalar_names + } + if not scalar_workaround: + scalar_workaround = None + + register_ad_pattern( + search_fn=pat_fn, + replace_fn=rep_fn, + patterns=patterns, + dummy_args=dummy_args, + scalar_workaround=scalar_workaround, + ) + total += 1 + return total + + @TransformRegistry.register("match_grouped_attention") class MatchGroupedAttention(BaseTransform): """ Match and replace the grouped attention pattern with - torch.ops.auto_deploy.torch_attention_grouped_sdpa. + torch.ops.auto_deploy.torch_attention. """ def _apply( @@ -523,99 +543,10 @@ def _apply( shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: def register_grouped_attention(patterns: ADPatternMatcherPass): - q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16) - k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16) - v1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16) - attn_mask = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16) - dropout = 0.12345 - scale = 0.56789 - n_rep = 7 - - dummy_args_1 = [q, k1, v1, n_rep, attn_mask, dropout, scale] - dummy_args_2 = [q, k1, v1, attn_mask, dropout, scale] - dummy_args_3 = [q, k1, v1, n_rep, attn_mask] - dummy_args_4 = [q, k1, v1, dropout, scale] - dummy_args_5 = [q, k1, v1, n_rep, dropout] - - register_ad_pattern( - search_fn=_grouped_attn_pattern_1, - replace_fn=_grouped_attn_replacement_1, - patterns=patterns, - dummy_args=dummy_args_1, - scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep}, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_2, - replace_fn=_grouped_attn_replacement_2, - patterns=patterns, - dummy_args=dummy_args_2, - scalar_workaround={ - "scale": scale, - "dropout_p": dropout, - }, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_3, - replace_fn=_grouped_attn_replacement_3, - patterns=patterns, - dummy_args=dummy_args_1, - scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep}, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_4, - replace_fn=_grouped_attn_replacement_4, - patterns=patterns, - dummy_args=dummy_args_2, - scalar_workaround={ - "scale": scale, - "dropout_p": dropout, - }, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_5, - replace_fn=_grouped_attn_replacement_5, - patterns=patterns, - dummy_args=dummy_args_3, - scalar_workaround={"n_rep": n_rep}, - ) - - register_ad_pattern( - search_fn=_grouped_attn_pattern_6, - replace_fn=_grouped_attn_replacement_6, - patterns=patterns, - dummy_args=dummy_args_2, - scalar_workaround={"scale": scale, "dropout_p": dropout}, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_7, - replace_fn=_grouped_attn_replacement_7, - patterns=patterns, - dummy_args=dummy_args_2, - scalar_workaround={"scale": scale, "dropout_p": dropout}, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_8, - replace_fn=_grouped_attn_replacement_8, - patterns=patterns, - dummy_args=dummy_args_4, - scalar_workaround={"scale": scale, "dropout_p": dropout}, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_9, - replace_fn=_grouped_attn_replacement_9, - patterns=patterns, - dummy_args=dummy_args_4, - scalar_workaround={"scale": scale, "dropout_p": dropout}, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_10, - replace_fn=_grouped_attn_replacement_10, - patterns=patterns, - dummy_args=dummy_args_5, - scalar_workaround={"dropout_p": dropout, "n_rep": n_rep}, - ) + return generate_and_register_grouped_attn_patterns(patterns, register_ad_pattern) num_grouped_patterns = _apply_pattern(gm, "Grouped Attention", register_grouped_attention) + if num_grouped_patterns == 0: ad_logger.warning( "Fail to find any Group Attention Pattern, output or performance may be incorrect" @@ -627,10 +558,146 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): is_clean=False, has_valid_shapes=False, ) - return gm, info +def _call_torch_attention( + q, k, v, *, is_causal, layout, attn_mask=None, dropout_p=None, scale=None +): + kwargs = {"is_causal": is_causal, "layout": layout} + if attn_mask is not None: + kwargs["attn_mask"] = attn_mask + if dropout_p is not None: + kwargs["dropout_p"] = dropout_p + if scale is not None: + kwargs["scale"] = scale + return torch.ops.auto_deploy.torch_attention.default(q, k, v, **kwargs) + + +def make_attn_bnsd_pair( + *, + has_attn_mask: bool, + has_dropout: bool, + is_causal: bool, + has_scale: bool, +) -> Tuple[Callable, Callable, List[str], str, str]: + argnames: List[str] = ["q", "k", "v"] + if has_attn_mask: + argnames.append("attn_mask") + if has_dropout: + argnames.append("dropout_p") + if has_scale: + argnames.append("scale") + + def pattern_fn(*args): + if len(args) != len(argnames): + raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}") + m = dict(zip(argnames, args)) + return _call_torch_attention( + m["q"], + m["k"], + m["v"], + is_causal=is_causal, + layout="bnsd", + attn_mask=m.get("attn_mask"), + dropout_p=m.get("dropout_p"), + scale=m.get("scale"), + ) + + def replacement_fn(*args): + if len(args) != len(argnames): + raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}") + m = dict(zip(argnames, args)) + q_b = torch.ops.aten.transpose.int(m["q"], 1, 2) + k_b = torch.ops.aten.transpose.int(m["k"], 1, 2) + v_b = torch.ops.aten.transpose.int(m["v"], 1, 2) + out_b = _call_torch_attention( + q_b, + k_b, + v_b, + is_causal=is_causal, + layout="bsnd", + attn_mask=m.get("attn_mask"), + dropout_p=m.get("dropout_p"), + scale=m.get("scale"), + ) + return torch.ops.aten.transpose.int(out_b, 1, 2) + + # Pattern matcher needs to see explicit arg names + _attach_signature(pattern_fn, argnames) + _attach_signature(replacement_fn, argnames) + + return pattern_fn, replacement_fn, argnames + + +def generate_and_register_attn_layout_patterns(patterns, register_ad_pattern: Callable): + """ + Enumerate all combinations across: + - has_attn_mask in {False, True} + - has_dropout in {False, True} + - is_causal in {False, True} + - has_scale in {False, True} + Register each pattern/replacement with appropriate dummy args and scalar workarounds. + """ + # Dummy tensors in BNSD + bs, n_heads, s_q, head_dim = 8, 8, 16, 64 + q = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16) + attn_mask = torch.randn(bs, n_heads, 1, s_q, device="cuda", dtype=torch.float16) + + dropout_p = 0.12345 + scale_val = 0.56789 + + total = 0 + axes = ((False, True),) * 4 + for has_attn_mask, has_dropout, is_causal, has_scale in product(*axes): + pat_fn, rep_fn, argnames = make_attn_bnsd_pair( + has_attn_mask=has_attn_mask, + has_dropout=has_dropout, + is_causal=is_causal, + has_scale=has_scale, + ) + + # Build dummy args following positional signature + value_map = { + "q": q, + "k": k, + "v": v, + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "scale": scale_val, + } + dummy_args: List[object] = [] + for name in argnames: + try: + dummy_args.append(value_map[name]) + except KeyError: + raise RuntimeError(f"Unexpected arg name: {name}") + + # Scalar workaround for present scalars only + scalar_names = {"dropout_p", "scale"} + scalar_workaround: Dict[str, object] = { + n: value_map[n] for n in argnames if n in scalar_names + } + if not scalar_workaround: + scalar_workaround = None + + register_ad_pattern( + search_fn=pat_fn, + replace_fn=rep_fn, + patterns=patterns, + dummy_args=dummy_args, + scalar_workaround=scalar_workaround, + ) + total += 1 + return total + + +def register_match_attn_layout(patterns: ADPatternMatcherPass): + return generate_and_register_attn_layout_patterns(patterns, register_ad_pattern) + + class MatchAttentionLayoutConfig(TransformConfig): """Configuration for the insert cached attention transform.""" @@ -640,13 +707,8 @@ class MatchAttentionLayoutConfig(TransformConfig): @TransformRegistry.register("match_attention_layout") class MatchAttentionLayout(BaseTransform): """ - Match and transform attention operations to match the layout expected by the attention backend. - - If the attention backend expects 'bnsd' layout (batch, num_heads, seq_len, head_dim), which - is the default for SDPA operations, we don't need to transform anything. - - If the backend expects 'bsnd' layout (batch, seq_len, num_heads, head_dim), we insert - appropriate transposes before and after SDPA operations and replace them with bsnd_grouped_sdpa. + Convert unified torch_attention calls from layout='bnsd' (explicit, positional or default) + into layout='bsnd' + correct Q/K/V transposes, and transpose the output back to bnsd. """ config: MatchAttentionLayoutConfig @@ -662,82 +724,26 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - # Get attention layout from attention_op attention_layout = self.config.attention_op.get_attention_layout() - # List of SDPA operations to look for - sdpa_ops = { - torch.ops.auto_deploy.torch_attention_grouped_sdpa, - } + if attention_layout not in ("bnsd", "bsnd"): + raise ValueError(f"Unsupported attention layout: {attention_layout}") - graph = gm.graph - num_bsnd_patterns = 0 - - # Look for SDPA operations - for sdpa_node in list(graph.nodes): - if sdpa_node.op != "call_function" or not is_op(sdpa_node, sdpa_ops): - continue - - ad_logger.debug(f"Found SDPA node to transform for bsnd layout: {sdpa_node}") - - # Extract q, k, v inputs - q, k, v = sdpa_node.args[:3] - - # Check if we need to transpose the inputs - if attention_layout == "bsnd": - # Add transposes before the node (from bnsd to bsnd) - with graph.inserting_before(sdpa_node): - q_updated = graph.call_function(torch.ops.aten.transpose.int, args=(q, 1, 2)) - k_updated = graph.call_function(torch.ops.aten.transpose.int, args=(k, 1, 2)) - v_updated = graph.call_function(torch.ops.aten.transpose.int, args=(v, 1, 2)) - - # Preserve fake tensor in meta["val"] for the transposed inputs - q_updated.meta["val"] = q.meta["val"].transpose(1, 2) - k_updated.meta["val"] = k.meta["val"].transpose(1, 2) - v_updated.meta["val"] = v.meta["val"].transpose(1, 2) - elif attention_layout == "bnsd": - # we don't need to do anything... - q_updated = q - k_updated = k - v_updated = v - else: - raise ValueError(f"Unsupported attention layout: {attention_layout}") - - # Create bsnd_grouped_sdpa node with the same args as the original node - # but using the transposed inputs - with graph.inserting_before(sdpa_node): - source_sdpa_node = graph.call_function( - self.config.attention_op.get_source_attention_op(), - args=(q_updated, k_updated, v_updated) + sdpa_node.args[3:], - kwargs=sdpa_node.kwargs, - ) - - # Check if need to update the output node to match the layout - if attention_layout == "bsnd": - # Add transpose for the output (from bsnd back to bnsd) - with graph.inserting_after(source_sdpa_node): - output_updated = graph.call_function( - torch.ops.aten.transpose.int, args=(source_sdpa_node, 1, 2) - ) - - # Preserve fake tensor in meta["val"] for the transposed inputs - source_sdpa_node.meta["val"] = sdpa_node.meta["val"].transpose(1, 2).contiguous() - output_updated.meta["val"] = source_sdpa_node.meta["val"].transpose(1, 2) - elif attention_layout == "bnsd": - output_updated = source_sdpa_node - else: - raise ValueError(f"Unsupported attention layout: {attention_layout}") - - # Replace the old node with the transposed output - sdpa_node.replace_all_uses_with(output_updated) - - num_bsnd_patterns += 1 + # If backend expects bnsd, nothing to do. + if attention_layout == "bnsd": + return gm, TransformInfo( + skipped=False, num_matches=0, is_clean=False, has_valid_shapes=False + ) + + num_matches = _apply_pattern( + gm, "MatchAttentionLayout(bnsd→bsnd)", register_match_attn_layout + ) + # If we changed any attention calls, the shapes may change around the transposes; flag for shape prop. info = TransformInfo( skipped=False, - num_matches=num_bsnd_patterns, + num_matches=num_matches, is_clean=False, has_valid_shapes=False, ) - return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py index 575bca14777..ed218061ef4 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py @@ -60,8 +60,9 @@ def fake_profiler_mha( v_fake = graph.placeholder(name="v_fake") v_fake.meta["val"] = torch.empty_like(value.transpose(1, 2), device="meta", dtype=value.dtype) + node_kwargs["layout"] = "bsnd" module._node_ref = graph.call_function( - torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa, + torch.ops.auto_deploy.torch_attention, args=(q_fake, k_fake, v_fake), kwargs=node_kwargs, ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 07f8df00e29..cfbf97eca90 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -479,8 +479,7 @@ def detect_column_row_shard( # acceptable attention nodes between sharded GEMMs shardable_attention_nodes = { torch.ops.auto_deploy.torch_attention_sdpa, - torch.ops.auto_deploy.torch_attention_grouped_sdpa, - torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa, + torch.ops.auto_deploy.torch_attention, } # This is a heuristic. Basically, we assume those are okay to shard if we also encounter an diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index c4554bf89b0..76d48669d61 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -90,7 +90,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: k = self.k_proj(x).view(b, s, -1, self.head_dim) v = self.v_proj(x).view(b, s, -1, self.head_dim) - y = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(q, k, v, is_causal=True) + y = torch.ops.auto_deploy.torch_attention(q, k, v, is_causal=True, layout="bsnd") y = y.contiguous().view(b, s, -1) return self.o_proj(y) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py index fd452eb36bb..e43e746ff0a 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py @@ -741,14 +741,12 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask): x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16) dynamic_shapes = model.get_dynamic_shapes() - # We should find 1 instance of torch_attention_grouped_sdpa + # We should find 1 instance of torch_attention expected_matches = 1 def verify_matcher(gm): grouped_sdpa_nodes = [ - n - for n in gm.graph.nodes - if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa) + n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention) ] # Check that we have the expected number of replacements @@ -879,7 +877,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Choose the appropriate attention implementation if self.use_grouped_sdpa: - attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa( + attn_output = torch.ops.auto_deploy.torch_attention( q, k, v, @@ -985,7 +983,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Choose the appropriate attention implementation if self.use_grouped_sdpa: - attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa( + attn_output = torch.ops.auto_deploy.torch_attention( q, k, v, @@ -1089,7 +1087,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Apply scaled dot product attention if self.use_grouped_sdpa: - attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa( + attn_output = torch.ops.auto_deploy.torch_attention( q, k, v, @@ -1136,7 +1134,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: attn_mask = self._get_attn_mask(x) if self.has_mask else None # Apply bsnd_grouped_sdpa directly - attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa.default( + attn_output = torch.ops.auto_deploy.torch_attention.default( q, k, v, @@ -1144,6 +1142,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dropout_p=0.0, is_causal=True, scale=1.0 / (self.head_dim**0.5), + layout="bsnd", ) # Reshape output for the linear projection (no transpose needed) @@ -1173,11 +1172,11 @@ def test_match_attention_layout(layout, model_config, has_mask): MockAttentionDescriptor.layout = layout if layout == "bnsd": if model_config.get("use_grouped_sdpa"): - source_op = torch.ops.auto_deploy.torch_attention_grouped_sdpa + source_op = torch.ops.auto_deploy.torch_attention else: source_op = torch.ops.auto_deploy.torch_attention_sdpa else: - source_op = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + source_op = torch.ops.auto_deploy.torch_attention MockAttentionDescriptor.source_attention_op = source_op # Create appropriate model based on model_config @@ -1210,7 +1209,8 @@ def verify_matcher(gm): original_nodes = [ n for n in gm.graph.nodes - if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa) + if is_op(n, torch.ops.auto_deploy.torch_attention) + and not (isinstance(n.args[-1], str) and n.args[-1] == "bsnd") ] else: original_nodes = [ @@ -1224,7 +1224,11 @@ def verify_matcher(gm): bsnd_nodes = [ n for n in gm.graph.nodes - if is_op(n, torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa) + if ( + is_op(n, torch.ops.auto_deploy.torch_attention) + and isinstance(n.args[-1], str) + and n.args[-1] == "bsnd" + ) ] transpose_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.aten.transpose.int)] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py index 93cfec18b50..7fbca61eac3 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py @@ -16,6 +16,7 @@ from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op torch.manual_seed(0) @@ -31,7 +32,7 @@ def get_attention_layout(cls) -> str: @classmethod def get_source_attention_op(cls) -> Callable: - return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + return torch.ops.auto_deploy.torch_attention class HFWrapper(nn.Module): @@ -82,12 +83,18 @@ def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str) pytest.skip("https://nvbugspro.nvidia.com/bug/5170222") def verify_matcher(gm: GraphModule): - """Ensure that there is exactly one torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + """Ensure that there is exactly one torch.ops.auto_deploy.torch_attention (layout="bsnd") call in the graph. Also check that there is no repeat_kv pattern left. """ - nodes = gm.graph.find_nodes( - op="call_function", target=torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa - ) + nodes = [ + n + for n in gm.graph.nodes + if ( + is_op(n, torch.ops.auto_deploy.torch_attention) + and isinstance(n.args[-1], str) + and n.args[-1] == "bsnd" + ) + ] assert len(nodes) == 1, "Expected exactly one bsnd_grouped_sdpa call in the graph" # TODO: check non-qkv args of node diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index b4159bd4828..ded0f820af7 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -74,8 +74,18 @@ def forward( v = v.view(b, s, self.num_kv_heads, self.head_dim) # Use grouped SDPA in bsnd layout - attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa( - q, k, v, None, 0.0, True, None + attn_output = torch.ops.auto_deploy.torch_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=True, + scale=None, + sinks=None, + sliding_window=None, + logit_cap=None, + layout="bsnd", ) # SDPA output is already in [b, s, n, h_d] format