Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions ultralytics/models/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,24 +193,22 @@ def get_cdn_group(
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num

if cls_noise_ratio > 0:
# Half of bbox prob
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
idx = torch.nonzero(mask).squeeze(-1)
# Randomly put a new one here
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
mask = torch.rand(dn_cls.shape, device=dn_cls.device) < (cls_noise_ratio * 0.5)
idx = mask.nonzero(as_tuple=True)[0]
new_label = torch.randint(0, num_classes, idx.shape, dtype=dn_cls.dtype, device=dn_cls.device)
dn_cls[idx] = new_label

if box_noise_scale > 0:
known_bbox = xywh2xyxy(dn_bbox)

diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4

rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
rand_part = torch.rand_like(dn_bbox)
rand_sign = torch.randint(0, 2, dn_bbox.shape, device=dn_bbox.device, dtype=dn_bbox.dtype) * 2.0 - 1.0
rand_part = torch.rand(dn_bbox.shape, device=dn_bbox.device, dtype=dn_bbox.dtype)
rand_part[neg_idx] += 1.0
rand_part *= rand_sign
known_bbox += rand_part * diff
known_bbox.clip_(min=0.0, max=1.0)
known_bbox.clamp_(min=0.0, max=1.0)
dn_bbox = xyxy2xywh(known_bbox)
dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid

Expand All @@ -220,35 +218,39 @@ def get_cdn_group(
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)

map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
# Efficient mapping without Python loops, use tensor arange and concatenation
map_indices = torch.cat([torch.arange(num, device=dn_b_idx.device) for num in gt_groups])
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)

map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
map_indices_full = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
padding_cls[(dn_b_idx, map_indices_full)] = dn_cls_embed
padding_bbox[(dn_b_idx, map_indices_full)] = dn_bbox

tgt_size = num_dn + num_queries
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool, device=gt_cls.device)
# Match query cannot see the reconstruct
attn_mask[num_dn:, :num_dn] = True
# Reconstruct cannot see each other
for i in range(num_group):
left = max_nums * 2 * i
right = max_nums * 2 * (i + 1)
if i == 0:
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
attn_mask[left:right, right:num_dn] = True
if i == num_group - 1:
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
attn_mask[left:right, : max_nums * i * 2] = True
else:
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
attn_mask[left:right, right:num_dn] = True
attn_mask[left:right, :left] = True

dn_meta = {
"dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
"dn_num_group": num_group,
"dn_num_split": [num_dn, num_queries],
}

return (
padding_cls.to(class_embed.device),
padding_bbox.to(class_embed.device),
attn_mask.to(class_embed.device),
padding_cls,
padding_bbox,
attn_mask,
dn_meta,
)
12 changes: 8 additions & 4 deletions ultralytics/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,14 @@ def xyxy2xywh(x):
"""
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = empty_like(x) # faster than clone/copy
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
y[..., 2] = x[..., 2] - x[..., 0] # width
y[..., 3] = x[..., 3] - x[..., 1] # height
if hasattr(x, "dtype") and "torch" in str(type(x)):
# torch.Tensor path
y[..., 0:2] = (x[..., 0:2] + x[..., 2:4]) / 2
y[..., 2:4] = x[..., 2:4] - x[..., 0:2]
else:
# numpy path
y[..., 0:2] = (x[..., 0:2] + x[..., 2:4]) / 2
y[..., 2:4] = x[..., 2:4] - x[..., 0:2]
return y


Expand Down