@@ -257,6 +257,144 @@ def bsnd_grouped_sdpa_fake(
257
257
return query .new_empty (* query .shape [:- 1 ], value .shape [- 1 ]).contiguous ()
258
258
259
259
260
+ # Unified attention op
261
+ @torch .library .custom_op ("auto_deploy::torch_attention" , mutates_args = ())
262
+ def torch_attention (
263
+ query : torch .Tensor ,
264
+ key : torch .Tensor ,
265
+ value : torch .Tensor ,
266
+ attn_mask : Optional [torch .Tensor ] = None ,
267
+ dropout_p : float = 0.0 ,
268
+ is_causal : bool = False ,
269
+ scale : Optional [float ] = None ,
270
+ sinks : Optional [torch .Tensor ] = None ,
271
+ sliding_window : Optional [int ] = None ,
272
+ logit_cap : Optional [float ] = None ,
273
+ layout : str = "bnsd" , # "bnsd" or "bsnd"
274
+ ) -> torch .Tensor :
275
+ """
276
+ SDPA attention (with optional GQA) that supports two memory layouts via `layout`:
277
+ - "bnsd": [batch, num_heads, seq_len, head_dim]
278
+ - "bsnd": [batch, seq_len, num_heads, head_dim]
279
+
280
+ The `attn_mask` is always interpreted as [b, n, s_q, s_k].
281
+
282
+ Returns a tensor in the SAME layout as inputs specified by `layout`.
283
+ """
284
+ if layout not in ("bnsd" , "bsnd" ):
285
+ raise ValueError (f"layout must be 'bnsd' or 'bsnd', got { layout !r} " )
286
+
287
+ if layout == "bsnd" :
288
+ query = query .transpose (1 , 2 ).contiguous ()
289
+ key = key .transpose (1 , 2 ).contiguous ()
290
+ value = value .transpose (1 , 2 ).contiguous ()
291
+
292
+ b , n_heads , s_q , head_dim = query .shape # bnsd format: [batch, num_heads, seq_len, head_dim]
293
+ _ , n_kv_heads , s_k , _ = key .shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]
294
+
295
+ # Inputs are already in bnsd format, no need to transpose
296
+ query_t = query # [b, n_heads, s_q, head_dim]
297
+ key_t = key # [b, n_kv_heads, s_k, head_dim]
298
+ value_t = value # [b, n_kv_heads, s_k, v_head_dim]
299
+
300
+ # Handle GQA by repeating KV if needed
301
+ if n_heads != n_kv_heads :
302
+ n_rep = n_heads // n_kv_heads
303
+ key_t = repeat_kv (key_t , n_rep )
304
+ value_t = repeat_kv (value_t , n_rep )
305
+
306
+ # Set scale
307
+ if scale is None :
308
+ scale = 1.0 / math .sqrt (head_dim )
309
+
310
+ # Compute attention scores: Q @ K^T
311
+ attn_scores = torch .matmul (query_t , key_t .transpose (- 2 , - 1 )) * scale # [b, n_heads, s_q, s_k]
312
+
313
+ # Apply attention mask if provided
314
+ if attn_mask is not None :
315
+ # Convert boolean mask to float if needed
316
+ attn_mask = _convert_boolean_mask_to_float (attn_mask , attn_scores .dtype )
317
+ attn_scores = attn_scores + attn_mask
318
+
319
+ # Apply causal mask if specified and only during the context phase
320
+ if is_causal and s_q == s_k : # Only apply causal mask during context processing
321
+ causal_mask = torch .triu (
322
+ torch .ones (s_q , s_k , device = query .device , dtype = torch .bool ),
323
+ diagonal = 1 , # Use diagonal=1 for standard causal masking
324
+ )
325
+ attn_scores .masked_fill_ (causal_mask .unsqueeze (0 ).unsqueeze (0 ), float ("-inf" ))
326
+
327
+ # Apply sliding window mask if specified
328
+ if sliding_window is not None and sliding_window > 0 :
329
+ # Handle position calculation for both context and generation phases
330
+ if s_q == s_k :
331
+ # Context phase: standard position calculation
332
+ query_positions = torch .arange (s_q , device = query .device )
333
+ key_positions = torch .arange (s_k , device = query .device )
334
+ else :
335
+ # Generation phase: query is at position s_k (after the cache)
336
+ query_positions = torch .arange (s_k , s_k + s_q , device = query .device ) # [s_k] for s_q=1
337
+ key_positions = torch .arange (s_k , device = query .device ) # [0,1,2,...,s_k-1]
338
+
339
+ # Create position difference matrix: query_pos - key_pos
340
+ pos_diff = query_positions .unsqueeze (1 ) - key_positions .unsqueeze (0 ) # [s_q, s_k]
341
+
342
+ # Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
343
+ sliding_window_mask = (pos_diff < 0 ) | (pos_diff >= sliding_window ) # [s_q, s_k]
344
+ attn_scores .masked_fill_ (sliding_window_mask .unsqueeze (0 ).unsqueeze (0 ), float ("-inf" ))
345
+
346
+ # Apply logit softcapping if enabled
347
+ attn_scores = _apply_logit_softcapping (attn_scores , logit_cap )
348
+
349
+ # Apply sinks if provided
350
+ if sinks is not None :
351
+ # Concatenate sinks to attention scores following the reference implementation
352
+ # sinks should have n_heads elements, each head gets its own sink value
353
+ # Expand sinks to [b, n_heads, s_q, 1] - one sink column per head
354
+ sinks_expanded = sinks .reshape (1 , - 1 , 1 , 1 ).expand (
355
+ b , n_heads , s_q , 1
356
+ ) # [b, n_heads, s_q, 1]
357
+
358
+ # Concatenate along the key dimension (last dimension)
359
+ logits_max = torch .max (attn_scores , dim = - 1 , keepdim = True ).values
360
+ sinks = torch .exp (sinks_expanded - logits_max )
361
+ unnormalized_scores = torch .exp (attn_scores - logits_max )
362
+ normalizer = unnormalized_scores .sum (dim = - 1 , keepdim = True ) + sinks
363
+ scores = unnormalized_scores / normalizer
364
+ # Use only the non-sink portion for computing output
365
+ # We added exactly 1 column, so remove exactly 1 column
366
+ attn_out = torch .matmul (scores , value_t ) # [b, n_heads, s_q, v_head_dim]
367
+ else :
368
+ attn_weights = torch .softmax (attn_scores , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
369
+ attn_out = torch .matmul (attn_weights , value_t ) # [b, n_heads, s_q, v_head_dim]
370
+
371
+ # Apply dropout if specified
372
+ if dropout_p > 0.0 :
373
+ attn_out = F .dropout (attn_out , p = dropout_p , training = False )
374
+
375
+ if layout == "bsnd" :
376
+ return attn_out .transpose (1 , 2 ).contiguous ()
377
+ else :
378
+ return attn_out .contiguous ()
379
+
380
+
381
+ @torch_attention .register_fake
382
+ def torch_attention_fake (
383
+ query ,
384
+ key ,
385
+ value ,
386
+ attn_mask = None ,
387
+ dropout_p : float = 0.0 ,
388
+ is_causal : bool = False ,
389
+ scale = None ,
390
+ sinks = None ,
391
+ sliding_window = None ,
392
+ logit_cap = None ,
393
+ layout : str = "bnsd" ,
394
+ ):
395
+ return query .new_empty (* query .shape [:- 1 ], value .shape [- 1 ]).contiguous ()
396
+
397
+
260
398
def update_kv_cache (
261
399
key_states : torch .Tensor ,
262
400
value_states : torch .Tensor ,
0 commit comments