ultralytics 8.0.235 YOLOv8 OBB train, val, predict and export (#4499)

Co-authored-by: Yash Khurana <ykhurana6@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Swamita Gupta <swamita2001@gmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1182102784@qq.com>
This commit is contained in:
Glenn Jocher 2024-01-05 03:00:26 +01:00 committed by GitHub
parent f702b34a50
commit 072291bc78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
52 changed files with 2090 additions and 524 deletions

View file

@ -51,6 +51,7 @@ class SegmentationValidator(DetectionValidator):
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[])
def get_desc(self):
"""Return a formatted description of evaluation metrics."""
@ -70,59 +71,62 @@ class SegmentationValidator(DetectionValidator):
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
return p, proto
def _prepare_batch(self, si, batch):
prepared_batch = super()._prepare_batch(si, batch)
midx = [si] if self.args.overlap_mask else batch['batch_idx'] == si
prepared_batch['masks'] = batch['masks'][midx]
return prepared_batch
def _prepare_pred(self, pred, pbatch, proto):
predn = super()._prepare_pred(pred, pbatch)
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch['imgsz'])
return predn, pred_masks
def update_metrics(self, preds, batch):
"""Metrics."""
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch['ori_shape'][si]
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
npr = len(pred)
stat = dict(conf=torch.zeros(0, device=self.device),
pred_cls=torch.zeros(0, device=self.device),
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
pbatch = self._prepare_batch(si, batch)
cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
nl = len(cls)
stat['target_cls'] = cls
if npr == 0:
if nl:
self.stats.append((correct_bboxes, correct_masks, *torch.zeros(
(2, 0), device=self.device), cls.squeeze(-1)))
for k in self.stats.keys():
self.stats[k].append(stat[k])
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
continue
# Masks
midx = [si] if self.args.overlap_mask else idx
gt_masks = batch['masks'][midx]
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
gt_masks = pbatch.pop('masks')
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch['ratio_pad'][si]) # native-space pred
predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
stat['conf'] = predn[:, 4]
stat['pred_cls'] = predn[:, 5]
# Evaluate
if nl:
height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch['ratio_pad'][si]) # native-space labels
labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable
correct_masks = self._process_batch(predn,
labelsn,
pred_masks,
gt_masks,
overlap=self.args.overlap_mask,
masks=True)
stat['tp'] = self._process_batch(predn, bbox, cls)
stat['tp_m'] = self._process_batch(predn,
bbox,
cls,
pred_masks,
gt_masks,
self.args.overlap_mask,
masks=True)
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
self.confusion_matrix.process_batch(predn, bbox, cls)
# Append correct_masks, correct_boxes, pconf, pcls, tcls
self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
for k in self.stats.keys():
self.stats[k].append(stat[k])
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if self.args.plots and self.batch_i < 3:
@ -131,7 +135,7 @@ class SegmentationValidator(DetectionValidator):
# Save
if self.args.save_json:
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
shape,
pbatch['ori_shape'],
ratio_pad=batch['ratio_pad'][si])
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
# if self.args.save_txt:
@ -142,7 +146,7 @@ class SegmentationValidator(DetectionValidator):
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
Return correct prediction matrix.
@ -155,7 +159,7 @@ class SegmentationValidator(DetectionValidator):
"""
if masks:
if overlap:
nl = len(labels)
nl = len(gt_cls)
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
@ -164,9 +168,9 @@ class SegmentationValidator(DetectionValidator):
gt_masks = gt_masks.gt_(0.5)
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
else: # boxes
iou = box_iou(labels[:, 1:], detections[:, :4])
iou = box_iou(gt_bboxes, detections[:, :4])
return self.match_predictions(detections[:, 5], labels[:, 0], iou)
return self.match_predictions(detections[:, 5], gt_cls, iou)
def plot_val_samples(self, batch, ni):
"""Plots validation samples with bounding box labels."""
@ -174,7 +178,7 @@ class SegmentationValidator(DetectionValidator):
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
batch['masks'],
masks=batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,