ultralytics 8.1.16 OBB ConfusionMatrix support (#8299)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-02-19 23:59:24 +08:00 committed by GitHub
parent 42744a1717
commit de01212465
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 26 additions and 20 deletions

View file

@ -326,9 +326,10 @@ class ConfusionMatrix:
Update confusion matrix for object detection task.
Args:
detections (Array[N, 6]): Detected bounding boxes and their associated information.
Each row should contain (x1, y1, x2, y2, conf, class).
gt_bboxes (Array[M, 4]): Ground truth bounding boxes with xyxy format.
detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
Each row should contain (x1, y1, x2, y2, conf, class)
or with an additional element `angle` when it's obb.
gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
gt_cls (Array[M]): The class labels.
"""
if gt_cls.shape[0] == 0: # Check if labels is empty
@ -347,7 +348,12 @@ class ConfusionMatrix:
detections = detections[detections[:, 4] > self.conf]
gt_classes = gt_cls.int()
detection_classes = detections[:, 5].int()
iou = box_iou(gt_bboxes, detections[:, :4])
is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension
iou = (
batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
if is_obb
else box_iou(gt_bboxes, detections[:, :4])
)
x = torch.where(iou > self.iou_thres)
if x[0].shape[0]: