ultralytics 8.0.197 save P, R, F1 curves to metrics (#5354)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: erminkev1 <83356055+erminkev1@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Andy <39454881+yermandy@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-10-13 02:49:31 +02:00 committed by GitHub
parent 7fd5dcbd86
commit 12e3eef844
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
33 changed files with 337 additions and 195 deletions

View file

@ -66,7 +66,7 @@ class DETRLoss(nn.Module):
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
"""Computes the classification loss based on predictions, target values, and ground truth scores."""
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
name_class = f'loss_class{postfix}'
bs, nq = pred_scores.shape[:2]
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
@ -90,7 +90,7 @@ class DETRLoss(nn.Module):
"""Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
boxes.
"""
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
name_bbox = f'loss_bbox{postfix}'
name_giou = f'loss_giou{postfix}'