ultralytics 8.1.16 OBB ConfusionMatrix support (#8299)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-02-19 23:59:24 +08:00 committed by GitHub
parent 42744a1717
commit de01212465
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 26 additions and 20 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.15"
__version__ = "8.1.16"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

View file

@ -9,7 +9,7 @@ model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
data: # (str, optional) path to data file, i.e. coco128.yaml
epochs: 100 # (int) number of epochs to train for
time: # (float, optional) number of hours to train for, overrides epochs if supplied
patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
patience: 100 # (int) epochs to wait for no observable improvement for early stopping of training
batch: 16 # (int) number of images per batch (-1 for AutoBatch)
imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
save: True # (bool) save train checkpoints and predict results

View file

@ -132,8 +132,7 @@ class DetectionValidator(BaseValidator):
if nl:
for k in self.stats.keys():
self.stats[k].append(stat[k])
# TODO: obb has not supported confusion_matrix yet.
if self.args.plots and self.args.task != "obb":
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
continue
@ -147,8 +146,7 @@ class DetectionValidator(BaseValidator):
# Evaluate
if nl:
stat["tp"] = self._process_batch(predn, bbox, cls)
# TODO: obb has not supported confusion_matrix yet.
if self.args.plots and self.args.task != "obb":
if self.args.plots:
self.confusion_matrix.process_batch(predn, bbox, cls)
for k in self.stats.keys():
self.stats[k].append(stat[k])

View file

@ -55,10 +55,11 @@ class OBBValidator(DetectionValidator):
Return correct prediction matrix.
Args:
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class.
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
Each label is of the format: class, x1, y1, x2, y2.
detections (torch.Tensor): Tensor of shape [N, 7] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class, angle.
gt_bboxes (torch.Tensor): Tensor of shape [M, 5] representing rotated boxes.
Each box is of the format: x1, y1, x2, y2, angle.
labels (torch.Tensor): Tensor of shape [M] representing labels.
Returns:
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.

View file

@ -26,6 +26,7 @@ GITHUB_ASSETS_NAMES = (
+ [f"FastSAM-{k}.pt" for k in "sx"]
+ [f"rtdetr-{k}.pt" for k in "lx"]
+ ["mobile_sam.pt"]
+ ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"]
)
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]

View file

@ -326,9 +326,10 @@ class ConfusionMatrix:
Update confusion matrix for object detection task.
Args:
detections (Array[N, 6]): Detected bounding boxes and their associated information.
Each row should contain (x1, y1, x2, y2, conf, class).
gt_bboxes (Array[M, 4]): Ground truth bounding boxes with xyxy format.
detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
Each row should contain (x1, y1, x2, y2, conf, class)
or with an additional element `angle` when it's obb.
gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
gt_cls (Array[M]): The class labels.
"""
if gt_cls.shape[0] == 0: # Check if labels is empty
@ -347,7 +348,12 @@ class ConfusionMatrix:
detections = detections[detections[:, 4] > self.conf]
gt_classes = gt_cls.int()
detection_classes = detections[:, 5].int()
iou = box_iou(gt_bboxes, detections[:, :4])
is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension
iou = (
batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
if is_obb
else box_iou(gt_bboxes, detections[:, :4])
)
x = torch.where(iou > self.iou_thres)
if x[0].shape[0]: