⚡️ Speed up method SAM2Model._apply_non_overlapping_constraints by 95%
#50
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 95% (0.95x) speedup for
SAM2Model._apply_non_overlapping_constraintsinultralytics/models/sam/modules/sam.py⏱️ Runtime :
3.12 milliseconds→1.60 milliseconds(best of66runs)📝 Explanation and details
The optimized code achieves a 95% speedup by making two key optimizations to the
_apply_non_overlapping_constraintsmethod:Key Optimizations:
Replaced
torch.argmaxwithtorch.max: Changed fromtorch.argmax(pred_masks, dim=0, keepdim=True)to_, max_obj_inds = torch.max(pred_masks, dim=0, keepdim=True). This is faster becausetorch.maxreturns both values and indices in a single fused operation, whiletorch.argmaxonly computes indices but still needs to traverse the data similarly.Pre-computed clamped tensor: Moved
torch.clamp(pred_masks, max=-10.0)outside thetorch.wherecall by pre-computing it asmin_mask. This avoids redundant tensor clamping operations within the conditional assignment.Performance Impact:
torch.argmaxline dropped from 2.66ms to 1.19ms (55% reduction)torch.whereoperation became more efficient, dropping from 764μs to 422μsTest Case Performance:
The optimization shows particularly strong gains on larger tensors:
Why This Works:
In PyTorch,
torch.maxis optimized as a single kernel operation that finds both maximum values and their indices simultaneously, whiletorch.argmaxperforms similar work but discards the values. Pre-computing the clamped tensor reduces redundant memory operations in the conditional assignment. These optimizations are especially effective for the multi-object segmentation use case where batch sizes and spatial dimensions are typically large.✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
To edit these changes
git checkout codeflash/optimize-SAM2Model._apply_non_overlapping_constraints-mirdk9aland push.