OBB: Fix distorted plotting (#9899)
This commit is contained in:
parent
e040ce0618
commit
417c429ec4
2 changed files with 13 additions and 8 deletions
|
|
@ -975,17 +975,22 @@ class Format:
|
||||||
1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio
|
1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio
|
||||||
)
|
)
|
||||||
labels["masks"] = masks
|
labels["masks"] = masks
|
||||||
if self.normalize:
|
|
||||||
instances.normalize(w, h)
|
|
||||||
labels["img"] = self._format_img(img)
|
labels["img"] = self._format_img(img)
|
||||||
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
||||||
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
||||||
if self.return_keypoint:
|
if self.return_keypoint:
|
||||||
labels["keypoints"] = torch.from_numpy(instances.keypoints)
|
labels["keypoints"] = torch.from_numpy(instances.keypoints)
|
||||||
|
if self.normalize:
|
||||||
|
labels["keypoints"][..., 0] /= w
|
||||||
|
labels["keypoints"][..., 1] /= h
|
||||||
if self.return_obb:
|
if self.return_obb:
|
||||||
labels["bboxes"] = (
|
labels["bboxes"] = (
|
||||||
xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5))
|
xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5))
|
||||||
)
|
)
|
||||||
|
# NOTE: need to normalize obb in xywhr format for width-height consistency
|
||||||
|
if self.normalize:
|
||||||
|
labels["bboxes"][:, [0, 2]] /= w
|
||||||
|
labels["bboxes"][:, [1, 3]] /= h
|
||||||
# Then we can use collate_fn
|
# Then we can use collate_fn
|
||||||
if self.batch_idx:
|
if self.batch_idx:
|
||||||
labels["batch_idx"] = torch.zeros(nl)
|
labels["batch_idx"] = torch.zeros(nl)
|
||||||
|
|
|
||||||
|
|
@ -838,16 +838,16 @@ def plot_images(
|
||||||
if len(bboxes):
|
if len(bboxes):
|
||||||
boxes = bboxes[idx]
|
boxes = bboxes[idx]
|
||||||
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
|
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
|
||||||
is_obb = boxes.shape[-1] == 5 # xywhr
|
|
||||||
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
|
|
||||||
if len(boxes):
|
if len(boxes):
|
||||||
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
|
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
|
||||||
boxes[..., 0::2] *= w # scale to pixels
|
boxes[..., [0, 2]] *= w # scale to pixels
|
||||||
boxes[..., 1::2] *= h
|
boxes[..., [1, 3]] *= h
|
||||||
elif scale < 1: # absolute coords need scale if image scales
|
elif scale < 1: # absolute coords need scale if image scales
|
||||||
boxes[..., :4] *= scale
|
boxes[..., :4] *= scale
|
||||||
boxes[..., 0::2] += x
|
boxes[..., 0] += x
|
||||||
boxes[..., 1::2] += y
|
boxes[..., 1] += y
|
||||||
|
is_obb = boxes.shape[-1] == 5 # xywhr
|
||||||
|
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
|
||||||
for j, box in enumerate(boxes.astype(np.int64).tolist()):
|
for j, box in enumerate(boxes.astype(np.int64).tolist()):
|
||||||
c = classes[j]
|
c = classes[j]
|
||||||
color = colors(c)
|
color = colors(c)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue