Skip to content

Commit 16d3071

Browse files
committed
WIP: unify bsnd_group_attention and group_attention, update match_attention_layout to use pattern matcher
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
1 parent 985fb44 commit 16d3071

File tree

12 files changed

+461
-113
lines changed

12 files changed

+461
-113
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def get_num_qkv_args(cls) -> int:
354354
@classmethod
355355
def get_source_attention_op(cls) -> OpOverloadPacket:
356356
"""Get the source attention op that we target for replacement."""
357-
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
357+
return torch.ops.auto_deploy.torch_attention
358358

359359
@classmethod
360360
def get_cached_attention_op(cls) -> MHACallable:

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,144 @@ def bsnd_grouped_sdpa_fake(
257257
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
258258

259259

260+
# Unified attention op
261+
@torch.library.custom_op("auto_deploy::torch_attention", mutates_args=())
262+
def torch_attention(
263+
query: torch.Tensor,
264+
key: torch.Tensor,
265+
value: torch.Tensor,
266+
attn_mask: Optional[torch.Tensor] = None,
267+
dropout_p: float = 0.0,
268+
is_causal: bool = False,
269+
scale: Optional[float] = None,
270+
sinks: Optional[torch.Tensor] = None,
271+
sliding_window: Optional[int] = None,
272+
logit_cap: Optional[float] = None,
273+
layout: str = "bnsd", # "bnsd" or "bsnd"
274+
) -> torch.Tensor:
275+
"""
276+
SDPA attention (with optional GQA) that supports two memory layouts via `layout`:
277+
- "bnsd": [batch, num_heads, seq_len, head_dim]
278+
- "bsnd": [batch, seq_len, num_heads, head_dim]
279+
280+
The `attn_mask` is always interpreted as [b, n, s_q, s_k].
281+
282+
Returns a tensor in the SAME layout as inputs specified by `layout`.
283+
"""
284+
if layout not in ("bnsd", "bsnd"):
285+
raise ValueError(f"layout must be 'bnsd' or 'bsnd', got {layout!r}")
286+
287+
if layout == "bsnd":
288+
query = query.transpose(1, 2).contiguous()
289+
key = key.transpose(1, 2).contiguous()
290+
value = value.transpose(1, 2).contiguous()
291+
292+
b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim]
293+
_, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]
294+
295+
# Inputs are already in bnsd format, no need to transpose
296+
query_t = query # [b, n_heads, s_q, head_dim]
297+
key_t = key # [b, n_kv_heads, s_k, head_dim]
298+
value_t = value # [b, n_kv_heads, s_k, v_head_dim]
299+
300+
# Handle GQA by repeating KV if needed
301+
if n_heads != n_kv_heads:
302+
n_rep = n_heads // n_kv_heads
303+
key_t = repeat_kv(key_t, n_rep)
304+
value_t = repeat_kv(value_t, n_rep)
305+
306+
# Set scale
307+
if scale is None:
308+
scale = 1.0 / math.sqrt(head_dim)
309+
310+
# Compute attention scores: Q @ K^T
311+
attn_scores = torch.matmul(query_t, key_t.transpose(-2, -1)) * scale # [b, n_heads, s_q, s_k]
312+
313+
# Apply attention mask if provided
314+
if attn_mask is not None:
315+
# Convert boolean mask to float if needed
316+
attn_mask = _convert_boolean_mask_to_float(attn_mask, attn_scores.dtype)
317+
attn_scores = attn_scores + attn_mask
318+
319+
# Apply causal mask if specified and only during the context phase
320+
if is_causal and s_q == s_k: # Only apply causal mask during context processing
321+
causal_mask = torch.triu(
322+
torch.ones(s_q, s_k, device=query.device, dtype=torch.bool),
323+
diagonal=1, # Use diagonal=1 for standard causal masking
324+
)
325+
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
326+
327+
# Apply sliding window mask if specified
328+
if sliding_window is not None and sliding_window > 0:
329+
# Handle position calculation for both context and generation phases
330+
if s_q == s_k:
331+
# Context phase: standard position calculation
332+
query_positions = torch.arange(s_q, device=query.device)
333+
key_positions = torch.arange(s_k, device=query.device)
334+
else:
335+
# Generation phase: query is at position s_k (after the cache)
336+
query_positions = torch.arange(s_k, s_k + s_q, device=query.device) # [s_k] for s_q=1
337+
key_positions = torch.arange(s_k, device=query.device) # [0,1,2,...,s_k-1]
338+
339+
# Create position difference matrix: query_pos - key_pos
340+
pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) # [s_q, s_k]
341+
342+
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
343+
sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window) # [s_q, s_k]
344+
attn_scores.masked_fill_(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
345+
346+
# Apply logit softcapping if enabled
347+
attn_scores = _apply_logit_softcapping(attn_scores, logit_cap)
348+
349+
# Apply sinks if provided
350+
if sinks is not None:
351+
# Concatenate sinks to attention scores following the reference implementation
352+
# sinks should have n_heads elements, each head gets its own sink value
353+
# Expand sinks to [b, n_heads, s_q, 1] - one sink column per head
354+
sinks_expanded = sinks.reshape(1, -1, 1, 1).expand(
355+
b, n_heads, s_q, 1
356+
) # [b, n_heads, s_q, 1]
357+
358+
# Concatenate along the key dimension (last dimension)
359+
logits_max = torch.max(attn_scores, dim=-1, keepdim=True).values
360+
sinks = torch.exp(sinks_expanded - logits_max)
361+
unnormalized_scores = torch.exp(attn_scores - logits_max)
362+
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
363+
scores = unnormalized_scores / normalizer
364+
# Use only the non-sink portion for computing output
365+
# We added exactly 1 column, so remove exactly 1 column
366+
attn_out = torch.matmul(scores, value_t) # [b, n_heads, s_q, v_head_dim]
367+
else:
368+
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype)
369+
attn_out = torch.matmul(attn_weights, value_t) # [b, n_heads, s_q, v_head_dim]
370+
371+
# Apply dropout if specified
372+
if dropout_p > 0.0:
373+
attn_out = F.dropout(attn_out, p=dropout_p, training=False)
374+
375+
if layout == "bsnd":
376+
return attn_out.transpose(1, 2).contiguous()
377+
else:
378+
return attn_out.contiguous()
379+
380+
381+
@torch_attention.register_fake
382+
def torch_attention_fake(
383+
query,
384+
key,
385+
value,
386+
attn_mask=None,
387+
dropout_p: float = 0.0,
388+
is_causal: bool = False,
389+
scale=None,
390+
sinks=None,
391+
sliding_window=None,
392+
logit_cap=None,
393+
layout: str = "bnsd",
394+
):
395+
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
396+
397+
260398
def update_kv_cache(
261399
key_states: torch.Tensor,
262400
value_states: torch.Tensor,

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def get_num_qkv_args(cls) -> int:
408408

409409
@classmethod
410410
def get_source_attention_op(cls) -> OpOverloadPacket:
411-
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
411+
return torch.ops.auto_deploy.torch_attention
412412

413413
@classmethod
414414
def get_cached_attention_op(cls) -> MHACallable:

tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def get_num_qkv_args(cls) -> int:
337337

338338
@classmethod
339339
def get_source_attention_op(cls) -> OpOverloadPacket:
340-
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
340+
return torch.ops.auto_deploy.torch_attention
341341

342342
@classmethod
343343
def get_cached_attention_op(cls) -> MHACallable:

tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def gpt_oss_attention(
5959
sinks = self.sinks
6060

6161
# Use custom op to capture attention. This layout is bsnd (batch, seq, num_heads, head_dim)
62-
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(
62+
attn_output = torch.ops.auto_deploy.torch_attention(
6363
query_states,
6464
key_states,
6565
value_states,
@@ -69,6 +69,7 @@ def gpt_oss_attention(
6969
scale=self.scaling,
7070
sinks=sinks,
7171
sliding_window=sliding_window,
72+
layout="bsnd",
7273
)
7374

7475
# Reshape back to original input shape

0 commit comments

Comments
 (0)