Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Dec 4, 2025

📄 10% (0.10x) speedup for multi_scale_deformable_attn_pytorch in ultralytics/nn/modules/utils.py

⏱️ Runtime : 3.68 milliseconds 3.34 milliseconds (best of 119 runs)

📝 Explanation and details

The optimized code achieves a 10% speedup through several targeted micro-optimizations that reduce computational overhead in the hot path:

Key Optimizations Applied:

  1. Pre-compute spatial sizes: Instead of computing H_ * W_ for each level during the split operation, the code pre-computes all spatial sizes at once using vectorized tensor operations (value_spatial_shapes[:, 0] * value_spatial_shapes[:, 1]), reducing the expensive list comprehension overhead.

  2. Eliminate tensor dereferencing in loop: Converting value_spatial_shapes.tolist() once outside the loop avoids repeated tensor attribute access and indexing operations inside the critical loop, which is particularly beneficial since this loop runs for each attention level.

  3. Reduce function lookup overhead: Moving torch.stack and torch.Tensor.flatten to local variables eliminates repeated attribute lookups during execution.

  4. Optimize tensor operations flow: The code moves the torch.stack and flatten operations on sampling_value_list outside the final computation chain, creating sampling_values as an intermediate result. This reduces the complexity of the final expression and potentially improves memory access patterns.

Performance Impact:
The line profiler shows the most significant improvements in:

  • The value.split() operation (28.4% → 24.8% of total time)
  • The loop enumeration overhead is reduced through direct range iteration
  • Overall execution time improved from 5.71ms to 5.39ms

Workload Benefits:
Based on the function reference, this optimization is particularly valuable since multi_scale_deformable_attn_pytorch is called within the forward pass of a transformer attention mechanism. The 10% improvement will compound across multiple attention heads and layers during model inference, making it especially beneficial for real-time applications or batch processing scenarios.

Test Case Performance:
The optimizations show consistent 5-20% improvements across all test cases, with particularly strong performance on multi-level, multi-point scenarios (up to 19.8% faster), which are the most computationally intensive use cases this function is designed to handle.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 27 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest  # used for our unit tests
import torch
from ultralytics.nn.modules.utils import multi_scale_deformable_attn_pytorch

# unit tests

# ------------------- BASIC TEST CASES -------------------


def test_basic_output_shape_and_type():
    """
    Basic: Test that output has correct shape and dtype for a small, typical input.
    """
    bs, num_keys, num_heads, embed_dims = 2, 8, 2, 4
    num_levels = 2
    num_queries = 3
    num_points = 2
    # value_spatial_shapes: (num_levels, 2)
    value_spatial_shapes = torch.tensor([[2, 2], [2, 2]], dtype=torch.long)
    # value: (bs, num_keys, num_heads, embed_dims)
    value = torch.rand(bs, num_keys, num_heads, embed_dims)
    # sampling_locations: (bs, num_queries, num_heads, num_levels, num_points, 2)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    # attention_weights: (bs, num_queries, num_heads, num_levels, num_points)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    output = codeflash_output  # 147μs -> 128μs (14.8% faster)


def test_basic_determinism():
    """
    Basic: Test that the function is deterministic for the same input.
    """
    bs, num_keys, num_heads, embed_dims = 1, 4, 1, 2
    num_levels = 1
    num_queries = 2
    num_points = 1
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    value = torch.ones(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.zeros(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.ones(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    output1 = codeflash_output  # 105μs -> 97.6μs (8.61% faster)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    output2 = codeflash_output  # 51.4μs -> 44.4μs (15.9% faster)


def test_basic_nonzero_output():
    """
    Basic: Test that output is non-zero for non-zero attention weights and value.
    """
    bs, num_keys, num_heads, embed_dims = 1, 4, 1, 2
    num_levels = 1
    num_queries = 2
    num_points = 1
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    value = torch.ones(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.ones(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    output = codeflash_output  # 100μs -> 88.5μs (13.5% faster)


# ------------------- EDGE TEST CASES -------------------


def test_edge_zero_attention_weights():
    """
    Edge: Test that output is zero when all attention weights are zero.
    """
    bs, num_keys, num_heads, embed_dims = 1, 4, 1, 2
    num_levels = 1
    num_queries = 2
    num_points = 1
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    value = torch.rand(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.zeros(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    output = codeflash_output  # 113μs -> 100μs (12.4% faster)


def test_edge_sampling_locations_out_of_bounds():
    """
    Edge: Test that out-of-bounds sampling locations result in zero output (due to zero padding).
    """
    bs, num_keys, num_heads, embed_dims = 1, 4, 1, 2
    num_levels = 1
    num_queries = 2
    num_points = 1
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    value = torch.rand(bs, num_keys, num_heads, embed_dims)
    # Set sampling locations outside [0,1] range (e.g., 2.0)
    sampling_locations = torch.full((bs, num_queries, num_heads, num_levels, num_points, 2), 2.0)
    attention_weights = torch.ones(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    output = codeflash_output  # 109μs -> 98.5μs (10.7% faster)


def test_edge_minimal_input():
    """
    Edge: Test minimal input sizes (all dimensions = 1).
    """
    bs, num_keys, num_heads, embed_dims = 1, 1, 1, 1
    num_levels = 1
    num_queries = 1
    num_points = 1
    value_spatial_shapes = torch.tensor([[1, 1]], dtype=torch.long)
    value = torch.tensor([[[[1.0]]]])
    sampling_locations = torch.tensor([[[[[[0.5, 0.5]]]]]])
    attention_weights = torch.tensor([[[[[1.0]]]]])
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    output = codeflash_output  # 96.5μs -> 85.6μs (12.7% faster)


def test_edge_large_embed_dims():
    """
    Edge: Test with large embed_dims but within memory constraints.
    """
    bs, num_keys, num_heads, embed_dims = 1, 4, 1, 256
    num_levels = 1
    num_queries = 2
    num_points = 1
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    value = torch.rand(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    output = codeflash_output  # 109μs -> 102μs (6.69% faster)


def test_edge_multiple_levels_and_points():
    """
    Edge: Test with multiple levels and points.
    """
    bs, num_keys, num_heads, embed_dims = 1, 12, 1, 8
    num_levels = 3
    num_queries = 2
    num_points = 4
    value_spatial_shapes = torch.tensor([[2, 2], [2, 2], [2, 2]], dtype=torch.long)
    value = torch.rand(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    output = codeflash_output  # 145μs -> 122μs (18.4% faster)


def test_edge_float16_dtype():
    """
    Edge: Test with float16 dtype.
    """
    bs, num_keys, num_heads, embed_dims = 1, 4, 1, 2
    num_levels = 1
    num_queries = 2
    num_points = 1
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    value = torch.rand(bs, num_keys, num_heads, embed_dims).half()
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2).half()
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points).half()
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    output = codeflash_output  # 107μs -> 94.0μs (13.9% faster)


def test_edge_invalid_shapes_raise():
    """
    Edge: Test that invalid shapes raise errors.
    """
    bs, num_keys, num_heads, embed_dims = 1, 4, 1, 2
    num_levels = 1
    num_queries = 2
    num_points = 1
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    value = torch.rand(bs, num_keys, num_heads, embed_dims)
    # Wrong shape: missing "2" at the end
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    with pytest.raises(Exception):
        multi_scale_deformable_attn_pytorch(
            value, value_spatial_shapes, sampling_locations, attention_weights
        )  # 3.73μs -> 3.69μs (1.14% faster)


# ------------------- LARGE SCALE TEST CASES -------------------


def test_large_scale_embed_dims_and_levels():
    """
    Large: Test with large embed_dims and multiple levels.
    """
    bs, num_keys, num_heads, embed_dims = 1, 16, 4, 64
    num_levels = 4
    num_queries = 32
    num_points = 2
    value_spatial_shapes = torch.tensor([[2, 2], [2, 2], [2, 2], [2, 2]], dtype=torch.long)
    value = torch.rand(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    output = codeflash_output  # 542μs -> 511μs (5.93% faster)


def test_large_scale_memory_limit():
    """
    Large: Ensure that the function does not exceed memory constraints (<100MB tensor).
    """
    bs, num_keys, num_heads, embed_dims = 2, 32, 2, 32
    num_levels = 2
    num_queries = 128
    num_points = 2
    value_spatial_shapes = torch.tensor([[4, 4], [4, 4]], dtype=torch.long)
    value = torch.rand(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    # Estimate memory usage
    total_bytes = (
        value.numel() * value.element_size()
        + sampling_locations.numel() * sampling_locations.element_size()
        + attention_weights.numel() * attention_weights.element_size()
    )
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    output = codeflash_output  # 319μs -> 303μs (5.16% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import pytest
import torch
from ultralytics.nn.modules.utils import multi_scale_deformable_attn_pytorch

# function to test
# (see above for function definition)

# unit tests

# ------------------- Basic Test Cases -------------------


def test_basic_single_level_single_point():
    # Test with one batch, one query, one head, one level, one point
    bs, num_keys, num_heads, embed_dims = 1, 4, 1, 2
    num_levels, num_points, num_queries = 1, 1, 1
    value = torch.arange(bs * num_keys * num_heads * embed_dims, dtype=torch.float32).reshape(
        bs, num_keys, num_heads, embed_dims
    )
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)  # 2x2 spatial shape
    sampling_locations = torch.zeros(bs, num_queries, num_heads, num_levels, num_points, 2, dtype=torch.float32)
    attention_weights = torch.ones(bs, num_queries, num_heads, num_levels, num_points, dtype=torch.float32)
    # Should produce output of shape (1, 1, 2)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    out = codeflash_output  # 111μs -> 101μs (10.1% faster)


def test_basic_multi_level_multi_point():
    # Test with two levels, two points per level, two heads
    bs, embed_dims, num_heads = 1, 4, 2
    value_spatial_shapes = torch.tensor([[2, 2], [1, 2]], dtype=torch.long)
    num_keys = (
        value_spatial_shapes[0, 0] * value_spatial_shapes[0, 1]
        + value_spatial_shapes[1, 0] * value_spatial_shapes[1, 1]
    )
    num_queries, num_levels, num_points = 3, 2, 2
    value = torch.randn(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    out = codeflash_output  # 121μs -> 104μs (16.7% faster)


def test_basic_batch_size_greater_than_one():
    # Test with batch size > 1
    bs, embed_dims, num_heads = 2, 3, 1
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    num_keys = value_spatial_shapes[0, 0] * value_spatial_shapes[0, 1]
    num_queries, num_levels, num_points = 2, 1, 2
    value = torch.randn(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    out = codeflash_output  # 96.3μs -> 89.7μs (7.41% faster)


# ------------------- Edge Test Cases -------------------


def test_edge_zero_attention_weights():
    # All attention weights are zero, output should be all zeros
    bs, embed_dims, num_heads = 1, 2, 1
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    num_keys = value_spatial_shapes[0, 0] * value_spatial_shapes[0, 1]
    num_queries, num_levels, num_points = 1, 1, 2
    value = torch.randn(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.zeros(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    out = codeflash_output  # 89.2μs -> 83.8μs (6.51% faster)


def test_edge_sampling_locations_out_of_bounds():
    # Sampling locations outside of [-1, 1] after transformation
    bs, embed_dims, num_heads = 1, 2, 1
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    num_keys = value_spatial_shapes[0, 0] * value_spatial_shapes[0, 1]
    num_queries, num_levels, num_points = 1, 1, 2
    value = torch.ones(bs, num_keys, num_heads, embed_dims)
    # Set sampling locations so grid_sample will sample from outside
    sampling_locations = torch.tensor([[[[[[2.0, 2.0], [-1.0, -1.0]]]]]], dtype=torch.float32)
    attention_weights = torch.ones(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    out = codeflash_output  # 91.0μs -> 81.5μs (11.7% faster)


def test_edge_minimal_shapes():
    # Minimal tensor shapes (all dimensions = 1)
    bs, num_keys, num_heads, embed_dims = 1, 1, 1, 1
    value = torch.tensor([[[[1.0]]]])
    value_spatial_shapes = torch.tensor([[1, 1]], dtype=torch.long)
    sampling_locations = torch.zeros(bs, 1, num_heads, 1, 1, 2)
    attention_weights = torch.ones(bs, 1, num_heads, 1, 1)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    out = codeflash_output  # 96.9μs -> 84.0μs (15.3% faster)


def test_edge_attention_weights_sum_to_one():
    # Attention weights sum to one for each query
    bs, embed_dims, num_heads = 1, 2, 1
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    num_keys = value_spatial_shapes[0, 0] * value_spatial_shapes[0, 1]
    num_queries, num_levels, num_points = 1, 1, 4
    value = torch.arange(num_keys * num_heads * embed_dims, dtype=torch.float32).reshape(
        bs, num_keys, num_heads, embed_dims
    )
    sampling_locations = torch.zeros(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.ones(bs, num_queries, num_heads, num_levels, num_points)
    attention_weights /= attention_weights.sum(-1, keepdim=True)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    out = codeflash_output  # 86.8μs -> 78.7μs (10.3% faster)
    # Output should be a weighted sum of sampled values; since all weights are equal, output should be average
    # Since all sampling locations are zeros, grid_sample will sample from the center (or nearest pixel)
    # We can't guarantee the exact value, but output should be finite and not nan


def test_edge_different_embed_dims():
    # Test with different embed_dims
    bs, embed_dims, num_heads = 1, 7, 2
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    num_keys = value_spatial_shapes[0, 0] * value_spatial_shapes[0, 1]
    num_queries, num_levels, num_points = 2, 1, 2
    value = torch.randn(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    out = codeflash_output  # 98.9μs -> 89.7μs (10.3% faster)


def test_edge_multiple_levels_varied_shapes():
    # Multiple levels with varied spatial shapes
    bs, embed_dims, num_heads = 1, 3, 1
    value_spatial_shapes = torch.tensor([[2, 2], [1, 3]], dtype=torch.long)
    num_keys = (
        value_spatial_shapes[0, 0] * value_spatial_shapes[0, 1]
        + value_spatial_shapes[1, 0] * value_spatial_shapes[1, 1]
    )
    num_queries, num_levels, num_points = 2, 2, 2
    value = torch.randn(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    out = codeflash_output  # 116μs -> 97.1μs (19.8% faster)


# ------------------- Large Scale Test Cases -------------------


def test_large_scale_maximum_elements():
    # Large scale test with maximum allowed elements under 100MB
    # Estimate: float32 = 4 bytes, so 100MB/4 = 25,000,000 elements
    # We'll keep tensors much smaller for safety, e.g., 100,000 elements
    bs, num_heads, embed_dims = 2, 4, 8
    value_spatial_shapes = torch.tensor([[10, 10], [10, 10]], dtype=torch.long)  # 2 levels, each 100
    num_keys = (
        value_spatial_shapes[0, 0] * value_spatial_shapes[0, 1]
        + value_spatial_shapes[1, 0] * value_spatial_shapes[1, 1]
    )  # 200
    num_queries, num_levels, num_points = 50, 2, 5
    value = torch.randn(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    out = codeflash_output  # 208μs -> 191μs (8.90% faster)


def test_large_scale_many_heads_and_embed_dims():
    # Large scale test with many heads and embed_dims
    bs, num_heads, embed_dims = 1, 8, 16
    value_spatial_shapes = torch.tensor([[8, 8], [4, 4]], dtype=torch.long)
    num_keys = (
        value_spatial_shapes[0, 0] * value_spatial_shapes[0, 1]
        + value_spatial_shapes[1, 0] * value_spatial_shapes[1, 1]
    )
    num_queries, num_levels, num_points = 20, 2, 4
    value = torch.randn(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points)
    codeflash_output = multi_scale_deformable_attn_pytorch(
        value, value_spatial_shapes, sampling_locations, attention_weights
    )
    out = codeflash_output  # 199μs -> 184μs (8.47% faster)


def test_error_invalid_shapes():
    # Value tensor shape mismatch with value_spatial_shapes
    bs, num_heads, embed_dims = 1, 1, 2
    value_spatial_shapes = torch.tensor([[2, 2], [1, 2]], dtype=torch.long)
    num_keys = 4  # Should be 4+2=6
    value = torch.randn(bs, num_keys, num_heads, embed_dims)
    sampling_locations = torch.rand(bs, 1, num_heads, 2, 1, 2)
    attention_weights = torch.rand(bs, 1, num_heads, 2, 1)
    # Should raise an error due to shape mismatch
    with pytest.raises(RuntimeError):
        multi_scale_deformable_attn_pytorch(
            value, value_spatial_shapes, sampling_locations, attention_weights
        )  # 84.1μs -> 75.7μs (11.1% faster)


def test_error_invalid_dtype():
    # Value tensor with wrong dtype
    bs, num_heads, embed_dims = 1, 1, 2
    value_spatial_shapes = torch.tensor([[2, 2]], dtype=torch.long)
    num_keys = 4
    value = torch.randint(0, 10, (bs, num_keys, num_heads, embed_dims), dtype=torch.int32)
    sampling_locations = torch.rand(bs, 1, num_heads, 1, 1, 2)
    attention_weights = torch.rand(bs, 1, num_heads, 1, 1)
    # Should raise an error due to dtype mismatch in grid_sample
    with pytest.raises(RuntimeError):
        multi_scale_deformable_attn_pytorch(
            value, value_spatial_shapes, sampling_locations, attention_weights
        )  # 126μs -> 115μs (8.86% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-multi_scale_deformable_attn_pytorch-mirec5ou and push.

Codeflash Static Badge

The optimized code achieves a **10% speedup** through several targeted micro-optimizations that reduce computational overhead in the hot path:

**Key Optimizations Applied:**

1. **Pre-compute spatial sizes**: Instead of computing `H_ * W_` for each level during the split operation, the code pre-computes all spatial sizes at once using vectorized tensor operations (`value_spatial_shapes[:, 0] * value_spatial_shapes[:, 1]`), reducing the expensive list comprehension overhead.

2. **Eliminate tensor dereferencing in loop**: Converting `value_spatial_shapes.tolist()` once outside the loop avoids repeated tensor attribute access and indexing operations inside the critical loop, which is particularly beneficial since this loop runs for each attention level.

3. **Reduce function lookup overhead**: Moving `torch.stack` and `torch.Tensor.flatten` to local variables eliminates repeated attribute lookups during execution.

4. **Optimize tensor operations flow**: The code moves the `torch.stack` and `flatten` operations on `sampling_value_list` outside the final computation chain, creating `sampling_values` as an intermediate result. This reduces the complexity of the final expression and potentially improves memory access patterns.

**Performance Impact:**
The line profiler shows the most significant improvements in:
- The `value.split()` operation (28.4% → 24.8% of total time)
- The loop enumeration overhead is reduced through direct range iteration
- Overall execution time improved from 5.71ms to 5.39ms

**Workload Benefits:**
Based on the function reference, this optimization is particularly valuable since `multi_scale_deformable_attn_pytorch` is called within the forward pass of a transformer attention mechanism. The 10% improvement will compound across multiple attention heads and layers during model inference, making it especially beneficial for real-time applications or batch processing scenarios.

**Test Case Performance:**
The optimizations show consistent 5-20% improvements across all test cases, with particularly strong performance on multi-level, multi-point scenarios (up to 19.8% faster), which are the most computationally intensive use cases this function is designed to handle.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 December 4, 2025 12:12
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Dec 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant