Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions ultralytics/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,31 @@ def __init__(self, reg_max):

def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""Compute IoU and DFL losses for rotated bounding boxes."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

# DFL loss
if self.dfl_loss:
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
# Precompute indices once for foreground objects
if fg_mask is not None and fg_mask.any():
# Gather masked tensors in advance for multiple uses
masked_target_scores = target_scores.sum(-1)[fg_mask]
weight = masked_target_scores.unsqueeze(-1)
masked_pred_bboxes = pred_bboxes[fg_mask]
masked_target_bboxes = target_bboxes[fg_mask]

iou = probiou(masked_pred_bboxes, masked_target_bboxes)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

# DFL loss
if self.dfl_loss:
# Mask channel dims for xywh2xyxy and bbox2dist only once
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
masked_pred_dist = pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max)
masked_target_ltrb = target_ltrb[fg_mask]
loss_dfl = self.dfl_loss(masked_pred_dist, masked_target_ltrb) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0, device=pred_dist.device)
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
# fg_mask is empty or None: return zeros
loss_iou = torch.tensor(0.0, device=pred_dist.device)
loss_dfl = torch.tensor(0.0, device=pred_dist.device)

return loss_iou, loss_dfl

Expand Down
44 changes: 30 additions & 14 deletions ultralytics/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,27 +212,43 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
- OBB format: [center_x, center_y, width, height, rotation_angle].
- Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
"""
x1, y1 = obb1[..., :2].split(1, dim=-1)
x2, y2 = obb2[..., :2].split(1, dim=-1)
# Faster slicing instead of split
x1 = obb1[..., 0:1]
y1 = obb1[..., 1:2]
x2 = obb2[..., 0:1]
y2 = obb2[..., 1:2]
a1, b1, c1 = _get_covariance_matrix(obb1)
a2, b2, c2 = _get_covariance_matrix(obb2)

t1 = (
((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)
) * 0.25
t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5
t3 = (
((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))
/ (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)
+ eps
).log() * 0.5
a1_a2 = a1 + a2
b1_b2 = b1 + b2
c1_c2 = c1 + c2
denom = (a1_a2 * b1_b2 - c1_c2.pow(2)) + eps # shared denominator

# Only compute each piece once
dx = x1 - x2
dy = y1 - y2

t1 = ((a1_a2 * dy.pow(2) + b1_b2 * dx.pow(2)) / denom) * 0.25
t2 = ((c1_c2 * (x2 - x1) * (y1 - y2)) / denom) * 0.5

# Precompute terms for t3 for better memory/batch parallelization
a1b1 = (a1 * b1 - c1.pow(2)).clamp_(0)
a2b2 = (a2 * b2 - c2.pow(2)).clamp_(0)
sqrt_ab = (a1b1 * a2b2).sqrt()
numer = (a1_a2 * b1_b2 - c1_c2.pow(2)) + eps
denom2 = 4 * sqrt_ab + eps
t3 = (numer / denom2).log() * 0.5

bd = (t1 + t2 + t3).clamp(eps, 100.0)
hd = (1.0 - (-bd).exp() + eps).sqrt()
iou = 1 - hd
if CIoU: # only include the wh aspect ratio part
w1, h1 = obb1[..., 2:4].split(1, dim=-1)
w2, h2 = obb2[..., 2:4].split(1, dim=-1)
v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)
w1 = obb1[..., 2:3]
h1 = obb1[..., 3:4]
w2 = obb2[..., 2:3]
h2 = obb2[..., 3:4]
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - v * alpha # CIoU
Expand Down