From c8c1916ddc0c89a30b16063ef40f449dcbc44c9d Mon Sep 17 00:00:00 2001 From: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Date: Sat, 27 Sep 2025 01:13:36 +0000 Subject: [PATCH 1/5] WIP: unify bsnd_group_attention and group_attention, update match_attention_layout to use pattern matcher Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> --- .../custom_ops/flashinfer_attention.py | 2 +- .../auto_deploy/custom_ops/torch_attention.py | 138 +++++++ .../custom_ops/torch_backend_attention.py | 2 +- .../custom_ops/triton_attention.py | 2 +- .../auto_deploy/models/patches/gptoss.py | 3 +- .../transform/library/attention.py | 360 +++++++++++++----- .../transform/library/kvcache_transformers.py | 3 +- .../auto_deploy/transform/library/sharding.py | 3 +- .../library/test_tp_sharding.py | 2 +- .../library/test_attention_matcher.py | 28 +- .../library/test_attention_matcher_hf.py | 17 +- .../transformations/library/test_kv_cache.py | 14 +- 12 files changed, 461 insertions(+), 113 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 708859dc697..78535008092 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -355,7 +355,7 @@ def get_num_qkv_args(cls) -> int: @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: """Get the source attention op that we target for replacement.""" - return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + return torch.ops.auto_deploy.torch_attention @classmethod def get_cached_attention_op(cls) -> MHACallable: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py index b55bbe6bfd9..afa0e5b1ad9 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py @@ -257,6 +257,144 @@ def bsnd_grouped_sdpa_fake( return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() +# Unified attention op +@torch.library.custom_op("auto_deploy::torch_attention", mutates_args=()) +def torch_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + sinks: Optional[torch.Tensor] = None, + sliding_window: Optional[int] = None, + logit_cap: Optional[float] = None, + layout: str = "bnsd", # "bnsd" or "bsnd" +) -> torch.Tensor: + """ + SDPA attention (with optional GQA) that supports two memory layouts via `layout`: + - "bnsd": [batch, num_heads, seq_len, head_dim] + - "bsnd": [batch, seq_len, num_heads, head_dim] + + The `attn_mask` is always interpreted as [b, n, s_q, s_k]. + + Returns a tensor in the SAME layout as inputs specified by `layout`. + """ + if layout not in ("bnsd", "bsnd"): + raise ValueError(f"layout must be 'bnsd' or 'bsnd', got {layout!r}") + + if layout == "bsnd": + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim] + _, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim] + + # Inputs are already in bnsd format, no need to transpose + query_t = query # [b, n_heads, s_q, head_dim] + key_t = key # [b, n_kv_heads, s_k, head_dim] + value_t = value # [b, n_kv_heads, s_k, v_head_dim] + + # Handle GQA by repeating KV if needed + if n_heads != n_kv_heads: + n_rep = n_heads // n_kv_heads + key_t = repeat_kv(key_t, n_rep) + value_t = repeat_kv(value_t, n_rep) + + # Set scale + if scale is None: + scale = 1.0 / math.sqrt(head_dim) + + # Compute attention scores: Q @ K^T + attn_scores = torch.matmul(query_t, key_t.transpose(-2, -1)) * scale # [b, n_heads, s_q, s_k] + + # Apply attention mask if provided + if attn_mask is not None: + # Convert boolean mask to float if needed + attn_mask = _convert_boolean_mask_to_float(attn_mask, attn_scores.dtype) + attn_scores = attn_scores + attn_mask + + # Apply causal mask if specified and only during the context phase + if is_causal and s_q == s_k: # Only apply causal mask during context processing + causal_mask = torch.triu( + torch.ones(s_q, s_k, device=query.device, dtype=torch.bool), + diagonal=1, # Use diagonal=1 for standard causal masking + ) + attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + + # Apply sliding window mask if specified + if sliding_window is not None and sliding_window > 0: + # Handle position calculation for both context and generation phases + if s_q == s_k: + # Context phase: standard position calculation + query_positions = torch.arange(s_q, device=query.device) + key_positions = torch.arange(s_k, device=query.device) + else: + # Generation phase: query is at position s_k (after the cache) + query_positions = torch.arange(s_k, s_k + s_q, device=query.device) # [s_k] for s_q=1 + key_positions = torch.arange(s_k, device=query.device) # [0,1,2,...,s_k-1] + + # Create position difference matrix: query_pos - key_pos + pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) # [s_q, s_k] + + # Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size + sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window) # [s_q, s_k] + attn_scores.masked_fill_(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + + # Apply logit softcapping if enabled + attn_scores = _apply_logit_softcapping(attn_scores, logit_cap) + + # Apply sinks if provided + if sinks is not None: + # Concatenate sinks to attention scores following the reference implementation + # sinks should have n_heads elements, each head gets its own sink value + # Expand sinks to [b, n_heads, s_q, 1] - one sink column per head + sinks_expanded = sinks.reshape(1, -1, 1, 1).expand( + b, n_heads, s_q, 1 + ) # [b, n_heads, s_q, 1] + + # Concatenate along the key dimension (last dimension) + logits_max = torch.max(attn_scores, dim=-1, keepdim=True).values + sinks = torch.exp(sinks_expanded - logits_max) + unnormalized_scores = torch.exp(attn_scores - logits_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + # Use only the non-sink portion for computing output + # We added exactly 1 column, so remove exactly 1 column + attn_out = torch.matmul(scores, value_t) # [b, n_heads, s_q, v_head_dim] + else: + attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype) + attn_out = torch.matmul(attn_weights, value_t) # [b, n_heads, s_q, v_head_dim] + + # Apply dropout if specified + if dropout_p > 0.0: + attn_out = F.dropout(attn_out, p=dropout_p, training=False) + + if layout == "bsnd": + return attn_out.transpose(1, 2).contiguous() + else: + return attn_out.contiguous() + + +@torch_attention.register_fake +def torch_attention_fake( + query, + key, + value, + attn_mask=None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale=None, + sinks=None, + sliding_window=None, + logit_cap=None, + layout: str = "bnsd", +): + return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() + + def update_kv_cache( key_states: torch.Tensor, value_states: torch.Tensor, diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py index 8a44b20e4cd..949817da963 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py @@ -409,7 +409,7 @@ def get_num_qkv_args(cls) -> int: @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: - return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + return torch.ops.auto_deploy.torch_attention @classmethod def get_cached_attention_op(cls) -> MHACallable: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index 0edaac9837e..a0d8693fe39 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -339,7 +339,7 @@ def get_num_qkv_args(cls) -> int: @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: - return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + return torch.ops.auto_deploy.torch_attention @classmethod def get_cached_attention_op(cls) -> MHACallable: diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py b/tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py index 3aecdc5eccd..8309e8a6de6 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py @@ -59,7 +59,7 @@ def gpt_oss_attention( sinks = self.sinks # Use custom op to capture attention. This layout is bsnd (batch, seq, num_heads, head_dim) - attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa( + attn_output = torch.ops.auto_deploy.torch_attention( query_states, key_states, value_states, @@ -69,6 +69,7 @@ def gpt_oss_attention( scale=self.scaling, sinks=sinks, sliding_window=sliding_window, + layout="bsnd", ) # Reshape back to original input shape diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py index 951d9bbf01e..1516a0f6e2c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py @@ -11,7 +11,6 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.logger import ad_logger -from ...utils.node_utils import is_op from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern from ..interface import ( BaseTransform, @@ -279,12 +278,12 @@ def _grouped_attn_pattern_1(q, k, v, n_rep, attn_mask, dropout_p, scale): def _grouped_attn_replacement_1(q, k, v, n_rep, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + return torch.ops.auto_deploy.torch_attention.default( q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale ) -# Only expose torch_attention_grouped_sdpa after the transformation +# Only expose torch_attention after the transformation def _grouped_attn_pattern_2(q, k, v, attn_mask, dropout_p, scale): return torch.ops.auto_deploy.torch_attention_sdpa.default( q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale @@ -292,7 +291,7 @@ def _grouped_attn_pattern_2(q, k, v, attn_mask, dropout_p, scale): def _grouped_attn_replacement_2(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + return torch.ops.auto_deploy.torch_attention.default( q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale ) @@ -306,12 +305,12 @@ def _grouped_attn_pattern_3(q, k, v, n_rep, attn_mask, dropout_p, scale): def _grouped_attn_replacement_3(q, k, v, n_rep, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + return torch.ops.auto_deploy.torch_attention.default( q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale ) -# Only expose torch_attention_grouped_sdpa after the transformation +# Only expose torch_attention after the transformation def _grouped_attn_pattern_4(q, k, v, attn_mask, dropout_p, scale): return torch.ops.auto_deploy.torch_attention_sdpa.default( q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale @@ -319,7 +318,7 @@ def _grouped_attn_pattern_4(q, k, v, attn_mask, dropout_p, scale): def _grouped_attn_replacement_4(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + return torch.ops.auto_deploy.torch_attention.default( q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale ) @@ -331,7 +330,7 @@ def _grouped_attn_pattern_5(q, k, v, n_rep, attn_mask): def _grouped_attn_replacement_5(q, k, v, n_rep, attn_mask): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(q, k, v, attn_mask) + return torch.ops.auto_deploy.torch_attention.default(q, k, v, attn_mask) def _grouped_attn_pattern_6(q, k, v, attn_mask, dropout_p, scale): @@ -512,7 +511,7 @@ def register_eager_attention(patterns: ADPatternMatcherPass): class MatchGroupedAttention(BaseTransform): """ Match and replace the grouped attention pattern with - torch.ops.auto_deploy.torch_attention_grouped_sdpa. + torch.ops.auto_deploy.torch_attention. """ def _apply( @@ -631,6 +630,256 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): return gm, info +def _attn_bnsd_pattern_1(q, k, v, attn_mask, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention.default( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=False, + scale=scale, + layout="bnsd", + ) + + +def _attn_bnsd_pattern_2(q, k, v, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention.default( + q, + k, + v, + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=scale, + layout="bnsd", + ) + + +def _attn_bnsd_pattern_3(q, k, v, attn_mask, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention.default( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=True, + scale=scale, + layout="bnsd", + ) + + +def _attn_bnsd_pattern_4(q, k, v, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention.default( + q, + k, + v, + attn_mask=None, + dropout_p=dropout_p, + is_causal=True, + scale=scale, + layout="bnsd", + ) + + +def _attn_bnsd_pattern_5(q, k, v, attn_mask, dropout_p): + return torch.ops.auto_deploy.torch_attention.default( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=False, + scale=None, + layout="bnsd", + ) + + +def _attn_bnsd_pattern_6(q, k, v, dropout_p): + return torch.ops.auto_deploy.torch_attention.default( + q, + k, + v, + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=None, + layout="bnsd", + ) + + +def _attn_bnsd_pattern_7(q, k, v, attn_mask, dropout_p): + return torch.ops.auto_deploy.torch_attention.default( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=True, + scale=None, + layout="bnsd", + ) + + +def _attn_bnsd_pattern_8(q, k, v, dropout_p): + return torch.ops.auto_deploy.torch_attention.default( + q, + k, + v, + attn_mask=None, + dropout_p=dropout_p, + is_causal=True, + scale=None, + layout="bnsd", + ) + + +def _attn_bnsd_to_bnsd_via_bsnd(q, k, v, *, attn_mask, dropout_p, is_causal, scale): + q_bsnd = torch.ops.aten.transpose.int(q, 1, 2) + k_bsnd = torch.ops.aten.transpose.int(k, 1, 2) + v_bsnd = torch.ops.aten.transpose.int(v, 1, 2) + + out_bsnd = torch.ops.auto_deploy.torch_attention.default( + q_bsnd, + k_bsnd, + v_bsnd, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + layout="bsnd", + ) + return torch.ops.aten.transpose.int(out_bsnd, 1, 2) + + +# 1) is_causal=False, mask present, scale present +def _attn_bnsd_replacement_1(q, k, v, attn_mask, dropout_p, scale): + return _attn_bnsd_to_bnsd_via_bsnd( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale + ) + + +# 2) is_causal=False, mask None, scale present +def _attn_bnsd_replacement_2(q, k, v, dropout_p, scale): + return _attn_bnsd_to_bnsd_via_bsnd( + q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=scale + ) + + +# 3) is_causal=True, mask present, scale present +def _attn_bnsd_replacement_3(q, k, v, attn_mask, dropout_p, scale): + return _attn_bnsd_to_bnsd_via_bsnd( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale + ) + + +# 4) is_causal=True, mask None, scale present +def _attn_bnsd_replacement_4(q, k, v, dropout_p, scale): + return _attn_bnsd_to_bnsd_via_bsnd( + q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=True, scale=scale + ) + + +# 5) is_causal=False, mask present, scale=None +def _attn_bnsd_replacement_5(q, k, v, attn_mask, dropout_p): + return _attn_bnsd_to_bnsd_via_bsnd( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=None + ) + + +# 6) is_causal=False, mask None, scale=None +def _attn_bnsd_replacement_6(q, k, v, dropout_p): + return _attn_bnsd_to_bnsd_via_bsnd( + q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=None + ) + + +# 7) is_causal=True, mask present, scale=None +def _attn_bnsd_replacement_7(q, k, v, attn_mask, dropout_p): + return _attn_bnsd_to_bnsd_via_bsnd( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=None + ) + + +# 8) is_causal=True, mask None, scale=None +def _attn_bnsd_replacement_8(q, k, v, dropout_p): + return _attn_bnsd_to_bnsd_via_bsnd( + q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=True, scale=None + ) + + +def register_match_attn_layout(patterns: ADPatternMatcherPass): + # Dummy tensors in BNSD (we match bnsd calls) + bs, n_heads, s_q, head_dim = 8, 8, 16, 64 + q = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16) + attn_mask = torch.randn(bs, n_heads, 1, s_q, device="cuda", dtype=torch.float16) + + dropout_p = 0.12345 + scale_val = 0.56789 + + # 1..4 (scale present) + register_ad_pattern( + search_fn=_attn_bnsd_pattern_1, + replace_fn=_attn_bnsd_replacement_1, + patterns=patterns, + dummy_args=[q, k, v, attn_mask, dropout_p, scale_val], + scalar_workaround={"dropout_p": dropout_p, "scale": scale_val}, + ) + register_ad_pattern( + search_fn=_attn_bnsd_pattern_2, + replace_fn=_attn_bnsd_replacement_2, + patterns=patterns, + dummy_args=[q, k, v, dropout_p, scale_val], + scalar_workaround={"dropout_p": dropout_p, "scale": scale_val}, + ) + register_ad_pattern( + search_fn=_attn_bnsd_pattern_3, + replace_fn=_attn_bnsd_replacement_3, + patterns=patterns, + dummy_args=[q, k, v, attn_mask, dropout_p, scale_val], + scalar_workaround={"dropout_p": dropout_p, "scale": scale_val}, + ) + register_ad_pattern( + search_fn=_attn_bnsd_pattern_4, + replace_fn=_attn_bnsd_replacement_4, + patterns=patterns, + dummy_args=[q, k, v, dropout_p, scale_val], + scalar_workaround={"dropout_p": dropout_p, "scale": scale_val}, + ) + + # 5..8 (scale None) + register_ad_pattern( + search_fn=_attn_bnsd_pattern_5, + replace_fn=_attn_bnsd_replacement_5, + patterns=patterns, + dummy_args=[q, k, v, attn_mask, dropout_p], + scalar_workaround={"dropout_p": dropout_p}, + ) + register_ad_pattern( + search_fn=_attn_bnsd_pattern_6, + replace_fn=_attn_bnsd_replacement_6, + patterns=patterns, + dummy_args=[q, k, v, dropout_p], + scalar_workaround={"dropout_p": dropout_p}, + ) + register_ad_pattern( + search_fn=_attn_bnsd_pattern_7, + replace_fn=_attn_bnsd_replacement_7, + patterns=patterns, + dummy_args=[q, k, v, attn_mask, dropout_p], + scalar_workaround={"dropout_p": dropout_p}, + ) + register_ad_pattern( + search_fn=_attn_bnsd_pattern_8, + replace_fn=_attn_bnsd_replacement_8, + patterns=patterns, + dummy_args=[q, k, v, dropout_p], + scalar_workaround={"dropout_p": dropout_p}, + ) + + class MatchAttentionLayoutConfig(TransformConfig): """Configuration for the insert cached attention transform.""" @@ -640,13 +889,8 @@ class MatchAttentionLayoutConfig(TransformConfig): @TransformRegistry.register("match_attention_layout") class MatchAttentionLayout(BaseTransform): """ - Match and transform attention operations to match the layout expected by the attention backend. - - If the attention backend expects 'bnsd' layout (batch, num_heads, seq_len, head_dim), which - is the default for SDPA operations, we don't need to transform anything. - - If the backend expects 'bsnd' layout (batch, seq_len, num_heads, head_dim), we insert - appropriate transposes before and after SDPA operations and replace them with bsnd_grouped_sdpa. + Convert unified torch_attention calls from layout='bnsd' (explicit, positional or default) + into layout='bsnd' + correct Q/K/V transposes, and transpose the output back to bnsd. """ config: MatchAttentionLayoutConfig @@ -662,82 +906,26 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - # Get attention layout from attention_op attention_layout = self.config.attention_op.get_attention_layout() - # List of SDPA operations to look for - sdpa_ops = { - torch.ops.auto_deploy.torch_attention_grouped_sdpa, - } - - graph = gm.graph - num_bsnd_patterns = 0 - - # Look for SDPA operations - for sdpa_node in list(graph.nodes): - if sdpa_node.op != "call_function" or not is_op(sdpa_node, sdpa_ops): - continue - - ad_logger.debug(f"Found SDPA node to transform for bsnd layout: {sdpa_node}") - - # Extract q, k, v inputs - q, k, v = sdpa_node.args[:3] - - # Check if we need to transpose the inputs - if attention_layout == "bsnd": - # Add transposes before the node (from bnsd to bsnd) - with graph.inserting_before(sdpa_node): - q_updated = graph.call_function(torch.ops.aten.transpose.int, args=(q, 1, 2)) - k_updated = graph.call_function(torch.ops.aten.transpose.int, args=(k, 1, 2)) - v_updated = graph.call_function(torch.ops.aten.transpose.int, args=(v, 1, 2)) - - # Preserve fake tensor in meta["val"] for the transposed inputs - q_updated.meta["val"] = q.meta["val"].transpose(1, 2) - k_updated.meta["val"] = k.meta["val"].transpose(1, 2) - v_updated.meta["val"] = v.meta["val"].transpose(1, 2) - elif attention_layout == "bnsd": - # we don't need to do anything... - q_updated = q - k_updated = k - v_updated = v - else: - raise ValueError(f"Unsupported attention layout: {attention_layout}") - - # Create bsnd_grouped_sdpa node with the same args as the original node - # but using the transposed inputs - with graph.inserting_before(sdpa_node): - source_sdpa_node = graph.call_function( - self.config.attention_op.get_source_attention_op(), - args=(q_updated, k_updated, v_updated) + sdpa_node.args[3:], - kwargs=sdpa_node.kwargs, - ) - - # Check if need to update the output node to match the layout - if attention_layout == "bsnd": - # Add transpose for the output (from bsnd back to bnsd) - with graph.inserting_after(source_sdpa_node): - output_updated = graph.call_function( - torch.ops.aten.transpose.int, args=(source_sdpa_node, 1, 2) - ) - - # Preserve fake tensor in meta["val"] for the transposed inputs - source_sdpa_node.meta["val"] = sdpa_node.meta["val"].transpose(1, 2).contiguous() - output_updated.meta["val"] = source_sdpa_node.meta["val"].transpose(1, 2) - elif attention_layout == "bnsd": - output_updated = source_sdpa_node - else: - raise ValueError(f"Unsupported attention layout: {attention_layout}") - - # Replace the old node with the transposed output - sdpa_node.replace_all_uses_with(output_updated) - - num_bsnd_patterns += 1 + if attention_layout not in ("bnsd", "bsnd"): + raise ValueError(f"Unsupported attention layout: {attention_layout}") + + # If backend expects bnsd, nothing to do. + if attention_layout == "bnsd": + return gm, TransformInfo( + skipped=False, num_matches=0, is_clean=False, has_valid_shapes=False + ) + + num_matches = _apply_pattern( + gm, "MatchAttentionLayout(bnsd→bsnd)", register_match_attn_layout + ) + # If we changed any attention calls, the shapes may change around the transposes; flag for shape prop. info = TransformInfo( skipped=False, - num_matches=num_bsnd_patterns, + num_matches=num_matches, is_clean=False, has_valid_shapes=False, ) - return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py index 575bca14777..8c47fb4002a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py @@ -46,6 +46,7 @@ def fake_profiler_mha( "sinks": "sinks", "sliding_window": "sliding_window", "logit_cap": "logit_cap", + "layout": "bsnd", } for k_kwargs, k_op_kwargs in kwargs_to_op.items(): if k_kwargs in kwargs: @@ -61,7 +62,7 @@ def fake_profiler_mha( v_fake.meta["val"] = torch.empty_like(value.transpose(1, 2), device="meta", dtype=value.dtype) module._node_ref = graph.call_function( - torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa, + torch.ops.auto_deploy.torch_attention, args=(q_fake, k_fake, v_fake), kwargs=node_kwargs, ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 07f8df00e29..cfbf97eca90 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -479,8 +479,7 @@ def detect_column_row_shard( # acceptable attention nodes between sharded GEMMs shardable_attention_nodes = { torch.ops.auto_deploy.torch_attention_sdpa, - torch.ops.auto_deploy.torch_attention_grouped_sdpa, - torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa, + torch.ops.auto_deploy.torch_attention, } # This is a heuristic. Basically, we assume those are okay to shard if we also encounter an diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index c4554bf89b0..76d48669d61 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -90,7 +90,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: k = self.k_proj(x).view(b, s, -1, self.head_dim) v = self.v_proj(x).view(b, s, -1, self.head_dim) - y = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(q, k, v, is_causal=True) + y = torch.ops.auto_deploy.torch_attention(q, k, v, is_causal=True, layout="bsnd") y = y.contiguous().view(b, s, -1) return self.o_proj(y) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py index fd452eb36bb..e43e746ff0a 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py @@ -741,14 +741,12 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask): x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16) dynamic_shapes = model.get_dynamic_shapes() - # We should find 1 instance of torch_attention_grouped_sdpa + # We should find 1 instance of torch_attention expected_matches = 1 def verify_matcher(gm): grouped_sdpa_nodes = [ - n - for n in gm.graph.nodes - if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa) + n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention) ] # Check that we have the expected number of replacements @@ -879,7 +877,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Choose the appropriate attention implementation if self.use_grouped_sdpa: - attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa( + attn_output = torch.ops.auto_deploy.torch_attention( q, k, v, @@ -985,7 +983,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Choose the appropriate attention implementation if self.use_grouped_sdpa: - attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa( + attn_output = torch.ops.auto_deploy.torch_attention( q, k, v, @@ -1089,7 +1087,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Apply scaled dot product attention if self.use_grouped_sdpa: - attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa( + attn_output = torch.ops.auto_deploy.torch_attention( q, k, v, @@ -1136,7 +1134,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: attn_mask = self._get_attn_mask(x) if self.has_mask else None # Apply bsnd_grouped_sdpa directly - attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa.default( + attn_output = torch.ops.auto_deploy.torch_attention.default( q, k, v, @@ -1144,6 +1142,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dropout_p=0.0, is_causal=True, scale=1.0 / (self.head_dim**0.5), + layout="bsnd", ) # Reshape output for the linear projection (no transpose needed) @@ -1173,11 +1172,11 @@ def test_match_attention_layout(layout, model_config, has_mask): MockAttentionDescriptor.layout = layout if layout == "bnsd": if model_config.get("use_grouped_sdpa"): - source_op = torch.ops.auto_deploy.torch_attention_grouped_sdpa + source_op = torch.ops.auto_deploy.torch_attention else: source_op = torch.ops.auto_deploy.torch_attention_sdpa else: - source_op = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + source_op = torch.ops.auto_deploy.torch_attention MockAttentionDescriptor.source_attention_op = source_op # Create appropriate model based on model_config @@ -1210,7 +1209,8 @@ def verify_matcher(gm): original_nodes = [ n for n in gm.graph.nodes - if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa) + if is_op(n, torch.ops.auto_deploy.torch_attention) + and not (isinstance(n.args[-1], str) and n.args[-1] == "bsnd") ] else: original_nodes = [ @@ -1224,7 +1224,11 @@ def verify_matcher(gm): bsnd_nodes = [ n for n in gm.graph.nodes - if is_op(n, torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa) + if ( + is_op(n, torch.ops.auto_deploy.torch_attention) + and isinstance(n.args[-1], str) + and n.args[-1] == "bsnd" + ) ] transpose_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.aten.transpose.int)] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py index 93cfec18b50..7fbca61eac3 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py @@ -16,6 +16,7 @@ from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op torch.manual_seed(0) @@ -31,7 +32,7 @@ def get_attention_layout(cls) -> str: @classmethod def get_source_attention_op(cls) -> Callable: - return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + return torch.ops.auto_deploy.torch_attention class HFWrapper(nn.Module): @@ -82,12 +83,18 @@ def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str) pytest.skip("https://nvbugspro.nvidia.com/bug/5170222") def verify_matcher(gm: GraphModule): - """Ensure that there is exactly one torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + """Ensure that there is exactly one torch.ops.auto_deploy.torch_attention (layout="bsnd") call in the graph. Also check that there is no repeat_kv pattern left. """ - nodes = gm.graph.find_nodes( - op="call_function", target=torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa - ) + nodes = [ + n + for n in gm.graph.nodes + if ( + is_op(n, torch.ops.auto_deploy.torch_attention) + and isinstance(n.args[-1], str) + and n.args[-1] == "bsnd" + ) + ] assert len(nodes) == 1, "Expected exactly one bsnd_grouped_sdpa call in the graph" # TODO: check non-qkv args of node diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index b4159bd4828..ded0f820af7 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -74,8 +74,18 @@ def forward( v = v.view(b, s, self.num_kv_heads, self.head_dim) # Use grouped SDPA in bsnd layout - attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa( - q, k, v, None, 0.0, True, None + attn_output = torch.ops.auto_deploy.torch_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=True, + scale=None, + sinks=None, + sliding_window=None, + logit_cap=None, + layout="bsnd", ) # SDPA output is already in [b, s, n, h_d] format From 7dbd0c635d00a3f5499ac5a8363993ebb6472a04 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Tue, 30 Sep 2025 18:15:53 +0000 Subject: [PATCH 2/5] remove old attention ops, update after rebase Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- .../_torch/auto_deploy/custom_ops/README.md | 3 +- .../auto_deploy/custom_ops/torch_attention.py | 166 ------------------ .../transform/library/attention.py | 10 +- 3 files changed, 6 insertions(+), 173 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md index 74258bbcd27..5a65d610493 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md @@ -12,10 +12,9 @@ The table below lists the operators ordered by their backend. |--------------|-------------| | `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer attention with KV cache support | | `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE (Rotary Position Embedding) implementation | -| `torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa` | Grouped SDPA (Scaled Dot Product Attention) with BSND format | | `torch.ops.auto_deploy.torch_attention_deepseek_fused_mla` | DeepSeek fused MLA (Multi-head Linear Attention) | | `torch.ops.auto_deploy.torch_attention_deepseek_mla` | DeepSeek MLA implementation | -| `torch.ops.auto_deploy.torch_attention_grouped_sdpa` | Grouped SDPA implementation | +| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported | | `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention | | `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation | | `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation | diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py index afa0e5b1ad9..6c54103814b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py @@ -91,172 +91,6 @@ def scaled_dot_product_attention_fake( return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() -@torch.library.custom_op("auto_deploy::torch_attention_grouped_sdpa", mutates_args=()) -def grouped_sdpa( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - sinks: Optional[torch.Tensor] = None, - sliding_window: Optional[int] = None, - logit_cap: Optional[float] = None, -) -> torch.Tensor: - """SDPA attention that can handle GQA. Expects bnsd format inputs.""" - b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim] - _, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim] - - # Inputs are already in bnsd format, no need to transpose - query_t = query # [b, n_heads, s_q, head_dim] - key_t = key # [b, n_kv_heads, s_k, head_dim] - value_t = value # [b, n_kv_heads, s_k, v_head_dim] - - # Handle GQA by repeating KV if needed - if n_heads != n_kv_heads: - n_rep = n_heads // n_kv_heads - key_t = repeat_kv(key_t, n_rep) - value_t = repeat_kv(value_t, n_rep) - - # Set scale - if scale is None: - scale = 1.0 / math.sqrt(head_dim) - - # Compute attention scores: Q @ K^T - attn_scores = torch.matmul(query_t, key_t.transpose(-2, -1)) * scale # [b, n_heads, s_q, s_k] - - # Apply attention mask if provided - if attn_mask is not None: - # Convert boolean mask to float if needed - attn_mask = _convert_boolean_mask_to_float(attn_mask, attn_scores.dtype) - attn_scores = attn_scores + attn_mask - - # Apply causal mask if specified and only during the context phase - if is_causal and s_q == s_k: # Only apply causal mask during context processing - causal_mask = torch.triu( - torch.ones(s_q, s_k, device=query.device, dtype=torch.bool), - diagonal=1, # Use diagonal=1 for standard causal masking - ) - attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) - - # Apply sliding window mask if specified - if sliding_window is not None and sliding_window > 0: - # Handle position calculation for both context and generation phases - if s_q == s_k: - # Context phase: standard position calculation - query_positions = torch.arange(s_q, device=query.device) - key_positions = torch.arange(s_k, device=query.device) - else: - # Generation phase: query is at position s_k (after the cache) - query_positions = torch.arange(s_k, s_k + s_q, device=query.device) # [s_k] for s_q=1 - key_positions = torch.arange(s_k, device=query.device) # [0,1,2,...,s_k-1] - - # Create position difference matrix: query_pos - key_pos - pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) # [s_q, s_k] - - # Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size - sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window) # [s_q, s_k] - attn_scores.masked_fill_(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf")) - - # Apply logit softcapping if enabled - attn_scores = _apply_logit_softcapping(attn_scores, logit_cap) - - # Apply sinks if provided - if sinks is not None: - # Concatenate sinks to attention scores following the reference implementation - # sinks should have n_heads elements, each head gets its own sink value - # Expand sinks to [b, n_heads, s_q, 1] - one sink column per head - sinks_expanded = sinks.reshape(1, -1, 1, 1).expand( - b, n_heads, s_q, 1 - ) # [b, n_heads, s_q, 1] - - # Concatenate along the key dimension (last dimension) - logits_max = torch.max(attn_scores, dim=-1, keepdim=True).values - sinks = torch.exp(sinks_expanded - logits_max) - unnormalized_scores = torch.exp(attn_scores - logits_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - scores = unnormalized_scores / normalizer - # Use only the non-sink portion for computing output - # We added exactly 1 column, so remove exactly 1 column - attn_out = torch.matmul(scores, value_t) # [b, n_heads, s_q, v_head_dim] - else: - attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype) - attn_out = torch.matmul(attn_weights, value_t) # [b, n_heads, s_q, v_head_dim] - - # Apply dropout if specified - if dropout_p > 0.0: - attn_out = F.dropout(attn_out, p=dropout_p, training=False) - - # Return in bnsd format (same as input format) - return attn_out - - -@grouped_sdpa.register_fake -def grouped_sdpa_fake( - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - scale=None, - sinks=None, - sliding_window=None, - logit_cap=None, -): - """Fake implementation of grouped SDPA.""" - return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() - - -@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=()) -def bsnd_grouped_sdpa( - query: torch.Tensor, # layout: [b, s_q, n, d] - key: torch.Tensor, # layout: [b, s_k, n, d] - value: torch.Tensor, # layout: [b, s_k, n, d] - attn_mask: Optional[torch.Tensor] = None, # layout: [b, n, s_q, s_k] - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - sinks: Optional[torch.Tensor] = None, - sliding_window: Optional[int] = None, - logit_cap: Optional[float] = None, -) -> torch.Tensor: - """Attention that assumes the input layout is bsnd. - - Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the - original sdpa op! - """ - # Transpose inputs to bnsd format for grouped_sdpa - query = query.transpose(1, 2).contiguous() # [b, s_q, n, d] -> [b, n, s_q, d] - key = key.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d] - value = value.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d] - - # Call grouped_sdpa with bnsd inputs - out = grouped_sdpa( - query, key, value, attn_mask, dropout_p, is_causal, scale, sinks, sliding_window, logit_cap - ) - # Transpose back to bsnd format - return out.transpose(1, 2).contiguous() - - -@bsnd_grouped_sdpa.register_fake -def bsnd_grouped_sdpa_fake( - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - scale=None, - sinks=None, - sliding_window=None, - logit_cap=None, -): - """Fake implementation of bnsd grouped SDPA.""" - return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() - - # Unified attention op @torch.library.custom_op("auto_deploy::torch_attention", mutates_args=()) def torch_attention( diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py index 1516a0f6e2c..35b4c75558a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py @@ -347,7 +347,7 @@ def _grouped_attn_pattern_6(q, k, v, attn_mask, dropout_p, scale): def _grouped_attn_replacement_6(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + return torch.ops.auto_deploy.torch_attention.default( q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale ) @@ -366,7 +366,7 @@ def _grouped_attn_pattern_7(q, k, v, attn_mask, dropout_p, scale): def _grouped_attn_replacement_7(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + return torch.ops.auto_deploy.torch_attention.default( q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale ) @@ -385,7 +385,7 @@ def _grouped_attn_pattern_8(q, k, v, dropout_p, scale): def _grouped_attn_replacement_8(q, k, v, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + return torch.ops.auto_deploy.torch_attention.default( q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=scale ) @@ -404,7 +404,7 @@ def _grouped_attn_pattern_9(q, k, v, dropout_p, scale): def _grouped_attn_replacement_9(q, k, v, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + return torch.ops.auto_deploy.torch_attention.default( q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=True, scale=scale ) @@ -423,7 +423,7 @@ def _grouped_attn_pattern_10(q, k, v, n_rep, dropout_p): def _grouped_attn_replacement_10(q, k, v, n_rep, dropout_p): - return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + return torch.ops.auto_deploy.torch_attention.default( q, k, v, From 6736c9c06a0ceab39b9cc4421afa4eae0460db7c Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Tue, 30 Sep 2025 21:09:30 +0000 Subject: [PATCH 3/5] auto-generate patterns for MatchGroupAttention and MatchAttentionLayout Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- .../transform/library/attention.py | 822 +++++++----------- 1 file changed, 332 insertions(+), 490 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py index 35b4c75558a..d3e50d2d447 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py @@ -269,170 +269,6 @@ def causal_mask(): return patterns -def _grouped_attn_pattern_1(q, k, v, n_rep, attn_mask, dropout_p, scale): - k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) - v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -def _grouped_attn_replacement_1(q, k, v, n_rep, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -# Only expose torch_attention after the transformation -def _grouped_attn_pattern_2(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -def _grouped_attn_replacement_2(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -def _grouped_attn_pattern_3(q, k, v, n_rep, attn_mask, dropout_p, scale): - k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) - v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -def _grouped_attn_replacement_3(q, k, v, n_rep, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -# Only expose torch_attention after the transformation -def _grouped_attn_pattern_4(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -def _grouped_attn_replacement_4(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -def _grouped_attn_pattern_5(q, k, v, n_rep, attn_mask): - k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) - v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) - return torch.ops.auto_deploy.torch_attention_sdpa.default(q, k, v, attn_mask) - - -def _grouped_attn_replacement_5(q, k, v, n_rep, attn_mask): - return torch.ops.auto_deploy.torch_attention.default(q, k, v, attn_mask) - - -def _grouped_attn_pattern_6(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=False, - scale=scale, - enable_gqa=True, - ) - - -def _grouped_attn_replacement_6(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -def _grouped_attn_pattern_7(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=True, - scale=scale, - enable_gqa=True, - ) - - -def _grouped_attn_replacement_7(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention.default( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -def _grouped_attn_pattern_8(q, k, v, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, - k, - v, - attn_mask=None, - dropout_p=dropout_p, - is_causal=False, - scale=scale, - enable_gqa=True, - ) - - -def _grouped_attn_replacement_8(q, k, v, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention.default( - q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -def _grouped_attn_pattern_9(q, k, v, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, - k, - v, - attn_mask=None, - dropout_p=dropout_p, - is_causal=True, - scale=scale, - enable_gqa=True, - ) - - -def _grouped_attn_replacement_9(q, k, v, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention.default( - q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -def _grouped_attn_pattern_10(q, k, v, n_rep, dropout_p): - k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) - v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) - return torch.ops.auto_deploy.torch_attention_sdpa.default( - q, - k, - v, - attn_mask=None, - dropout_p=dropout_p, - is_causal=True, - ) - - -def _grouped_attn_replacement_10(q, k, v, n_rep, dropout_p): - return torch.ops.auto_deploy.torch_attention.default( - q, - k, - v, - attn_mask=None, - dropout_p=dropout_p, - is_causal=True, - ) - - @TransformRegistry.register("match_repeat_kv") class MatchRepeatKV(BaseTransform): """ @@ -507,6 +343,195 @@ def register_eager_attention(patterns: ADPatternMatcherPass): return gm, info +def make_grouped_attn_pair( + *, + repeat_kv: bool, + is_causal: bool, + has_scale: bool, + enable_gqa: bool, + has_attn_mask: bool, + has_dropout: bool, +): + """ + Returns (pattern_fn, replacement_fn, argnames) such that: + - pattern_fn(*args) calls torch_attention_sdpa.default with the specified knobs + and optional pre-repeat kv + - replacement_fn(*args) calls torch_attention.default mirroring signature exactly + - argnames is the ordered arg list for constructing dummy args + + Arg order rules: + - Base: (q, k, v) + - +repeat_kv -> insert n_rep after (q, k, v) + - +attn_mask -> include attn_mask after n_rep if repeat_kv else after (q, k, v) + - +dropout -> include dropout_p after attn_mask or after n_rep / base if no attn_mask + - +scale -> include scale last (after dropout_p if present, else after attn_mask/n_rep/base) + """ + # build signature + argnames: list[str] = ["q", "k", "v"] + if repeat_kv: + argnames.append("n_rep") + if has_attn_mask: + argnames.append("attn_mask") + if has_dropout: + argnames.append("dropout_p") + if has_scale: + argnames.append("scale") + + # helper to build call kwargs source strings + def _build_sdpa_kw_src(): + parts = [] + if has_attn_mask: + parts.append("attn_mask=attn_mask") + if has_dropout: + parts.append("dropout_p=dropout_p") + if has_scale: + parts.append("scale=scale") + parts.append(f"is_causal={str(is_causal)}") + if enable_gqa: + parts.append("enable_gqa=True") + return ", ".join(parts) + + def _build_attn_kw_src(): + parts = [] + if has_attn_mask: + parts.append("attn_mask=attn_mask") + if has_dropout: + parts.append("dropout_p=dropout_p") + if has_scale: + parts.append("scale=scale") + parts.append(f"is_causal={str(is_causal)}") + return ", ".join(parts) + + sdpa_kw_src = _build_sdpa_kw_src() + attn_kw_src = _build_attn_kw_src() + + # factories that also return the source we exec + def pattern_factory(argnames=tuple(argnames), repeat_kv=repeat_kv): + args_sig = ", ".join(argnames) + fn_name = ( + f"ga_pat_r{int(repeat_kv)}_c{int(is_causal)}_s{int(has_scale)}_" + f"g{int(enable_gqa)}_m{int(has_attn_mask)}_d{int(has_dropout)}" + ) + body_lines = [f"def {fn_name}({args_sig}):"] + if repeat_kv: + body_lines.append(" k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)") + body_lines.append(" v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)") + call_line = ( + " return torch.ops.auto_deploy.torch_attention_sdpa.default(" + f"q, k, v{', ' if sdpa_kw_src else ''}{sdpa_kw_src})" + ) + body_lines.append(call_line) + src = "\n".join(body_lines) + scope = {"torch": torch} + exec(src, scope) + return scope[fn_name] + + def replacement_factory(argnames=tuple(argnames)): + args_sig = ", ".join(argnames) + fn_name = ( + f"ga_rep_r{int(repeat_kv)}_c{int(is_causal)}_s{int(has_scale)}_" + f"g{int(enable_gqa)}_m{int(has_attn_mask)}_d{int(has_dropout)}" + ) + body = [f"def {fn_name}({args_sig}):"] + call_line = ( + " return torch.ops.auto_deploy.torch_attention.default(" + f"q, k, v{', ' if attn_kw_src else ''}{attn_kw_src})" + ) + body.append(call_line) + src = "\n".join(body) + scope = {"torch": torch} + exec(src, scope) + return scope[fn_name] + + pat_fn = pattern_factory() + rep_fn = replacement_factory() + + return pat_fn, rep_fn, argnames + + +def generate_and_register_grouped_attn_patterns(patterns, register_ad_pattern: Callable): + """ + Auto-generate all grouped attention patterns across these axes: + 1) repeat_kv: [False, True] + 2) is_causal: [False, True] + 3) has_scale: [False, True] + 4) enable_gqa: [False, True] (only a kwarg to SDPA side) + 5) has_attn_mask: [False, True] + 6) has_dropout: [False, True] + + For each valid combo, we: + - build pattern/replacement functions with exact-arg parity + - build dummy args matching the signature (with CUDA fp16 tensors etc.) + - build scalar_workaround dict for any scalars/n_rep present + - call register_ad_pattern(...) + """ + q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16) + k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16) + v1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16) + attn_mask_tensor = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16) + + dropout_val = 0.12345 + scale_val = 0.56789 + n_rep_val = 7 + + total = 0 + for repeat_kv in (False, True): + for is_causal in (False, True): + for has_scale in (False, True): + for enable_gqa in (False, True): + for has_attn_mask in (False, True): + for has_dropout in (False, True): + # Build functions + pat_fn, rep_fn, argnames = make_grouped_attn_pair( + repeat_kv=repeat_kv, + is_causal=is_causal, + has_scale=has_scale, + enable_gqa=enable_gqa, + has_attn_mask=has_attn_mask, + has_dropout=has_dropout, + ) + + # Build dummy args in the same positional order + dummy_args: List[object] = [] + for name in argnames: + if name == "q": + dummy_args.append(q) + elif name == "k": + dummy_args.append(k1) + elif name == "v": + dummy_args.append(v1) + elif name == "n_rep": + dummy_args.append(n_rep_val) + elif name == "attn_mask": + dummy_args.append(attn_mask_tensor) + elif name == "dropout_p": + dummy_args.append(dropout_val) + elif name == "scale": + dummy_args.append(scale_val) + else: + raise RuntimeError(f"Unexpected arg name: {name}") + + # scalar_workaround mirrors only the scalar args present by name + scalar_workaround: Dict[str, object] = {} + if "n_rep" in argnames: + scalar_workaround["n_rep"] = n_rep_val + if "dropout_p" in argnames: + scalar_workaround["dropout_p"] = dropout_val + if "scale" in argnames: + scalar_workaround["scale"] = scale_val + + # Register + register_ad_pattern( + search_fn=pat_fn, + replace_fn=rep_fn, + patterns=patterns, + dummy_args=dummy_args, + scalar_workaround=scalar_workaround if scalar_workaround else None, + ) + total += 1 + return total + + @TransformRegistry.register("match_grouped_attention") class MatchGroupedAttention(BaseTransform): """ @@ -522,99 +547,10 @@ def _apply( shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: def register_grouped_attention(patterns: ADPatternMatcherPass): - q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16) - k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16) - v1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16) - attn_mask = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16) - dropout = 0.12345 - scale = 0.56789 - n_rep = 7 - - dummy_args_1 = [q, k1, v1, n_rep, attn_mask, dropout, scale] - dummy_args_2 = [q, k1, v1, attn_mask, dropout, scale] - dummy_args_3 = [q, k1, v1, n_rep, attn_mask] - dummy_args_4 = [q, k1, v1, dropout, scale] - dummy_args_5 = [q, k1, v1, n_rep, dropout] - - register_ad_pattern( - search_fn=_grouped_attn_pattern_1, - replace_fn=_grouped_attn_replacement_1, - patterns=patterns, - dummy_args=dummy_args_1, - scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep}, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_2, - replace_fn=_grouped_attn_replacement_2, - patterns=patterns, - dummy_args=dummy_args_2, - scalar_workaround={ - "scale": scale, - "dropout_p": dropout, - }, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_3, - replace_fn=_grouped_attn_replacement_3, - patterns=patterns, - dummy_args=dummy_args_1, - scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep}, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_4, - replace_fn=_grouped_attn_replacement_4, - patterns=patterns, - dummy_args=dummy_args_2, - scalar_workaround={ - "scale": scale, - "dropout_p": dropout, - }, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_5, - replace_fn=_grouped_attn_replacement_5, - patterns=patterns, - dummy_args=dummy_args_3, - scalar_workaround={"n_rep": n_rep}, - ) - - register_ad_pattern( - search_fn=_grouped_attn_pattern_6, - replace_fn=_grouped_attn_replacement_6, - patterns=patterns, - dummy_args=dummy_args_2, - scalar_workaround={"scale": scale, "dropout_p": dropout}, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_7, - replace_fn=_grouped_attn_replacement_7, - patterns=patterns, - dummy_args=dummy_args_2, - scalar_workaround={"scale": scale, "dropout_p": dropout}, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_8, - replace_fn=_grouped_attn_replacement_8, - patterns=patterns, - dummy_args=dummy_args_4, - scalar_workaround={"scale": scale, "dropout_p": dropout}, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_9, - replace_fn=_grouped_attn_replacement_9, - patterns=patterns, - dummy_args=dummy_args_4, - scalar_workaround={"scale": scale, "dropout_p": dropout}, - ) - register_ad_pattern( - search_fn=_grouped_attn_pattern_10, - replace_fn=_grouped_attn_replacement_10, - patterns=patterns, - dummy_args=dummy_args_5, - scalar_workaround={"dropout_p": dropout, "n_rep": n_rep}, - ) + return generate_and_register_grouped_attn_patterns(patterns, register_ad_pattern) num_grouped_patterns = _apply_pattern(gm, "Grouped Attention", register_grouped_attention) + if num_grouped_patterns == 0: ad_logger.warning( "Fail to find any Group Attention Pattern, output or performance may be incorrect" @@ -626,190 +562,103 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): is_clean=False, has_valid_shapes=False, ) - return gm, info -def _attn_bnsd_pattern_1(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention.default( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=False, - scale=scale, - layout="bnsd", - ) - - -def _attn_bnsd_pattern_2(q, k, v, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention.default( - q, - k, - v, - attn_mask=None, - dropout_p=dropout_p, - is_causal=False, - scale=scale, - layout="bnsd", - ) - - -def _attn_bnsd_pattern_3(q, k, v, attn_mask, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention.default( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=True, - scale=scale, - layout="bnsd", - ) - - -def _attn_bnsd_pattern_4(q, k, v, dropout_p, scale): - return torch.ops.auto_deploy.torch_attention.default( - q, - k, - v, - attn_mask=None, - dropout_p=dropout_p, - is_causal=True, - scale=scale, - layout="bnsd", - ) - - -def _attn_bnsd_pattern_5(q, k, v, attn_mask, dropout_p): - return torch.ops.auto_deploy.torch_attention.default( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=False, - scale=None, - layout="bnsd", - ) - - -def _attn_bnsd_pattern_6(q, k, v, dropout_p): - return torch.ops.auto_deploy.torch_attention.default( - q, - k, - v, - attn_mask=None, - dropout_p=dropout_p, - is_causal=False, - scale=None, - layout="bnsd", - ) - - -def _attn_bnsd_pattern_7(q, k, v, attn_mask, dropout_p): - return torch.ops.auto_deploy.torch_attention.default( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=True, - scale=None, - layout="bnsd", - ) - - -def _attn_bnsd_pattern_8(q, k, v, dropout_p): - return torch.ops.auto_deploy.torch_attention.default( - q, - k, - v, - attn_mask=None, - dropout_p=dropout_p, - is_causal=True, - scale=None, - layout="bnsd", - ) - - -def _attn_bnsd_to_bnsd_via_bsnd(q, k, v, *, attn_mask, dropout_p, is_causal, scale): - q_bsnd = torch.ops.aten.transpose.int(q, 1, 2) - k_bsnd = torch.ops.aten.transpose.int(k, 1, 2) - v_bsnd = torch.ops.aten.transpose.int(v, 1, 2) - - out_bsnd = torch.ops.auto_deploy.torch_attention.default( - q_bsnd, - k_bsnd, - v_bsnd, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - layout="bsnd", - ) - return torch.ops.aten.transpose.int(out_bsnd, 1, 2) - - -# 1) is_causal=False, mask present, scale present -def _attn_bnsd_replacement_1(q, k, v, attn_mask, dropout_p, scale): - return _attn_bnsd_to_bnsd_via_bsnd( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -# 2) is_causal=False, mask None, scale present -def _attn_bnsd_replacement_2(q, k, v, dropout_p, scale): - return _attn_bnsd_to_bnsd_via_bsnd( - q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=scale - ) - - -# 3) is_causal=True, mask present, scale present -def _attn_bnsd_replacement_3(q, k, v, attn_mask, dropout_p, scale): - return _attn_bnsd_to_bnsd_via_bsnd( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -# 4) is_causal=True, mask None, scale present -def _attn_bnsd_replacement_4(q, k, v, dropout_p, scale): - return _attn_bnsd_to_bnsd_via_bsnd( - q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=True, scale=scale - ) - - -# 5) is_causal=False, mask present, scale=None -def _attn_bnsd_replacement_5(q, k, v, attn_mask, dropout_p): - return _attn_bnsd_to_bnsd_via_bsnd( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=None - ) - - -# 6) is_causal=False, mask None, scale=None -def _attn_bnsd_replacement_6(q, k, v, dropout_p): - return _attn_bnsd_to_bnsd_via_bsnd( - q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=None - ) - - -# 7) is_causal=True, mask present, scale=None -def _attn_bnsd_replacement_7(q, k, v, attn_mask, dropout_p): - return _attn_bnsd_to_bnsd_via_bsnd( - q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=None - ) +def make_attn_bnsd_pair( + *, + has_attn_mask: bool, + has_dropout: bool, + is_causal: bool, + has_scale: bool, +) -> Tuple[Callable, Callable, List[str], str, str]: + """ + Returns (pattern_fn, replacement_fn, argnames, pat_src, rep_src) + - pattern_fn(*args) matches torch_attention.default(..., layout="bnsd") + - replacement_fn(*args) transposes to BSND, runs torch_attention.default(..., layout="bsnd"), transposes back + - argnames is the ordered arg list: (q, k, v [, attn_mask] [, dropout_p] [, scale]) + - pat_src / rep_src are the exact function bodies (for debug logging) + """ + # signature in positional order + argnames: List[str] = ["q", "k", "v"] + if has_attn_mask: + argnames.append("attn_mask") + if has_dropout: + argnames.append("dropout_p") + if has_scale: + argnames.append("scale") + + # build kw parts (omit anything not present; always include is_causal; set layout explicitly + def _build_kw_src(layout_value: str) -> str: + parts = [] + if has_attn_mask: + parts.append("attn_mask=attn_mask") + if has_dropout: + parts.append("dropout_p=dropout_p") + if has_scale: + parts.append("scale=scale") + parts.append(f"is_causal={str(is_causal)}") + parts.append(f'layout="{layout_value}"') + return ", ".join(parts) + + bnsd_kw_src = _build_kw_src("bnsd") + bsnd_kw_src = _build_kw_src("bsnd") + + # factories: generate functions with explicit kwargs + def pattern_factory(argnames=tuple(argnames)): + args_sig = ", ".join(argnames) + fn_name = ( + f"attn_bnsd_pat_m{int(has_attn_mask)}_d{int(has_dropout)}_" + f"c{int(is_causal)}_s{int(has_scale)}" + ) + body = [f"def {fn_name}({args_sig}):"] + call = ( + " return torch.ops.auto_deploy.torch_attention.default(" + f"q, k, v{', ' if bnsd_kw_src else ''}{bnsd_kw_src})" + ) + body.append(call) + src = "\n".join(body) + scope = {"torch": torch} + exec(src, scope) + return scope[fn_name] + + def replacement_factory(argnames=tuple(argnames)): + args_sig = ", ".join(argnames) + fn_name = ( + f"attn_bnsd_rep_m{int(has_attn_mask)}_d{int(has_dropout)}_" + f"c{int(is_causal)}_s{int(has_scale)}" + ) + body = [f"def {fn_name}({args_sig}):"] + body.append(" q_bsnd = torch.ops.aten.transpose.int(q, 1, 2)") + body.append(" k_bsnd = torch.ops.aten.transpose.int(k, 1, 2)") + body.append(" v_bsnd = torch.ops.aten.transpose.int(v, 1, 2)") + call = ( + " out_bsnd = torch.ops.auto_deploy.torch_attention.default(" + f"q_bsnd, k_bsnd, v_bsnd{', ' if bsnd_kw_src else ''}{bsnd_kw_src})" + ) + body.append(call) + body.append(" return torch.ops.aten.transpose.int(out_bsnd, 1, 2)") + src = "\n".join(body) + scope = {"torch": torch} + exec(src, scope) + return scope[fn_name] + pat_fn = pattern_factory() + rep_fn = replacement_factory() -# 8) is_causal=True, mask None, scale=None -def _attn_bnsd_replacement_8(q, k, v, dropout_p): - return _attn_bnsd_to_bnsd_via_bsnd( - q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=True, scale=None - ) + return pat_fn, rep_fn, argnames -def register_match_attn_layout(patterns: ADPatternMatcherPass): - # Dummy tensors in BNSD (we match bnsd calls) +def generate_and_register_attn_layout_patterns(patterns, register_ad_pattern: Callable): + """ + Enumerate all combinations across: + - has_attn_mask in {False, True} + - has_dropout in {False, True} + - is_causal in {False, True} + - has_scale in {False, True} + Register each pattern/replacement with appropriate dummy args and scalar workarounds. + """ + # Dummy tensors in BNSD bs, n_heads, s_q, head_dim = 8, 8, 16, 64 q = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16) k = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16) @@ -819,65 +668,58 @@ def register_match_attn_layout(patterns: ADPatternMatcherPass): dropout_p = 0.12345 scale_val = 0.56789 - # 1..4 (scale present) - register_ad_pattern( - search_fn=_attn_bnsd_pattern_1, - replace_fn=_attn_bnsd_replacement_1, - patterns=patterns, - dummy_args=[q, k, v, attn_mask, dropout_p, scale_val], - scalar_workaround={"dropout_p": dropout_p, "scale": scale_val}, - ) - register_ad_pattern( - search_fn=_attn_bnsd_pattern_2, - replace_fn=_attn_bnsd_replacement_2, - patterns=patterns, - dummy_args=[q, k, v, dropout_p, scale_val], - scalar_workaround={"dropout_p": dropout_p, "scale": scale_val}, - ) - register_ad_pattern( - search_fn=_attn_bnsd_pattern_3, - replace_fn=_attn_bnsd_replacement_3, - patterns=patterns, - dummy_args=[q, k, v, attn_mask, dropout_p, scale_val], - scalar_workaround={"dropout_p": dropout_p, "scale": scale_val}, - ) - register_ad_pattern( - search_fn=_attn_bnsd_pattern_4, - replace_fn=_attn_bnsd_replacement_4, - patterns=patterns, - dummy_args=[q, k, v, dropout_p, scale_val], - scalar_workaround={"dropout_p": dropout_p, "scale": scale_val}, - ) + total = 0 + for has_attn_mask in (False, True): + for has_dropout in (False, True): + for is_causal in (False, True): + for has_scale in (False, True): + pat_fn, rep_fn, argnames = make_attn_bnsd_pair( + has_attn_mask=has_attn_mask, + has_dropout=has_dropout, + is_causal=is_causal, + has_scale=has_scale, + ) + + # Build dummy args following positional signature + dummy_args: List[object] = [] + for name in argnames: + if name == "q": + dummy_args.append(q) + elif name == "k": + dummy_args.append(k) + elif name == "v": + dummy_args.append(v) + elif name == "attn_mask": + dummy_args.append(attn_mask) + elif name == "dropout_p": + dummy_args.append(dropout_p) + elif name == "scale": + dummy_args.append(scale_val) + else: + raise RuntimeError(f"Unexpected arg name: {name}") + + # Scalar workaround for present scalars only + scalar_workaround = {} + if has_dropout: + scalar_workaround["dropout_p"] = dropout_p + if has_scale: + scalar_workaround["scale"] = scale_val + if not scalar_workaround: + scalar_workaround = None + + register_ad_pattern( + search_fn=pat_fn, + replace_fn=rep_fn, + patterns=patterns, + dummy_args=dummy_args, + scalar_workaround=scalar_workaround, + ) + total += 1 + return total - # 5..8 (scale None) - register_ad_pattern( - search_fn=_attn_bnsd_pattern_5, - replace_fn=_attn_bnsd_replacement_5, - patterns=patterns, - dummy_args=[q, k, v, attn_mask, dropout_p], - scalar_workaround={"dropout_p": dropout_p}, - ) - register_ad_pattern( - search_fn=_attn_bnsd_pattern_6, - replace_fn=_attn_bnsd_replacement_6, - patterns=patterns, - dummy_args=[q, k, v, dropout_p], - scalar_workaround={"dropout_p": dropout_p}, - ) - register_ad_pattern( - search_fn=_attn_bnsd_pattern_7, - replace_fn=_attn_bnsd_replacement_7, - patterns=patterns, - dummy_args=[q, k, v, attn_mask, dropout_p], - scalar_workaround={"dropout_p": dropout_p}, - ) - register_ad_pattern( - search_fn=_attn_bnsd_pattern_8, - replace_fn=_attn_bnsd_replacement_8, - patterns=patterns, - dummy_args=[q, k, v, dropout_p], - scalar_workaround={"dropout_p": dropout_p}, - ) + +def register_match_attn_layout(patterns: ADPatternMatcherPass): + return generate_and_register_attn_layout_patterns(patterns, register_ad_pattern) class MatchAttentionLayoutConfig(TransformConfig): From 80ef42c12695884686d34f5dc9e7aeebe75ac212 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 3 Oct 2025 05:42:10 +0000 Subject: [PATCH 4/5] add some feedback Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- .../custom_ops/flashinfer_attention.py | 15 +++++++++++++++ .../custom_ops/torch_backend_attention.py | 15 +++++++++++++++ .../auto_deploy/custom_ops/triton_attention.py | 15 +++++++++++++++ .../transform/library/kvcache_transformers.py | 2 +- 4 files changed, 46 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 78535008092..3200a21937d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -399,6 +399,21 @@ def _init_workspace(si: SequenceInfo) -> torch.Tensor: @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: + # Sanity check: layout == "bsnd" + # Prefer kwargs; fall back to the final positional arg if it's a string. + layout = source_attn_node.kwargs.get("layout", None) + if ( + layout is None + and len(source_attn_node.args) > 0 + and isinstance(source_attn_node.args[-1], str) + ): + layout = source_attn_node.args[-1] + if layout != "bsnd": + raise RuntimeError( + f"Expected torch_attention layout='bsnd' but got {layout!r} " + f"for node: {source_attn_node.format_node()}" + ) + # Double check other arguments attn_mask, dropout_p, is_causal = extract_op_args( source_attn_node, "attn_mask", "dropout_p", "is_causal" diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py index 949817da963..6eadb4b4466 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py @@ -460,6 +460,21 @@ def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitial @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: + # Sanity check: layout == "bsnd" + # Prefer kwargs; fall back to the final positional arg if it's a string. + layout = source_attn_node.kwargs.get("layout", None) + if ( + layout is None + and len(source_attn_node.args) > 0 + and isinstance(source_attn_node.args[-1], str) + ): + layout = source_attn_node.args[-1] + if layout != "bsnd": + raise RuntimeError( + f"Expected torch_attention layout='bsnd' but got {layout!r} " + f"for node: {source_attn_node.format_node()}" + ) + # Check other arguments attn_mask, dropout_p, is_causal = extract_op_args( source_attn_node, "attn_mask", "dropout_p", "is_causal" diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index a0d8693fe39..56aad993a3c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -390,6 +390,21 @@ def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitial @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: + # Sanity check: layout == "bsnd" + # Prefer kwargs; fall back to the final positional arg if it's a string. + layout = source_attn_node.kwargs.get("layout", None) + if ( + layout is None + and len(source_attn_node.args) > 0 + and isinstance(source_attn_node.args[-1], str) + ): + layout = source_attn_node.args[-1] + if layout != "bsnd": + raise RuntimeError( + f"Expected torch_attention layout='bsnd' but got {layout!r} " + f"for node: {source_attn_node.format_node()}" + ) + # retrieve head_dim from k_fake attn_mask, dropout_p, is_causal = extract_op_args( source_attn_node, "attn_mask", "dropout_p", "is_causal" diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py index 8c47fb4002a..ed218061ef4 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py @@ -46,7 +46,6 @@ def fake_profiler_mha( "sinks": "sinks", "sliding_window": "sliding_window", "logit_cap": "logit_cap", - "layout": "bsnd", } for k_kwargs, k_op_kwargs in kwargs_to_op.items(): if k_kwargs in kwargs: @@ -61,6 +60,7 @@ def fake_profiler_mha( v_fake = graph.placeholder(name="v_fake") v_fake.meta["val"] = torch.empty_like(value.transpose(1, 2), device="meta", dtype=value.dtype) + node_kwargs["layout"] = "bsnd" module._node_ref = graph.call_function( torch.ops.auto_deploy.torch_attention, args=(q_fake, k_fake, v_fake), From 8795861c1a8797a396a8516ae7b1ca440044f3af Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 3 Oct 2025 16:01:22 +0000 Subject: [PATCH 5/5] get rid of exec() Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- .../transform/library/attention.py | 458 +++++++++--------- 1 file changed, 217 insertions(+), 241 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py index d3e50d2d447..9468220c59e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py @@ -1,5 +1,7 @@ """Pattern matching for detecting repeat_kv, eager, grouped attention patterns from Huggingface models.""" +from inspect import Parameter, Signature +from itertools import product from typing import Any, Callable, Dict, List, Tuple, Type import torch @@ -343,6 +345,39 @@ def register_eager_attention(patterns: ADPatternMatcherPass): return gm, info +def _attach_signature(fn: Callable, argnames: List[str]) -> Callable: + # Make FX "see" q,k,v[,attn_mask][,dropout_p][,scale] even though fn(*args) internally + params = [Parameter(n, kind=Parameter.POSITIONAL_OR_KEYWORD) for n in argnames] + fn.__signature__ = Signature(parameters=params) + return fn + + +def _call_sdpa( + q, k, v, *, is_causal: bool, enable_gqa: bool, attn_mask=None, dropout_p=None, scale=None +): + kwargs = {"is_causal": is_causal} + if attn_mask is not None: + kwargs["attn_mask"] = attn_mask + if dropout_p is not None: + kwargs["dropout_p"] = dropout_p + if scale is not None: + kwargs["scale"] = scale + if enable_gqa: + kwargs["enable_gqa"] = True + return torch.ops.auto_deploy.torch_attention_sdpa.default(q, k, v, **kwargs) + + +def _call_attn(q, k, v, *, is_causal: bool, attn_mask=None, dropout_p=None, scale=None): + kwargs = {"is_causal": is_causal} + if attn_mask is not None: + kwargs["attn_mask"] = attn_mask + if dropout_p is not None: + kwargs["dropout_p"] = dropout_p + if scale is not None: + kwargs["scale"] = scale + return torch.ops.auto_deploy.torch_attention.default(q, k, v, **kwargs) + + def make_grouped_attn_pair( *, repeat_kv: bool, @@ -351,23 +386,18 @@ def make_grouped_attn_pair( enable_gqa: bool, has_attn_mask: bool, has_dropout: bool, -): +) -> Tuple[Callable, Callable, List[str]]: """ - Returns (pattern_fn, replacement_fn, argnames) such that: - - pattern_fn(*args) calls torch_attention_sdpa.default with the specified knobs - and optional pre-repeat kv - - replacement_fn(*args) calls torch_attention.default mirroring signature exactly - - argnames is the ordered arg list for constructing dummy args + Returns (pattern_fn, replacement_fn, argnames) with exact positional parity. Arg order rules: - - Base: (q, k, v) - - +repeat_kv -> insert n_rep after (q, k, v) - - +attn_mask -> include attn_mask after n_rep if repeat_kv else after (q, k, v) - - +dropout -> include dropout_p after attn_mask or after n_rep / base if no attn_mask - - +scale -> include scale last (after dropout_p if present, else after attn_mask/n_rep/base) + Base: (q, k, v) + +repeat_kv -> insert n_rep after (q, k, v) + +attn_mask -> include attn_mask after n_rep if repeat_kv else after (q, k, v) + +dropout -> include dropout_p after attn_mask or after n_rep/base if no attn_mask + +scale -> include scale last """ - # build signature - argnames: list[str] = ["q", "k", "v"] + argnames: List[str] = ["q", "k", "v"] if repeat_kv: argnames.append("n_rep") if has_attn_mask: @@ -377,76 +407,53 @@ def make_grouped_attn_pair( if has_scale: argnames.append("scale") - # helper to build call kwargs source strings - def _build_sdpa_kw_src(): - parts = [] - if has_attn_mask: - parts.append("attn_mask=attn_mask") - if has_dropout: - parts.append("dropout_p=dropout_p") - if has_scale: - parts.append("scale=scale") - parts.append(f"is_causal={str(is_causal)}") - if enable_gqa: - parts.append("enable_gqa=True") - return ", ".join(parts) - - def _build_attn_kw_src(): - parts = [] - if has_attn_mask: - parts.append("attn_mask=attn_mask") - if has_dropout: - parts.append("dropout_p=dropout_p") - if has_scale: - parts.append("scale=scale") - parts.append(f"is_causal={str(is_causal)}") - return ", ".join(parts) - - sdpa_kw_src = _build_sdpa_kw_src() - attn_kw_src = _build_attn_kw_src() - - # factories that also return the source we exec - def pattern_factory(argnames=tuple(argnames), repeat_kv=repeat_kv): - args_sig = ", ".join(argnames) - fn_name = ( - f"ga_pat_r{int(repeat_kv)}_c{int(is_causal)}_s{int(has_scale)}_" - f"g{int(enable_gqa)}_m{int(has_attn_mask)}_d{int(has_dropout)}" - ) - body_lines = [f"def {fn_name}({args_sig}):"] + def pattern_fn(*args): + if len(args) != len(argnames): + raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}") + m = dict(zip(argnames, args)) + + q = m["q"] + k = m["k"] + v = m["v"] + if repeat_kv: - body_lines.append(" k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)") - body_lines.append(" v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)") - call_line = ( - " return torch.ops.auto_deploy.torch_attention_sdpa.default(" - f"q, k, v{', ' if sdpa_kw_src else ''}{sdpa_kw_src})" + n_rep = m["n_rep"] + k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) + v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) + + return _call_sdpa( + q, + k, + v, + is_causal=is_causal, + enable_gqa=enable_gqa, + attn_mask=m.get("attn_mask"), + dropout_p=m.get("dropout_p"), + scale=m.get("scale"), ) - body_lines.append(call_line) - src = "\n".join(body_lines) - scope = {"torch": torch} - exec(src, scope) - return scope[fn_name] - - def replacement_factory(argnames=tuple(argnames)): - args_sig = ", ".join(argnames) - fn_name = ( - f"ga_rep_r{int(repeat_kv)}_c{int(is_causal)}_s{int(has_scale)}_" - f"g{int(enable_gqa)}_m{int(has_attn_mask)}_d{int(has_dropout)}" - ) - body = [f"def {fn_name}({args_sig}):"] - call_line = ( - " return torch.ops.auto_deploy.torch_attention.default(" - f"q, k, v{', ' if attn_kw_src else ''}{attn_kw_src})" + + # Replacement: torch_attention.default mirroring the positional signature exactly. + # We do NOT pass enable_gqa here (it’s SDPA-only). We accept n_rep to mirror signature, + # but we don’t need to use it in the replacement graph. + def replacement_fn(*args): + if len(args) != len(argnames): + raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}") + m = dict(zip(argnames, args)) + return _call_attn( + m["q"], + m["k"], + m["v"], + is_causal=is_causal, + attn_mask=m.get("attn_mask"), + dropout_p=m.get("dropout_p"), + scale=m.get("scale"), ) - body.append(call_line) - src = "\n".join(body) - scope = {"torch": torch} - exec(src, scope) - return scope[fn_name] - pat_fn = pattern_factory() - rep_fn = replacement_factory() + # Pattern matcher needs to see explicit arg names + _attach_signature(pattern_fn, argnames) + _attach_signature(replacement_fn, argnames) - return pat_fn, rep_fn, argnames + return pattern_fn, replacement_fn, argnames def generate_and_register_grouped_attn_patterns(patterns, register_ad_pattern: Callable): @@ -475,60 +482,49 @@ def generate_and_register_grouped_attn_patterns(patterns, register_ad_pattern: C n_rep_val = 7 total = 0 - for repeat_kv in (False, True): - for is_causal in (False, True): - for has_scale in (False, True): - for enable_gqa in (False, True): - for has_attn_mask in (False, True): - for has_dropout in (False, True): - # Build functions - pat_fn, rep_fn, argnames = make_grouped_attn_pair( - repeat_kv=repeat_kv, - is_causal=is_causal, - has_scale=has_scale, - enable_gqa=enable_gqa, - has_attn_mask=has_attn_mask, - has_dropout=has_dropout, - ) - - # Build dummy args in the same positional order - dummy_args: List[object] = [] - for name in argnames: - if name == "q": - dummy_args.append(q) - elif name == "k": - dummy_args.append(k1) - elif name == "v": - dummy_args.append(v1) - elif name == "n_rep": - dummy_args.append(n_rep_val) - elif name == "attn_mask": - dummy_args.append(attn_mask_tensor) - elif name == "dropout_p": - dummy_args.append(dropout_val) - elif name == "scale": - dummy_args.append(scale_val) - else: - raise RuntimeError(f"Unexpected arg name: {name}") - - # scalar_workaround mirrors only the scalar args present by name - scalar_workaround: Dict[str, object] = {} - if "n_rep" in argnames: - scalar_workaround["n_rep"] = n_rep_val - if "dropout_p" in argnames: - scalar_workaround["dropout_p"] = dropout_val - if "scale" in argnames: - scalar_workaround["scale"] = scale_val - - # Register - register_ad_pattern( - search_fn=pat_fn, - replace_fn=rep_fn, - patterns=patterns, - dummy_args=dummy_args, - scalar_workaround=scalar_workaround if scalar_workaround else None, - ) - total += 1 + axes = ((False, True),) * 6 + for repeat_kv, is_causal, has_scale, enable_gqa, has_attn_mask, has_dropout in product(*axes): + pat_fn, rep_fn, argnames = make_grouped_attn_pair( + repeat_kv=repeat_kv, + is_causal=is_causal, + has_scale=has_scale, + enable_gqa=enable_gqa, + has_attn_mask=has_attn_mask, + has_dropout=has_dropout, + ) + + # Build dummy args in the same positional order + value_map = { + "q": q, + "k": k1, + "v": v1, + "n_rep": n_rep_val, + "attn_mask": attn_mask_tensor, + "dropout_p": dropout_val, + "scale": scale_val, + } + dummy_args: List[object] = [] + for name in argnames: + try: + dummy_args.append(value_map[name]) + except KeyError: + raise RuntimeError(f"Unexpected arg name: {name}") + + scalar_names = {"n_rep", "dropout_p", "scale"} + scalar_workaround: Dict[str, object] = { + n: value_map[n] for n in argnames if n in scalar_names + } + if not scalar_workaround: + scalar_workaround = None + + register_ad_pattern( + search_fn=pat_fn, + replace_fn=rep_fn, + patterns=patterns, + dummy_args=dummy_args, + scalar_workaround=scalar_workaround, + ) + total += 1 return total @@ -565,6 +561,19 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): return gm, info +def _call_torch_attention( + q, k, v, *, is_causal, layout, attn_mask=None, dropout_p=None, scale=None +): + kwargs = {"is_causal": is_causal, "layout": layout} + if attn_mask is not None: + kwargs["attn_mask"] = attn_mask + if dropout_p is not None: + kwargs["dropout_p"] = dropout_p + if scale is not None: + kwargs["scale"] = scale + return torch.ops.auto_deploy.torch_attention.default(q, k, v, **kwargs) + + def make_attn_bnsd_pair( *, has_attn_mask: bool, @@ -572,14 +581,6 @@ def make_attn_bnsd_pair( is_causal: bool, has_scale: bool, ) -> Tuple[Callable, Callable, List[str], str, str]: - """ - Returns (pattern_fn, replacement_fn, argnames, pat_src, rep_src) - - pattern_fn(*args) matches torch_attention.default(..., layout="bnsd") - - replacement_fn(*args) transposes to BSND, runs torch_attention.default(..., layout="bsnd"), transposes back - - argnames is the ordered arg list: (q, k, v [, attn_mask] [, dropout_p] [, scale]) - - pat_src / rep_src are the exact function bodies (for debug logging) - """ - # signature in positional order argnames: List[str] = ["q", "k", "v"] if has_attn_mask: argnames.append("attn_mask") @@ -588,65 +589,45 @@ def make_attn_bnsd_pair( if has_scale: argnames.append("scale") - # build kw parts (omit anything not present; always include is_causal; set layout explicitly - def _build_kw_src(layout_value: str) -> str: - parts = [] - if has_attn_mask: - parts.append("attn_mask=attn_mask") - if has_dropout: - parts.append("dropout_p=dropout_p") - if has_scale: - parts.append("scale=scale") - parts.append(f"is_causal={str(is_causal)}") - parts.append(f'layout="{layout_value}"') - return ", ".join(parts) - - bnsd_kw_src = _build_kw_src("bnsd") - bsnd_kw_src = _build_kw_src("bsnd") - - # factories: generate functions with explicit kwargs - def pattern_factory(argnames=tuple(argnames)): - args_sig = ", ".join(argnames) - fn_name = ( - f"attn_bnsd_pat_m{int(has_attn_mask)}_d{int(has_dropout)}_" - f"c{int(is_causal)}_s{int(has_scale)}" + def pattern_fn(*args): + if len(args) != len(argnames): + raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}") + m = dict(zip(argnames, args)) + return _call_torch_attention( + m["q"], + m["k"], + m["v"], + is_causal=is_causal, + layout="bnsd", + attn_mask=m.get("attn_mask"), + dropout_p=m.get("dropout_p"), + scale=m.get("scale"), ) - body = [f"def {fn_name}({args_sig}):"] - call = ( - " return torch.ops.auto_deploy.torch_attention.default(" - f"q, k, v{', ' if bnsd_kw_src else ''}{bnsd_kw_src})" - ) - body.append(call) - src = "\n".join(body) - scope = {"torch": torch} - exec(src, scope) - return scope[fn_name] - - def replacement_factory(argnames=tuple(argnames)): - args_sig = ", ".join(argnames) - fn_name = ( - f"attn_bnsd_rep_m{int(has_attn_mask)}_d{int(has_dropout)}_" - f"c{int(is_causal)}_s{int(has_scale)}" - ) - body = [f"def {fn_name}({args_sig}):"] - body.append(" q_bsnd = torch.ops.aten.transpose.int(q, 1, 2)") - body.append(" k_bsnd = torch.ops.aten.transpose.int(k, 1, 2)") - body.append(" v_bsnd = torch.ops.aten.transpose.int(v, 1, 2)") - call = ( - " out_bsnd = torch.ops.auto_deploy.torch_attention.default(" - f"q_bsnd, k_bsnd, v_bsnd{', ' if bsnd_kw_src else ''}{bsnd_kw_src})" + + def replacement_fn(*args): + if len(args) != len(argnames): + raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}") + m = dict(zip(argnames, args)) + q_b = torch.ops.aten.transpose.int(m["q"], 1, 2) + k_b = torch.ops.aten.transpose.int(m["k"], 1, 2) + v_b = torch.ops.aten.transpose.int(m["v"], 1, 2) + out_b = _call_torch_attention( + q_b, + k_b, + v_b, + is_causal=is_causal, + layout="bsnd", + attn_mask=m.get("attn_mask"), + dropout_p=m.get("dropout_p"), + scale=m.get("scale"), ) - body.append(call) - body.append(" return torch.ops.aten.transpose.int(out_bsnd, 1, 2)") - src = "\n".join(body) - scope = {"torch": torch} - exec(src, scope) - return scope[fn_name] + return torch.ops.aten.transpose.int(out_b, 1, 2) - pat_fn = pattern_factory() - rep_fn = replacement_factory() + # Pattern matcher needs to see explicit arg names + _attach_signature(pattern_fn, argnames) + _attach_signature(replacement_fn, argnames) - return pat_fn, rep_fn, argnames + return pattern_fn, replacement_fn, argnames def generate_and_register_attn_layout_patterns(patterns, register_ad_pattern: Callable): @@ -669,52 +650,47 @@ def generate_and_register_attn_layout_patterns(patterns, register_ad_pattern: Ca scale_val = 0.56789 total = 0 - for has_attn_mask in (False, True): - for has_dropout in (False, True): - for is_causal in (False, True): - for has_scale in (False, True): - pat_fn, rep_fn, argnames = make_attn_bnsd_pair( - has_attn_mask=has_attn_mask, - has_dropout=has_dropout, - is_causal=is_causal, - has_scale=has_scale, - ) - - # Build dummy args following positional signature - dummy_args: List[object] = [] - for name in argnames: - if name == "q": - dummy_args.append(q) - elif name == "k": - dummy_args.append(k) - elif name == "v": - dummy_args.append(v) - elif name == "attn_mask": - dummy_args.append(attn_mask) - elif name == "dropout_p": - dummy_args.append(dropout_p) - elif name == "scale": - dummy_args.append(scale_val) - else: - raise RuntimeError(f"Unexpected arg name: {name}") - - # Scalar workaround for present scalars only - scalar_workaround = {} - if has_dropout: - scalar_workaround["dropout_p"] = dropout_p - if has_scale: - scalar_workaround["scale"] = scale_val - if not scalar_workaround: - scalar_workaround = None - - register_ad_pattern( - search_fn=pat_fn, - replace_fn=rep_fn, - patterns=patterns, - dummy_args=dummy_args, - scalar_workaround=scalar_workaround, - ) - total += 1 + axes = ((False, True),) * 4 + for has_attn_mask, has_dropout, is_causal, has_scale in product(*axes): + pat_fn, rep_fn, argnames = make_attn_bnsd_pair( + has_attn_mask=has_attn_mask, + has_dropout=has_dropout, + is_causal=is_causal, + has_scale=has_scale, + ) + + # Build dummy args following positional signature + value_map = { + "q": q, + "k": k, + "v": v, + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "scale": scale_val, + } + dummy_args: List[object] = [] + for name in argnames: + try: + dummy_args.append(value_map[name]) + except KeyError: + raise RuntimeError(f"Unexpected arg name: {name}") + + # Scalar workaround for present scalars only + scalar_names = {"dropout_p", "scale"} + scalar_workaround: Dict[str, object] = { + n: value_map[n] for n in argnames if n in scalar_names + } + if not scalar_workaround: + scalar_workaround = None + + register_ad_pattern( + search_fn=pat_fn, + replace_fn=rep_fn, + patterns=patterns, + dummy_args=dummy_args, + scalar_workaround=scalar_workaround, + ) + total += 1 return total