ultralytics 8.0.235 YOLOv8 OBB train, val, predict and export (#4499)
Co-authored-by: Yash Khurana <ykhurana6@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Swamita Gupta <swamita2001@gmail.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1182102784@qq.com>
This commit is contained in:
parent
f702b34a50
commit
072291bc78
52 changed files with 2090 additions and 524 deletions
|
|
@ -66,57 +66,63 @@ class PoseValidator(DetectionValidator):
|
|||
is_pose = self.kpt_shape == [17, 3]
|
||||
nkpt = self.kpt_shape[0]
|
||||
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
||||
self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[])
|
||||
|
||||
def _prepare_batch(self, si, batch):
|
||||
pbatch = super()._prepare_batch(si, batch)
|
||||
kpts = batch['keypoints'][batch['batch_idx'] == si]
|
||||
h, w = pbatch['imgsz']
|
||||
kpts = kpts.clone()
|
||||
kpts[..., 0] *= w
|
||||
kpts[..., 1] *= h
|
||||
kpts = ops.scale_coords(pbatch['imgsz'], kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'])
|
||||
pbatch['kpts'] = kpts
|
||||
return pbatch
|
||||
|
||||
def _prepare_pred(self, pred, pbatch):
|
||||
predn = super()._prepare_pred(pred, pbatch)
|
||||
nk = pbatch['kpts'].shape[1]
|
||||
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
|
||||
ops.scale_coords(pbatch['imgsz'], pred_kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'])
|
||||
return predn, pred_kpts
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Metrics."""
|
||||
for si, pred in enumerate(preds):
|
||||
idx = batch['batch_idx'] == si
|
||||
cls = batch['cls'][idx]
|
||||
bbox = batch['bboxes'][idx]
|
||||
kpts = batch['keypoints'][idx]
|
||||
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
||||
nk = kpts.shape[1] # number of keypoints
|
||||
shape = batch['ori_shape'][si]
|
||||
correct_kpts = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
self.seen += 1
|
||||
|
||||
npr = len(pred)
|
||||
stat = dict(conf=torch.zeros(0, device=self.device),
|
||||
pred_cls=torch.zeros(0, device=self.device),
|
||||
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
||||
tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
|
||||
pbatch = self._prepare_batch(si, batch)
|
||||
cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
|
||||
nl = len(cls)
|
||||
stat['target_cls'] = cls
|
||||
if npr == 0:
|
||||
if nl:
|
||||
self.stats.append((correct_bboxes, correct_kpts, *torch.zeros(
|
||||
(2, 0), device=self.device), cls.squeeze(-1)))
|
||||
for k in self.stats.keys():
|
||||
self.stats[k].append(stat[k])
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
|
||||
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
|
||||
continue
|
||||
|
||||
# Predictions
|
||||
if self.args.single_cls:
|
||||
pred[:, 5] = 0
|
||||
predn = pred.clone()
|
||||
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
|
||||
ratio_pad=batch['ratio_pad'][si]) # native-space pred
|
||||
pred_kpts = predn[:, 6:].view(npr, nk, -1)
|
||||
ops.scale_coords(batch['img'][si].shape[1:], pred_kpts, shape, ratio_pad=batch['ratio_pad'][si])
|
||||
predn, pred_kpts = self._prepare_pred(pred, pbatch)
|
||||
stat['conf'] = predn[:, 4]
|
||||
stat['pred_cls'] = predn[:, 5]
|
||||
|
||||
# Evaluate
|
||||
if nl:
|
||||
height, width = batch['img'].shape[2:]
|
||||
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
|
||||
(width, height, width, height), device=self.device) # target boxes
|
||||
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
|
||||
ratio_pad=batch['ratio_pad'][si]) # native-space labels
|
||||
tkpts = kpts.clone()
|
||||
tkpts[..., 0] *= width
|
||||
tkpts[..., 1] *= height
|
||||
tkpts = ops.scale_coords(batch['img'][si].shape[1:], tkpts, shape, ratio_pad=batch['ratio_pad'][si])
|
||||
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
||||
correct_bboxes = self._process_batch(predn[:, :6], labelsn)
|
||||
correct_kpts = self._process_batch(predn[:, :6], labelsn, pred_kpts, tkpts)
|
||||
stat['tp'] = self._process_batch(predn, bbox, cls)
|
||||
stat['tp_p'] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch['kpts'])
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(predn, labelsn)
|
||||
self.confusion_matrix.process_batch(predn, bbox, cls)
|
||||
|
||||
# Append correct_masks, correct_boxes, pconf, pcls, tcls
|
||||
self.stats.append((correct_bboxes, correct_kpts, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
|
||||
for k in self.stats.keys():
|
||||
self.stats[k].append(stat[k])
|
||||
|
||||
# Save
|
||||
if self.args.save_json:
|
||||
|
|
@ -124,7 +130,7 @@ class PoseValidator(DetectionValidator):
|
|||
# if self.args.save_txt:
|
||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||
|
||||
def _process_batch(self, detections, labels, pred_kpts=None, gt_kpts=None):
|
||||
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
|
||||
"""
|
||||
Return correct prediction matrix.
|
||||
|
||||
|
|
@ -142,12 +148,12 @@ class PoseValidator(DetectionValidator):
|
|||
"""
|
||||
if pred_kpts is not None and gt_kpts is not None:
|
||||
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
||||
area = ops.xyxy2xywh(labels[:, 1:])[:, 2:].prod(1) * 0.53
|
||||
area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53
|
||||
iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
|
||||
else: # boxes
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
iou = box_iou(gt_bboxes, detections[:, :4])
|
||||
|
||||
return self.match_predictions(detections[:, 5], labels[:, 0], iou)
|
||||
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue