Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ultralytics/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ def single_mask_loss(
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
predicted masks from the prototype masks and predicted mask coefficients.
"""
pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
# matmul for efficiency: (n,32) x (32,HW) -> (n,HW) -> (n,H,W)
proto_flat = proto.view(proto.shape[0], -1)
pred_mask = torch.matmul(pred, proto_flat).view(-1, proto.shape[1], proto.shape[2])
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()

Expand Down
16 changes: 10 additions & 6 deletions ultralytics/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,12 +669,16 @@ def crop_mask(masks, boxes):
Returns:
(torch.Tensor): Cropped masks.
"""
_, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)

return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
n, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, dim=1) # x1 shape(n,1,1)
rows = torch.arange(w, device=masks.device, dtype=x1.dtype).view(1, 1, w)
cols = torch.arange(h, device=masks.device, dtype=x1.dtype).view(1, h, 1)

# Use torch.logical_and for more efficient mask computation and broadcasting.
mask_x = (rows >= x1) & (rows < x2) # shape [n, h, w], x1,x2 broadcasted
mask_y = (cols >= y1) & (cols < y2) # shape [n, h, w], y1,y2 broadcasted
crop = mask_x & mask_y # elementwise-and for the crop mask
return masks * crop


def process_mask(protos, masks_in, bboxes, shape, upsample=False):
Expand Down