diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 3945f0391af..2856aac71db 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -367,7 +367,9 @@ def single_mask_loss( The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the predicted masks from the prototype masks and predicted mask coefficients. """ - pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80) + # matmul for efficiency: (n,32) x (32,HW) -> (n,HW) -> (n,H,W) + proto_flat = proto.view(proto.shape[0], -1) + pred_mask = torch.matmul(pred, proto_flat).view(-1, proto.shape[1], proto.shape[2]) loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none") return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum() diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index 0f0c8c07b7e..97030b950ab 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -669,12 +669,16 @@ def crop_mask(masks, boxes): Returns: (torch.Tensor): Cropped masks. """ - _, h, w = masks.shape - x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1) - r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w) - c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1) - - return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + n, h, w = masks.shape + x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, dim=1) # x1 shape(n,1,1) + rows = torch.arange(w, device=masks.device, dtype=x1.dtype).view(1, 1, w) + cols = torch.arange(h, device=masks.device, dtype=x1.dtype).view(1, h, 1) + + # Use torch.logical_and for more efficient mask computation and broadcasting. + mask_x = (rows >= x1) & (rows < x2) # shape [n, h, w], x1,x2 broadcasted + mask_y = (cols >= y1) & (cols < y2) # shape [n, h, w], y1,y2 broadcasted + crop = mask_x & mask_y # elementwise-and for the crop mask + return masks * crop def process_mask(protos, masks_in, bboxes, shape, upsample=False):