ultralytics 8.0.197 save P, R, F1 curves to metrics (#5354)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: erminkev1 <83356055+erminkev1@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Andy <39454881+yermandy@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-10-13 02:49:31 +02:00 committed by GitHub
parent 7fd5dcbd86
commit 12e3eef844
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
33 changed files with 337 additions and 195 deletions

View file

@ -188,7 +188,7 @@ def get_cdn_group(batch,
num_group = num_dn // max_nums
num_group = 1 if num_group == 0 else num_group
# pad gt to max_num of a batch
# Pad gt to max_num of a batch
bs = len(gt_groups)
gt_cls = batch['cls'] # (bs*num, )
gt_bbox = batch['bboxes'] # bs*num, 4
@ -204,10 +204,10 @@ def get_cdn_group(batch,
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
if cls_noise_ratio > 0:
# half of bbox prob
# Half of bbox prob
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
idx = torch.nonzero(mask).squeeze(-1)
# randomly put a new one here
# Randomly put a new one here
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
dn_cls[idx] = new_label
@ -240,9 +240,9 @@ def get_cdn_group(batch,
tgt_size = num_dn + num_queries
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
# match query cannot see the reconstruct
# Match query cannot see the reconstruct
attn_mask[num_dn:, :num_dn] = True
# reconstruct cannot see each other
# Reconstruct cannot see each other
for i in range(num_group):
if i == 0:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True