ultralytics 8.0.235 YOLOv8 OBB train, val, predict and export (#4499)

Co-authored-by: Yash Khurana <ykhurana6@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Swamita Gupta <swamita2001@gmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1182102784@qq.com>
This commit is contained in:
Glenn Jocher 2024-01-05 03:00:26 +01:00 committed by GitHub
parent f702b34a50
commit 072291bc78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
52 changed files with 2090 additions and 524 deletions

View file

@ -1,8 +1,11 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import math
import random
from copy import copy
import numpy as np
import torch.nn as nn
from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
@ -54,6 +57,16 @@ class DetectionTrainer(BaseTrainer):
def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
if self.args.multi_scale:
imgs = batch['img']
sz = (random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride) // self.stride *
self.stride) # size
sf = sz / max(imgs.shape[2:]) # scale factor
if sf != 1:
ns = [math.ceil(x * sf / self.stride) * self.stride
for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
batch['img'] = imgs
return batch
def set_model_attributes(self):

View file

@ -70,7 +70,7 @@ class DetectionValidator(BaseValidator):
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
self.seen = 0
self.jdict = []
self.stats = []
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[])
def get_desc(self):
"""Return a formatted string summarizing class metrics of YOLO model."""
@ -86,51 +86,68 @@ class DetectionValidator(BaseValidator):
agnostic=self.args.single_cls,
max_det=self.args.max_det)
def _prepare_batch(self, si, batch):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx].squeeze(-1)
bbox = batch['bboxes'][idx]
ori_shape = batch['ori_shape'][si]
imgsz = batch['img'].shape[2:]
ratio_pad = batch['ratio_pad'][si]
if len(cls):
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
prepared_batch = dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
return prepared_batch
def _prepare_pred(self, pred, pbatch):
predn = pred.clone()
ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'],
ratio_pad=pbatch['ratio_pad']) # native-space pred
return predn
def update_metrics(self, preds, batch):
"""Metrics."""
for si, pred in enumerate(preds):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch['ori_shape'][si]
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
npr = len(pred)
stat = dict(conf=torch.zeros(0, device=self.device),
pred_cls=torch.zeros(0, device=self.device),
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
pbatch = self._prepare_batch(si, batch)
cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
nl = len(cls)
stat['target_cls'] = cls
if npr == 0:
if nl:
self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
for k in self.stats.keys():
self.stats[k].append(stat[k])
# TODO: obb has not supported confusion_matrix yet.
if self.args.plots and self.args.task != 'obb':
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
continue
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch['ratio_pad'][si]) # native-space pred
predn = self._prepare_pred(pred, pbatch)
stat['conf'] = predn[:, 4]
stat['pred_cls'] = predn[:, 5]
# Evaluate
if nl:
height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch['ratio_pad'][si]) # native-space labels
labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
stat['tp'] = self._process_batch(predn, bbox, cls)
# TODO: obb has not supported confusion_matrix yet.
if self.args.plots and self.args.task != 'obb':
self.confusion_matrix.process_batch(predn, bbox, cls)
for k in self.stats.keys():
self.stats[k].append(stat[k])
# Save
if self.args.save_json:
self.pred_to_json(predn, batch['im_file'][si])
if self.args.save_txt:
file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt'
self.save_one_txt(predn, self.args.save_conf, shape, file)
self.save_one_txt(predn, self.args.save_conf, pbatch['ori_shape'], file)
def finalize_metrics(self, *args, **kwargs):
"""Set final values for metrics speed and confusion matrix."""
@ -139,10 +156,11 @@ class DetectionValidator(BaseValidator):
def get_stats(self):
"""Returns metrics statistics and results dictionary."""
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
if len(stats) and stats[0].any():
self.metrics.process(*stats)
self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
if len(stats) and stats['tp'].any():
self.metrics.process(**stats)
self.nt_per_class = np.bincount(stats['target_cls'].astype(int),
minlength=self.nc) # number of targets per class
return self.metrics.results_dict
def print_results(self):
@ -165,7 +183,7 @@ class DetectionValidator(BaseValidator):
normalize=normalize,
on_plot=self.on_plot)
def _process_batch(self, detections, labels):
def _process_batch(self, detections, gt_bboxes, gt_cls):
"""
Return correct prediction matrix.
@ -178,8 +196,8 @@ class DetectionValidator(BaseValidator):
Returns:
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
"""
iou = box_iou(labels[:, 1:], detections[:, :4])
return self.match_predictions(detections[:, 5], labels[:, 0], iou)
iou = box_iou(gt_bboxes, detections[:, :4])
return self.match_predictions(detections[:, 5], gt_cls, iou)
def build_dataset(self, img_path, mode='val', batch=None):
"""