Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Dec 4, 2025

📄 16% (0.16x) speedup for v8SegmentationLoss.single_mask_loss in ultralytics/utils/loss.py

⏱️ Runtime : 8.77 milliseconds 7.53 milliseconds (best of 82 runs)

📝 Explanation and details

The optimization achieves a 16% speedup through two key improvements:

1. Optimized Matrix Multiplication in single_mask_loss:

  • Replaced torch.einsum("in,nhw->ihw", pred, proto) with torch.matmul(pred, proto_flat).view(-1, proto.shape[1], proto.shape[2])
  • This eliminates einsum's parsing overhead and uses PyTorch's highly optimized BLAS-backed matmul operations
  • Line profiler shows the matrix operation improved from 2.44ms to 1.23ms (49% faster for this line)

2. More Efficient Mask Generation in crop_mask:

  • Replaced chained arithmetic operations ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) with logical operations using & operator
  • Split the mask computation into separate x and y components (mask_x and mask_y) before combining
  • Used explicit .view() calls instead of advanced indexing for tensor reshaping

Performance Impact:
The optimizations are particularly effective for:

  • Large batches: Test with 128 objects shows 20.7% speedup
  • Large mask sizes: 128x128 masks show 13.9% improvement
  • Multiple objects: All multi-object test cases show 13-20% improvements

The crop_mask function's total time reduced from 6.56ms to 4.94ms (25% faster), while single_mask_loss improved from 15.35ms to 12.35ms (20% faster). These functions are likely in the training hot path for YOLO segmentation models, making these optimizations valuable for training performance. The improvements are consistent across different input sizes and configurations, indicating robust performance gains.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 32 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import torch
from ultralytics.utils.loss import v8SegmentationLoss

# ----------------- UNIT TESTS -----------------

# Basic Test Cases


def test_single_object_perfect_prediction():
    # One object, prediction matches mask exactly
    n, H, W = 1, 8, 8
    gt_mask = torch.zeros((n, H, W))
    gt_mask[0, 2:6, 2:6] = 1.0
    proto = torch.zeros((32, H, W))
    proto[0] = gt_mask[0]  # Only first proto "draws" the object
    pred = torch.zeros((n, 32))
    pred[0, 0] = 10.0  # Large positive value, so sigmoid(proto[0]) ~ 1
    xyxy = torch.tensor([[2, 2, 6, 6]], dtype=torch.float32)
    area = torch.tensor([16.0])
    # The predicted mask will be very close to the gt mask, so BCE loss should be near zero
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 123μs -> 106μs (16.3% faster)


def test_single_object_bad_prediction():
    # One object, prediction is all zeros (sigmoid(0) = 0.5)
    n, H, W = 1, 8, 8
    gt_mask = torch.zeros((n, H, W))
    gt_mask[0, 2:6, 2:6] = 1.0
    proto = torch.zeros((32, H, W))
    proto[0] = gt_mask[0]
    pred = torch.zeros((n, 32))  # All zeros, so mask is zeros
    xyxy = torch.tensor([[2, 2, 6, 6]], dtype=torch.float32)
    area = torch.tensor([16.0])
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 123μs -> 103μs (18.5% faster)


def test_multiple_objects_different_areas():
    # Two objects, different sizes, different predictions
    n, H, W = 2, 8, 8
    gt_mask = torch.zeros((n, H, W))
    gt_mask[0, 1:4, 1:4] = 1.0  # 9 pixels
    gt_mask[1, 4:8, 4:8] = 1.0  # 16 pixels
    proto = torch.zeros((32, H, W))
    proto[0] = gt_mask[0]
    proto[1] = gt_mask[1]
    pred = torch.zeros((n, 32))
    pred[0, 0] = 10.0  # Good prediction for obj 0
    pred[1, 1] = -10.0  # Bad prediction for obj 1 (sigmoid(-10) ~ 0)
    xyxy = torch.tensor([[1, 1, 4, 4], [4, 4, 8, 8]], dtype=torch.float32)
    area = torch.tensor([9.0, 16.0])
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 129μs -> 109μs (17.7% faster)


def test_zero_area_object():
    # Object with zero area should not cause division by zero
    n, H, W = 1, 8, 8
    gt_mask = torch.zeros((n, H, W))
    proto = torch.zeros((32, H, W))
    pred = torch.zeros((n, 32))
    xyxy = torch.tensor([[2, 2, 2, 2]], dtype=torch.float32)  # x1==x2, y1==y2
    area = torch.tensor([0.0])
    # Should not raise ZeroDivisionError, but may return inf or nan
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 158μs -> 138μs (14.0% faster)


# Edge Test Cases


def test_empty_masks():
    # No objects in the image
    n, H, W = 0, 8, 8
    gt_mask = torch.zeros((n, H, W))
    proto = torch.zeros((32, H, W))
    pred = torch.zeros((n, 32))
    xyxy = torch.zeros((n, 4))
    area = torch.ones((n,))
    # Should return zero loss
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 119μs -> 97.1μs (23.2% faster)


def test_masks_outside_image():
    # Object's box is outside the image bounds; crop_mask should produce all zeros
    n, H, W = 1, 8, 8
    gt_mask = torch.ones((n, H, W))
    proto = torch.ones((32, H, W))
    pred = torch.ones((n, 32))
    xyxy = torch.tensor([[100, 100, 110, 110]], dtype=torch.float32)
    area = torch.tensor([100.0])
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 130μs -> 109μs (18.9% faster)


def test_partial_overlap_box():
    # Object's box partially outside image
    n, H, W = 1, 8, 8
    gt_mask = torch.ones((n, H, W))
    proto = torch.ones((32, H, W))
    pred = torch.zeros((n, 32))
    xyxy = torch.tensor([[-2, -2, 4, 4]], dtype=torch.float32)
    area = torch.tensor([36.0])
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 127μs -> 111μs (14.4% faster)


def test_non_square_masks_and_boxes():
    # Non-square masks and boxes
    n, H, W = 2, 8, 16
    gt_mask = torch.zeros((n, H, W))
    gt_mask[0, 2:4, 3:10] = 1.0
    gt_mask[1, 5:8, 8:16] = 1.0
    proto = torch.zeros((32, H, W))
    proto[0] = gt_mask[0]
    proto[1] = gt_mask[1]
    pred = torch.zeros((n, 32))
    pred[0, 0] = 10.0
    pred[1, 1] = 10.0
    xyxy = torch.tensor([[3, 2, 10, 4], [8, 5, 16, 8]], dtype=torch.float32)
    area = torch.tensor([14.0, 24.0])
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 130μs -> 113μs (15.4% faster)


def test_area_broadcasting():
    # area is shape (n,), but crop_mask returns shape (n, ...)
    n, H, W = 3, 8, 8
    gt_mask = torch.rand((n, H, W))
    proto = torch.rand((32, H, W))
    pred = torch.rand((n, 32))
    xyxy = torch.tensor([[0, 0, 8, 8], [2, 2, 6, 6], [1, 1, 7, 7]], dtype=torch.float32)
    area = torch.tensor([64.0, 16.0, 36.0])
    # Should not raise any error, just check output shape and type
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 137μs -> 118μs (16.3% faster)


# Large Scale Test Cases


def test_large_number_of_objects():
    # Test with 128 objects, 32x32 masks
    n, H, W = 128, 32, 32
    torch.manual_seed(0)
    gt_mask = torch.randint(0, 2, (n, H, W)).float()
    proto = torch.randn((32, H, W))
    pred = torch.randn((n, 32))
    xyxy = torch.randint(0, H, (n, 4)).float()
    # Ensure x2 > x1 and y2 > y1
    xyxy[:, 2] = torch.max(xyxy[:, 0] + 1, xyxy[:, 2])
    xyxy[:, 3] = torch.max(xyxy[:, 1] + 1, xyxy[:, 3])
    area = (xyxy[:, 2] - xyxy[:, 0]) * (xyxy[:, 3] - xyxy[:, 1])
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 683μs -> 566μs (20.7% faster)


def test_large_mask_size():
    # Test with 16 objects, 128x128 masks, but keep under 100MB
    n, H, W = 16, 128, 128
    torch.manual_seed(42)
    gt_mask = torch.randint(0, 2, (n, H, W)).float()
    proto = torch.randn((32, H, W))
    pred = torch.randn((n, 32))
    xyxy = torch.randint(0, H, (n, 4)).float()
    xyxy[:, 2] = torch.max(xyxy[:, 0] + 1, xyxy[:, 2])
    xyxy[:, 3] = torch.max(xyxy[:, 1] + 1, xyxy[:, 3])
    area = (xyxy[:, 2] - xyxy[:, 0]) * (xyxy[:, 3] - xyxy[:, 1])
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 1.39ms -> 1.22ms (13.9% faster)


def test_highly_overlapping_boxes():
    # Many objects with overlapping boxes
    n, H, W = 32, 32, 32
    gt_mask = torch.zeros((n, H, W))
    for i in range(n):
        gt_mask[i, 8:24, 8:24] = 1.0  # All objects overlap in the center
    proto = torch.zeros((32, H, W))
    proto[0] = gt_mask[0]
    pred = torch.zeros((n, 32))
    for i in range(n):
        pred[i, 0] = 10.0  # All objects use proto[0]
    xyxy = torch.tensor([[8, 8, 24, 24]] * n, dtype=torch.float32)
    area = torch.tensor([256.0] * n)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 267μs -> 230μs (16.0% faster)


def test_randomized_large_batch():
    # Randomized stress test with 64 objects, 64x64 masks
    n, H, W = 64, 64, 64
    torch.manual_seed(123)
    gt_mask = torch.randint(0, 2, (n, H, W)).float()
    proto = torch.randn((32, H, W))
    pred = torch.randn((n, 32))
    xyxy = torch.randint(0, H, (n, 4)).float()
    xyxy[:, 2] = torch.max(xyxy[:, 0] + 1, xyxy[:, 2])
    xyxy[:, 3] = torch.max(xyxy[:, 1] + 1, xyxy[:, 3])
    area = (xyxy[:, 2] - xyxy[:, 0]) * (xyxy[:, 3] - xyxy[:, 1])
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 1.30ms -> 1.12ms (15.4% faster)


# Extra: Determinism test


def test_determinism():
    # The function should be deterministic given the same inputs
    n, H, W = 4, 16, 16
    torch.manual_seed(2024)
    gt_mask = torch.randint(0, 2, (n, H, W)).float()
    proto = torch.randn((32, H, W))
    pred = torch.randn((n, 32))
    xyxy = torch.randint(0, H, (n, 4)).float()
    xyxy[:, 2] = torch.max(xyxy[:, 0] + 1, xyxy[:, 2])
    xyxy[:, 3] = torch.max(xyxy[:, 1] + 1, xyxy[:, 3])
    area = (xyxy[:, 2] - xyxy[:, 0]) * (xyxy[:, 3] - xyxy[:, 1])
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss1 = codeflash_output  # 131μs -> 113μs (16.3% faster)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss2 = codeflash_output  # 68.6μs -> 54.5μs (26.0% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import pytest  # used for our unit tests
import torch
from ultralytics.utils.loss import v8SegmentationLoss

# unit tests

# ------------------------ Basic Test Cases ------------------------


def test_single_object_perfect_prediction():
    """Test single object, perfect prediction: loss should be ~0"""
    n, H, W = 1, 8, 8
    gt_mask = torch.ones((n, H, W))
    pred = torch.ones((n, 32))
    proto = torch.ones((32, H, W))
    xyxy = torch.tensor([[0, 0, H, W]], dtype=torch.float)
    area = torch.tensor([H * W], dtype=torch.float)
    # The predicted mask will be all 32 (sum of ones), so sigmoid(32) ~ 1, so BCE with gt=1 is very small
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 130μs -> 109μs (19.5% faster)


def test_single_object_all_zero_prediction():
    """Test single object, all-zero prediction: loss should be high (gt=1, pred=0)"""
    n, H, W = 1, 8, 8
    gt_mask = torch.ones((n, H, W))
    pred = torch.zeros((n, 32))
    proto = torch.ones((32, H, W))
    xyxy = torch.tensor([[0, 0, H, W]], dtype=torch.float)
    area = torch.tensor([H * W], dtype=torch.float)
    # The predicted mask will be all zeros, so BCE with gt=1 is -log(sigmoid(0)) = 0.693...
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 127μs -> 105μs (20.7% faster)


def test_single_object_all_zero_gt():
    """Test single object, gt mask all zeros, pred all zeros: loss should be ~0"""
    n, H, W = 1, 8, 8
    gt_mask = torch.zeros((n, H, W))
    pred = torch.zeros((n, 32))
    proto = torch.ones((32, H, W))
    xyxy = torch.tensor([[0, 0, H, W]], dtype=torch.float)
    area = torch.tensor([H * W], dtype=torch.float)
    # BCE with gt=0, pred=0: -log(1-sigmoid(0)) = 0.693...
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 128μs -> 113μs (12.9% faster)


def test_two_objects_different_boxes():
    """Test two objects, different boxes, different predictions."""
    n, H, W = 2, 8, 8
    gt_mask = torch.stack([torch.ones((H, W)), torch.zeros((H, W))])
    pred = torch.stack([torch.ones(32), torch.zeros(32)])
    proto = torch.ones((32, H, W))
    xyxy = torch.tensor([[0, 0, H, W], [0, 0, H, W]], dtype=torch.float)
    area = torch.tensor([H * W, H * W], dtype=torch.float)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 131μs -> 109μs (20.5% faster)


def test_empty_mask():
    """Test with zero objects (n=0): should return 0 loss."""
    n, H, W = 0, 8, 8
    gt_mask = torch.empty((n, H, W))
    pred = torch.empty((n, 32))
    proto = torch.ones((32, H, W))
    xyxy = torch.empty((n, 4))
    area = torch.empty((n,))
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 141μs -> 122μs (15.7% faster)


def test_zero_area_box():
    """Test with a box of zero area: should not divide by zero (should return nan or inf)."""
    n, H, W = 1, 8, 8
    gt_mask = torch.ones((n, H, W))
    pred = torch.ones((n, 32))
    proto = torch.ones((32, H, W))
    xyxy = torch.tensor([[1, 1, 1, 1]], dtype=torch.float)  # x1==x2, y1==y2
    area = torch.tensor([0.0], dtype=torch.float)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 140μs -> 122μs (15.1% faster)


def test_box_out_of_bounds():
    """Test with boxes outside mask bounds: should not crash, but mask is all zeros so mean is zero."""
    n, H, W = 1, 8, 8
    gt_mask = torch.ones((n, H, W))
    pred = torch.ones((n, 32))
    proto = torch.ones((32, H, W))
    xyxy = torch.tensor([[H + 1, W + 1, H + 2, W + 2]], dtype=torch.float)
    area = torch.tensor([1.0], dtype=torch.float)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 128μs -> 108μs (18.5% faster)


def test_non_integer_boxes():
    """Test with non-integer (float) box coordinates, should work as expected."""
    n, H, W = 1, 8, 8
    gt_mask = torch.ones((n, H, W))
    pred = torch.ones((n, 32))
    proto = torch.ones((32, H, W))
    xyxy = torch.tensor([[1.2, 2.8, 6.5, 7.1]], dtype=torch.float)
    area = torch.tensor([(6.5 - 1.2) * (7.1 - 2.8)], dtype=torch.float)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 131μs -> 108μs (20.6% faster)


def test_area_broadcasting():
    """Test that area shape (n,) is correctly broadcasted."""
    n, H, W = 3, 8, 8
    gt_mask = torch.ones((n, H, W))
    pred = torch.ones((n, 32))
    proto = torch.ones((32, H, W))
    xyxy = torch.tensor([[0, 0, H, W], [0, 0, H, W], [0, 0, H, W]], dtype=torch.float)
    area = torch.tensor([H * W, H * W, H * W], dtype=torch.float)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 132μs -> 117μs (13.0% faster)


# ------------------------ Large Scale Test Cases ------------------------


def test_large_batch():
    """Test with a large batch of objects (n=200, H=32, W=32)."""
    n, H, W = 200, 32, 32
    torch.manual_seed(0)
    gt_mask = torch.randint(0, 2, (n, H, W)).float()
    pred = torch.randn((n, 32))
    proto = torch.randn((32, H, W))
    xyxy = torch.cat([torch.randint(0, H // 2, (n, 2)).float(), torch.randint(H // 2, H, (n, 2)).float()], dim=1)
    area = (xyxy[:, 2] - xyxy[:, 0]) * (xyxy[:, 3] - xyxy[:, 1])
    # Ensure no zero area to avoid inf/nan
    area[area == 0] = 1.0
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 1.06ms -> 932μs (13.4% faster)


def test_large_mask_size():
    """Test with a single object but large mask size (n=1, H=128, W=128)."""
    n, H, W = 1, 128, 128
    torch.manual_seed(42)
    gt_mask = torch.randint(0, 2, (n, H, W)).float()
    pred = torch.randn((n, 32))
    proto = torch.randn((32, H, W))
    xyxy = torch.tensor([[0, 0, H, W]], dtype=torch.float)
    area = torch.tensor([H * W], dtype=torch.float)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 270μs -> 239μs (13.0% faster)


def test_maximum_allowed_tensor_size():
    """Test with largest tensor size under 100MB (n=32, H=64, W=64)."""
    n, H, W = 32, 64, 64  # 32*64*64*4 = 524288 bytes per tensor
    torch.manual_seed(123)
    gt_mask = torch.randint(0, 2, (n, H, W)).float()
    pred = torch.randn((n, 32))
    proto = torch.randn((32, H, W))
    xyxy = torch.cat([torch.randint(0, H // 2, (n, 2)).float(), torch.randint(H // 2, H, (n, 2)).float()], dim=1)
    area = (xyxy[:, 2] - xyxy[:, 0]) * (xyxy[:, 3] - xyxy[:, 1])
    area[area == 0] = 1.0
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 676μs -> 567μs (19.2% faster)


# ------------------------ Determinism Test ------------------------


def test_determinism():
    """Test that repeated calls with same input yield the same result."""
    n, H, W = 2, 16, 16
    torch.manual_seed(999)
    gt_mask = torch.randint(0, 2, (n, H, W)).float()
    pred = torch.randn((n, 32))
    proto = torch.randn((32, H, W))
    xyxy = torch.tensor([[2, 2, 10, 10], [0, 0, 16, 16]], dtype=torch.float)
    area = torch.tensor([(10 - 2) * (10 - 2), 16 * 16], dtype=torch.float)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss1 = codeflash_output  # 142μs -> 122μs (16.3% faster)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss2 = codeflash_output  # 68.6μs -> 54.7μs (25.3% faster)


# ------------------------ Type/Device Test Cases ------------------------


def test_float16_support():
    """Test that function works with float16 inputs."""
    n, H, W = 2, 8, 8
    gt_mask = torch.ones((n, H, W)).half()
    pred = torch.ones((n, 32)).half()
    proto = torch.ones((32, H, W)).half()
    xyxy = torch.tensor([[0, 0, H, W], [0, 0, H, W]], dtype=torch.float16)
    area = torch.tensor([H * W, H * W], dtype=torch.float16)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred, proto, xyxy, area)
    loss = codeflash_output  # 138μs -> 114μs (20.8% faster)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_loss_sensitive_to_prediction():
    """Test that loss decreases when prediction gets closer to ground truth."""
    n, H, W = 1, 8, 8
    gt_mask = torch.ones((n, H, W))
    pred_bad = torch.zeros((n, 32))
    pred_good = torch.ones((n, 32))
    proto = torch.ones((32, H, W))
    xyxy = torch.tensor([[0, 0, H, W]], dtype=torch.float)
    area = torch.tensor([H * W], dtype=torch.float)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred_bad, proto, xyxy, area)
    loss_bad = codeflash_output  # 140μs -> 120μs (17.2% faster)
    codeflash_output = v8SegmentationLoss.single_mask_loss(gt_mask, pred_good, proto, xyxy, area)
    loss_good = codeflash_output  # 64.3μs -> 49.8μs (29.3% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-v8SegmentationLoss.single_mask_loss-mirhjj6d and push.

Codeflash Static Badge

The optimization achieves a **16% speedup** through two key improvements:

**1. Optimized Matrix Multiplication in `single_mask_loss`:**
- Replaced `torch.einsum("in,nhw->ihw", pred, proto)` with `torch.matmul(pred, proto_flat).view(-1, proto.shape[1], proto.shape[2])`
- This eliminates einsum's parsing overhead and uses PyTorch's highly optimized BLAS-backed matmul operations
- Line profiler shows the matrix operation improved from 2.44ms to 1.23ms (49% faster for this line)

**2. More Efficient Mask Generation in `crop_mask`:**
- Replaced chained arithmetic operations `((r >= x1) * (r < x2) * (c >= y1) * (c < y2))` with logical operations using `&` operator
- Split the mask computation into separate x and y components (`mask_x` and `mask_y`) before combining
- Used explicit `.view()` calls instead of advanced indexing for tensor reshaping

**Performance Impact:**
The optimizations are particularly effective for:
- **Large batches**: Test with 128 objects shows 20.7% speedup
- **Large mask sizes**: 128x128 masks show 13.9% improvement  
- **Multiple objects**: All multi-object test cases show 13-20% improvements

The `crop_mask` function's total time reduced from 6.56ms to 4.94ms (25% faster), while `single_mask_loss` improved from 15.35ms to 12.35ms (20% faster). These functions are likely in the training hot path for YOLO segmentation models, making these optimizations valuable for training performance. The improvements are consistent across different input sizes and configurations, indicating robust performance gains.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 December 4, 2025 13:42
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Dec 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant