Skip to content

Conversation

Fridah-nv
Copy link
Collaborator

@Fridah-nv Fridah-nv commented Sep 30, 2025

This PR does the following:

  • Bsnd_grouped_sdpa + grouped_sdpa unification
  • Match attention layout with pattern matching
  • Auto-generating attention pattern matches

Summary by CodeRabbit

  • New Features

    • Unified attention operator with configurable layouts (supports bnsd and bsnd) and automatic input/output formatting.
    • Expanded available custom operators covering distributed reduce, linear variants, MoE (including fused), quantization paths (FP8/NVFP4), RoPE variants, and Triton-based attention.
  • Refactor

    • Replaced legacy grouped attention variants with the unified attention operator across backends and models.
    • Introduced scalable, generator-based attention transformation patterns and layout handling.
  • Documentation

    • Updated public operator list and descriptions.
  • Tests

    • Updated test suites to validate the unified attention operator, layout detection, and sharding recognition.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

…ention_layout to use pattern matcher

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv force-pushed the user/fridah/attn-layout branch from 16d3071 to eee5637 Compare September 30, 2025 17:57
@Fridah-nv Fridah-nv changed the title [None][autodeploy]: small refactors on attention matching [None][autodeploy] small refactors on attention matching Sep 30, 2025
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv marked this pull request as ready for review September 30, 2025 21:14
@Fridah-nv Fridah-nv requested a review from a team as a code owner September 30, 2025 21:14
@Fridah-nv Fridah-nv requested review from MrGeva and lucaslie and removed request for MrGeva September 30, 2025 21:14
Copy link
Contributor

coderabbitai bot commented Sep 30, 2025

📝 Walkthrough

Walkthrough

Unified attention custom op to auto_deploy::torch_attention with explicit layout handling ("bnsd" or "bsnd"), removed legacy grouped SDPA variants, updated backends and model patches to reference the new op, introduced dynamic generation/registration of attention transform patterns, adjusted sharding/KV-cache transformers, and updated tests and README accordingly.

Changes

Cohort / File(s) Change Summary
Unified torch_attention op and callers
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py, tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py, tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py, tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py
Replace grouped SDPA ops with unified auto_deploy::torch_attention; add layout parameter with validation and conditional transposes; update source-op mappings in FlashInfer, TorchBackend, and Triton to use torch_attention.
Transform: attention patterns and layout handling
tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
Add generators for grouped-attention and layout patterns; register patterns programmatically; enforce supported layouts; update matchers to use new generators and layout logic.
Model patch
tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py
Switch to torch_attention with layout="bsnd" in place of bsnd_grouped_sdpa.
Transform: KV cache and sharding
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py, tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Pass through layout in kwargs; retarget fake/translate ops to torch_attention; update shardable op detection to include torch_attention.
Operator list documentation
tensorrt_llm/_torch/auto_deploy/custom_ops/README.md
Remove grouped SDPA entries; add unified torch_attention with layout note; enumerate additional public operators.
Tests: layout and matcher updates
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py, tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
Update expectations and calls from grouped SDPA variants to torch_attention; add explicit layout="bsnd"; adjust node detection/verification to new op and arguments.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Caller as Model/Module
  participant Op as torch.ops.auto_deploy.torch_attention
  participant Impl as Attention Impl
  participant Kernel as Grouped SDPA Kernel

  Caller->>Op: torch_attention(Q,K,V, ..., layout)
  Note over Op: Validate layout in {"bnsd","bsnd"}
  alt layout == "bsnd"
    Op->>Impl: Transpose inputs to internal bnsd
  else layout == "bnsd"
    Op->>Impl: Use inputs as-is
  end
  Impl->>Kernel: Compute attention (grouped SDPA)
  alt layout == "bsnd"
    Impl->>Op: Transpose output back to bsnd
  else
    Impl-->>Op: Return bnsd output
  end
  Op-->>Caller: Output tensor
Loading
sequenceDiagram
  autonumber
  participant Gen as Pattern Generators
  participant Reg as ADPattern Registry
  participant GM as GraphModule
  participant Pass as Transform Pass

  Gen->>Reg: Generate grouped-attn patterns
  Gen->>Reg: Generate layout patterns (bnsd/bsnd)
  Pass->>GM: Apply registered patterns
  alt Matches found
    GM-->>Pass: Replacements applied
  else No matches
    GM-->>Pass: No-op
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings, 1 inconclusive)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR description includes the repository template but leaves the Description and Test Coverage sections empty and still displays the @coderabbitai summary placeholder, providing no actual explanation of the issue, the implemented solution, or which tests validate the changes. This omission deprives reviewers of essential context and undermines confidence that the new functionality is adequately tested. Please complete the Description section with a clear summary of the problem and solution, list the relevant test cases under Test Coverage, and remove or replace the @coderabbitai placeholder so the PR fully conforms to the repository’s required template.
Docstring Coverage ⚠️ Warning Docstring coverage is 30.95% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ❓ Inconclusive The title “[None][autodeploy] small refactors on attention matching” is overly generic and fails to highlight the PR’s substantial functional changes, including the consolidation of grouped SDPA operators into a unified torch_attention operation, the introduction of layout support, and the auto-generation of attention pattern matching. Teammates reviewing the project history may not recognize these key enhancements under the description of a “small refactor.” Please update the title to clearly reflect the primary changes, for example “[None][feat] Unify grouped SDPA ops into a single torch_attention with layout support and auto-generated attention patterns,” so that colleagues can quickly understand the core improvements at a glance.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c510b67 and b5ddb47.

📒 Files selected for processing (13)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/README.md (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py (8 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
🧬 Code graph analysis (10)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (1)
  • torch_attention (96-212)
tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (1)
  • torch_attention (96-212)
tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (1)
  • torch_attention (96-212)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (1)
  • torch_attention (96-212)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (1)
  • torch_attention (96-212)
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (1)
  • torch_attention (96-212)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (1)
  • torch_attention (96-212)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py (2)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (179-202)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (1)
  • torch_attention (96-212)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py (3)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py (1)
  • verify_matcher (85-120)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (179-202)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (1)
  • torch_attention (96-212)
tensorrt_llm/_torch/auto_deploy/transform/library/attention.py (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (1)
  • repeat_kv (35-47)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (1)
  • register_ad_pattern (99-182)
🪛 Ruff (0.13.1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py

119-119: Avoid specifying long messages outside the exception class

(TRY003)


218-218: Unused function argument: key

(ARG001)


220-220: Unused function argument: attn_mask

(ARG001)


221-221: Unused function argument: dropout_p

(ARG001)


222-222: Unused function argument: is_causal

(ARG001)


223-223: Unused function argument: scale

(ARG001)


224-224: Unused function argument: sinks

(ARG001)


225-225: Unused function argument: sliding_window

(ARG001)


226-226: Unused function argument: logit_cap

(ARG001)


227-227: Unused function argument: layout

(ARG001)

tensorrt_llm/_torch/auto_deploy/transform/library/attention.py

389-389: Use explicit conversion flag

Replace with conversion flag

(RUF010)


402-402: Use explicit conversion flag

Replace with conversion flag

(RUF010)


426-426: Use of exec detected

(S102)


443-443: Use of exec detected

(S102)


512-512: Avoid specifying long messages outside the exception class

(TRY003)


600-600: Use explicit conversion flag

Replace with conversion flag

(RUF010)


622-622: Use of exec detected

(S102)


643-643: Use of exec detected

(S102)


699-699: Avoid specifying long messages outside the exception class

(TRY003)


754-754: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (23)
tensorrt_llm/_torch/auto_deploy/custom_ops/README.md (1)

17-17: LGTM!

The documentation accurately reflects the API consolidation of grouped SDPA operators into a single torch_attention operator with explicit layout support.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

93-93: LGTM!

The migration to the unified torch_attention operator with explicit layout="bsnd" is correct and aligns with the PR's refactoring objectives. The layout parameter clarifies the expected tensor shape.

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py (1)

412-412: LGTM!

The update to return torch.ops.auto_deploy.torch_attention correctly reflects the unified attention operator. The existing get_attention_layout() method already returns "bsnd", ensuring consistency with the layout-aware implementation.

tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py (1)

62-73: LGTM!

The migration to torch.ops.auto_deploy.torch_attention with explicit layout="bsnd" is correct. All existing parameters are preserved, and the explicit layout parameter formalizes what was previously documented only in the comment.

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)

480-483: LGTM!

The update to shardable_attention_nodes correctly reflects the unified attention API. The set now includes torch_attention (the new unified grouped SDPA operator) alongside torch_attention_sdpa, replacing the legacy grouped SDPA variants.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py (8)

744-745: LGTM! Comment correctly updated.

The comment now references the unified torch_attention operator, consistent with the refactoring.


749-749: LGTM! Filter correctly updated to the unified attention operator.

The verification logic now checks for torch_attention instead of the legacy torch_attention_grouped_sdpa.


880-888: LGTM! Call site correctly migrated to unified attention operator.

The call now uses torch.ops.auto_deploy.torch_attention with the same parameters. The default layout of "bnsd" is used implicitly.


986-994: LGTM! Call site correctly migrated.

Consistent migration to the unified torch_attention operator.


1090-1098: LGTM! Call site correctly updated.

The migration to torch_attention is consistent with the refactoring.


1137-1146: LGTM! Correctly uses explicit layout parameter.

The call now uses torch_attention.default with the explicit layout="bsnd" parameter, correctly representing the bsnd memory layout for this model.


1173-1180: LGTM! Source operator assignment correctly updated.

The logic properly assigns torch_attention for both grouped SDPA (bnsd) and bsnd layouts, falling back to torch_attention_sdpa for standard bnsd cases.


1209-1232: LGTM! Verification logic correctly detects layout-aware attention.

The updated filter logic properly identifies torch_attention nodes with the layout="bsnd" parameter by checking if the last argument is the string "bsnd". This aligns with the new unified attention operator design.

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (4)

94-117: LGTM! Well-documented unified attention operator.

The function signature clearly documents the layout parameter and its behavior. The docstring explicitly states the supported layouts and the output layout guarantee, which is excellent for API clarity.


118-124: LGTM! Input validation and layout transformation are correct.

The layout validation and conditional transpose logic correctly convert bsnd inputs to bnsd format for internal processing. The use of .contiguous() ensures optimal memory layout after transposition.

Note: Static analysis flags the error message format (TRY003), but this is a minor style issue that can be addressed later if needed.


209-212: LGTM! Output layout transformation correctly preserves input layout.

The output transformation is symmetric with the input transformation, correctly returning results in the same layout as the inputs. The use of .contiguous() in both branches ensures consistent memory layout.


215-229: LGTM! Fake implementation signature correctly matches the real implementation.

The fake implementation's unused parameters are expected—fake implementations only need to return correctly shaped tensors for shape inference and don't execute the actual computation. The static analysis warnings (ARG001) are false positives in this context.

tensorrt_llm/_torch/auto_deploy/transform/library/attention.py (6)

346-449: Dynamic function generation is well-structured.

The use of exec() for generating pattern and replacement functions is acceptable here because:

  • The code is generated at registration time from controlled boolean flags, not user input
  • The scope is restricted (only torch in the namespace)
  • This approach avoids manually defining 64 pattern/replacement pairs

The generated function names are descriptive and include all flag states for debuggability.

Note: Static analysis flags exec() usage (S102), but this is a false positive given the controlled context.


452-532: LGTM! Comprehensive pattern generation with correct argument handling.

The function systematically generates all 64 pattern variants across the boolean flag combinations. The dummy argument construction correctly follows the signature order, and scalar workarounds are properly mapped. The use of fixed scalar values (dropout_val, scale_val, n_rep_val) is appropriate for pattern matching.


535-565: LGTM! Clean refactoring to use dynamic pattern generation.

The implementation now delegates to generate_and_register_grouped_attn_patterns, significantly reducing code duplication while maintaining the same functionality. The docstring correctly references the unified torch_attention operator.


568-649: LGTM! Layout transformation pattern generation is correct.

The function generates patterns that match torch_attention with layout="bnsd" and replace them with transposed calls using layout="bsnd". The replacement correctly:

  • Transposes Q, K, V from bnsd to bsnd using aten.transpose.int
  • Calls attention with layout="bsnd"
  • Transposes output back to bnsd

The explicit use of torch.ops.aten.transpose.int ensures correct graph-level pattern matching.


652-723: LGTM! Layout pattern enumeration is comprehensive and correct.

The function systematically generates all 16 layout transformation patterns. The dummy tensors correctly use BNSD layout (bs, n_heads, s_q, head_dim), and scalar workarounds are properly constructed for present parameters only.


751-773: LGTM! Layout transformation logic is correct and efficient.

The implementation correctly:

  • Validates that the layout is either "bnsd" or "bsnd"
  • Short-circuits when the backend expects "bnsd" (no transformation needed)
  • Applies layout transformation patterns only when the backend expects "bsnd"

The logic is clear and efficient, avoiding unnecessary pattern matching when no transformation is required.

Note: Static analysis flags the error message format (TRY003), but this is a minor style issue.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I migrated MatchGroupedAttention and MatchAttentionLayout to auto generate and register patterns, but I leave MatchEagerAttention. Because MatchGroupedAttention and MatchAttentionLayout are just matching between torch_attention_sdpa and torch_attention (the new unified grouped op), I can iterate through all the parameter combinations (64 patterns for MatchGroupedAttention and 16 for MatchAttentionLayout)
But eager attention's pattern is much more flexible, e.g. when/whether to cast softmax output can be a new feature to compose the pattern. Iterating through all these features can blow up the number of patterns quickly. In this case, I think it's okay to add new patterns manually when we discover a new way to write eager attention. Let me know your suggestions!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

"sinks": "sinks",
"sliding_window": "sliding_window",
"logit_cap": "logit_cap",
"layout": "bsnd",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is arg name mapping between **kwargs names that huggingface may insert and **kwargs name we expect.

This doesn't make sense in this context

torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa,
torch.ops.auto_deploy.torch_attention,
args=(q_fake, k_fake, v_fake),
kwargs=node_kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requires hard-coding layout="bsnd"

def get_source_attention_op(cls) -> OpOverloadPacket:
"""Get the source attention op that we target for replacement."""
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
return torch.ops.auto_deploy.torch_attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a sanity check now that torch_attention is called with layout "bsnd".

Probably best to put that into the get_constants utility and raise an error if there is a mismatch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also needed for the other backends

sinks: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
logit_cap: Optional[float] = None,
layout: str = "bnsd", # "bnsd" or "bsnd"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could type hint as

Suggested change
layout: str = "bnsd", # "bnsd" or "bsnd"
layout: Literal["bnsd", "bsnd"] = "bnsd",

Comment on lines +478 to +483
for repeat_kv in (False, True):
for is_causal in (False, True):
for has_scale in (False, True):
for enable_gqa in (False, True):
for has_attn_mask in (False, True):
for has_dropout in (False, True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good use case for itertools.product

body_lines.append(call_line)
src = "\n".join(body_lines)
scope = {"torch": torch}
exec(src, scope)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for exec anywhere here. this can be solved with a "function factory", i.e., along the lines of

def make_multiplier(factor):
    def multiply(x):
        return x * factor
    return multiply

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Backlog
Development

Successfully merging this pull request may close these issues.

2 participants