ultralytics 8.0.197 save P, R, F1 curves to metrics (#5354)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: erminkev1 <83356055+erminkev1@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Andy <39454881+yermandy@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-10-13 02:49:31 +02:00 committed by GitHub
parent 7fd5dcbd86
commit 12e3eef844
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
33 changed files with 337 additions and 195 deletions

View file

@ -66,7 +66,7 @@ class DETRLoss(nn.Module):
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
"""Computes the classification loss based on predictions, target values, and ground truth scores."""
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
name_class = f'loss_class{postfix}'
bs, nq = pred_scores.shape[:2]
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
@ -90,7 +90,7 @@ class DETRLoss(nn.Module):
"""Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
boxes.
"""
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
name_bbox = f'loss_bbox{postfix}'
name_giou = f'loss_giou{postfix}'

View file

@ -188,7 +188,7 @@ def get_cdn_group(batch,
num_group = num_dn // max_nums
num_group = 1 if num_group == 0 else num_group
# pad gt to max_num of a batch
# Pad gt to max_num of a batch
bs = len(gt_groups)
gt_cls = batch['cls'] # (bs*num, )
gt_bbox = batch['bboxes'] # bs*num, 4
@ -204,10 +204,10 @@ def get_cdn_group(batch,
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
if cls_noise_ratio > 0:
# half of bbox prob
# Half of bbox prob
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
idx = torch.nonzero(mask).squeeze(-1)
# randomly put a new one here
# Randomly put a new one here
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
dn_cls[idx] = new_label
@ -240,9 +240,9 @@ def get_cdn_group(batch,
tgt_size = num_dn + num_queries
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
# match query cannot see the reconstruct
# Match query cannot see the reconstruct
attn_mask[num_dn:, :num_dn] = True
# reconstruct cannot see each other
# Reconstruct cannot see each other
for i in range(num_group):
if i == 0:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True