ultralytics 8.0.235 YOLOv8 OBB train, val, predict and export (#4499)
Co-authored-by: Yash Khurana <ykhurana6@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Swamita Gupta <swamita2001@gmail.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1182102784@qq.com>
This commit is contained in:
parent
f702b34a50
commit
072291bc78
52 changed files with 2090 additions and 524 deletions
|
|
@ -7,7 +7,7 @@ from typing import List
|
|||
|
||||
import numpy as np
|
||||
|
||||
from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
|
||||
from .ops import ltwh2xywh, ltwh2xyxy, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
|
|
@ -212,19 +212,9 @@ class Instances:
|
|||
segments (list | ndarray): segments.
|
||||
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3].
|
||||
"""
|
||||
if segments is None:
|
||||
segments = []
|
||||
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
|
||||
self.keypoints = keypoints
|
||||
self.normalized = normalized
|
||||
|
||||
if len(segments) > 0:
|
||||
# List[np.array(1000, 2)] * num_samples
|
||||
segments = resample_segments(segments)
|
||||
# (N, 1000, 2)
|
||||
segments = np.stack(segments, axis=0)
|
||||
else:
|
||||
segments = np.zeros((0, 1000, 2), dtype=np.float32)
|
||||
self.segments = segments
|
||||
|
||||
def convert_bbox(self, format):
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ import torch.nn.functional as F
|
|||
|
||||
from ultralytics.utils.metrics import OKS_SIGMA
|
||||
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
||||
from ultralytics.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
|
||||
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
||||
|
||||
from .metrics import bbox_iou
|
||||
from .metrics import bbox_iou, probiou
|
||||
from .tal import bbox2dist
|
||||
|
||||
|
||||
|
|
@ -95,6 +95,30 @@ class BboxLoss(nn.Module):
|
|||
F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
|
||||
|
||||
|
||||
class RotatedBboxLoss(BboxLoss):
|
||||
"""Criterion class for computing training losses during training."""
|
||||
|
||||
def __init__(self, reg_max, use_dfl=False):
|
||||
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
||||
super().__init__(reg_max, use_dfl)
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
"""IoU loss."""
|
||||
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
||||
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
# DFL loss
|
||||
if self.use_dfl:
|
||||
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
|
||||
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
|
||||
loss_dfl = loss_dfl.sum() / target_scores_sum
|
||||
else:
|
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
||||
|
||||
return loss_iou, loss_dfl
|
||||
|
||||
|
||||
class KeypointLoss(nn.Module):
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
|
|
@ -243,9 +267,9 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|||
except RuntimeError as e:
|
||||
raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
|
||||
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
|
||||
'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e
|
||||
"i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
|
||||
'as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help.') from e
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
|
@ -526,3 +550,109 @@ class v8ClassificationLoss:
|
|||
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='mean')
|
||||
loss_items = loss.detach()
|
||||
return loss, loss_items
|
||||
|
||||
|
||||
class v8OBBLoss(v8DetectionLoss):
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
super().__init__(model)
|
||||
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)
|
||||
|
||||
def preprocess(self, targets, batch_size, scale_tensor):
|
||||
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
|
||||
if targets.shape[0] == 0:
|
||||
out = torch.zeros(batch_size, 0, 6, device=self.device)
|
||||
else:
|
||||
i = targets[:, 0] # image index
|
||||
_, counts = i.unique(return_counts=True)
|
||||
counts = counts.to(dtype=torch.int32)
|
||||
out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
|
||||
for j in range(batch_size):
|
||||
matches = i == j
|
||||
n = matches.sum()
|
||||
if n:
|
||||
bboxes = targets[matches, 2:]
|
||||
bboxes[..., :4].mul_(scale_tensor)
|
||||
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
||||
return out
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate and return the loss for the YOLO model."""
|
||||
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
||||
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
|
||||
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1)
|
||||
|
||||
# b, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
pred_angle = pred_angle.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# targets
|
||||
try:
|
||||
batch_idx = batch['batch_idx'].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes'].view(-1, 5)), 1)
|
||||
rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
|
||||
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
except RuntimeError as e:
|
||||
raise TypeError('ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n'
|
||||
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolov8n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'OBB' dataset using 'data=coco8-obb.yaml' "
|
||||
'as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help.') from e
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
|
||||
|
||||
bboxes_for_assigner = pred_bboxes.clone().detach()
|
||||
# Only the first four elements need to be scaled
|
||||
bboxes_for_assigner[..., :4] *= stride_tensor
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(pred_scores.detach().sigmoid(),
|
||||
bboxes_for_assigner.type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor, gt_labels, gt_bboxes,
|
||||
mask_gt)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes[..., :4] /= stride_tensor
|
||||
loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
|
||||
target_scores_sum, fg_mask)
|
||||
else:
|
||||
loss[0] += (pred_angle * 0).sum()
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.cls # cls gain
|
||||
loss[2] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
def bbox_decode(self, anchor_points, pred_dist, pred_angle):
|
||||
"""
|
||||
Decode predicted object bounding box coordinates from anchor points and distribution.
|
||||
|
||||
Args:
|
||||
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
||||
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
|
||||
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
|
||||
Returns:
|
||||
(torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
|
||||
"""
|
||||
if self.use_dfl:
|
||||
b, a, c = pred_dist.shape # batch, anchors, channels
|
||||
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
||||
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
|
||||
|
|
|
|||
|
|
@ -165,6 +165,92 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
|
|||
return (torch.exp(-e) * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
|
||||
|
||||
|
||||
def _get_covariance_matrix(boxes):
|
||||
"""
|
||||
Generating covariance matrix from obbs.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Covariance metrixs corresponding to original rotated bounding boxes.
|
||||
"""
|
||||
# Gaussian bounding boxes, ignored the center points(the first two columns) cause it's not needed here.
|
||||
gbbs = torch.cat((torch.pow(boxes[:, 2:4], 2) / 12, boxes[:, 4:]), dim=-1)
|
||||
a, b, c = gbbs.split(1, dim=-1)
|
||||
return (
|
||||
a * torch.cos(c) ** 2 + b * torch.sin(c) ** 2,
|
||||
a * torch.sin(c) ** 2 + b * torch.cos(c) ** 2,
|
||||
a * torch.cos(c) * torch.sin(c) - b * torch.sin(c) * torch.cos(c),
|
||||
)
|
||||
|
||||
|
||||
def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
||||
"""
|
||||
Calculate the prob iou between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
|
||||
|
||||
Args:
|
||||
obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
|
||||
obb2 (torch.Tensor): A tensor of shape (N, 5) representing predicted obbs, with xywhr format.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A tensor of shape (N, ) representing obb similarities.
|
||||
"""
|
||||
x1, y1 = obb1[..., :2].split(1, dim=-1)
|
||||
x2, y2 = obb2[..., :2].split(1, dim=-1)
|
||||
a1, b1, c1 = _get_covariance_matrix(obb1)
|
||||
a2, b2, c2 = _get_covariance_matrix(obb2)
|
||||
|
||||
t1 = (((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) /
|
||||
((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.25
|
||||
t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5
|
||||
t3 = torch.log(((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) /
|
||||
(4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) *
|
||||
(a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + eps) * 0.5
|
||||
bd = t1 + t2 + t3
|
||||
bd = torch.clamp(bd, eps, 100.0)
|
||||
hd = torch.sqrt(1.0 - torch.exp(-bd) + eps)
|
||||
iou = 1 - hd
|
||||
if CIoU: # only include the wh aspect ratio part
|
||||
w1, h1 = obb1[..., 2:4].split(1, dim=-1)
|
||||
w2, h2 = obb2[..., 2:4].split(1, dim=-1)
|
||||
v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
|
||||
with torch.no_grad():
|
||||
alpha = v / (v - iou + (1 + eps))
|
||||
return iou - v * alpha # CIoU
|
||||
return iou
|
||||
|
||||
|
||||
def batch_probiou(obb1, obb2, eps=1e-7):
|
||||
"""
|
||||
Calculate the prob iou between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
|
||||
|
||||
Args:
|
||||
obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
|
||||
obb2 (torch.Tensor): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A tensor of shape (N, M) representing obb similarities.
|
||||
"""
|
||||
x1, y1 = obb1[..., :2].split(1, dim=-1)
|
||||
x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
|
||||
a1, b1, c1 = _get_covariance_matrix(obb1)
|
||||
a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))
|
||||
|
||||
t1 = (((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) /
|
||||
((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.25
|
||||
t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5
|
||||
t3 = torch.log(((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) /
|
||||
(4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) *
|
||||
(a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + eps) * 0.5
|
||||
bd = t1 + t2 + t3
|
||||
bd = torch.clamp(bd, eps, 100.0)
|
||||
hd = torch.sqrt(1.0 - torch.exp(-bd) + eps)
|
||||
return 1 - hd
|
||||
|
||||
|
||||
def smooth_BCE(eps=0.1):
|
||||
"""
|
||||
Computes smoothed positive and negative Binary Cross-Entropy targets.
|
||||
|
|
@ -213,17 +299,17 @@ class ConfusionMatrix:
|
|||
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
|
||||
self.matrix[p][t] += 1
|
||||
|
||||
def process_batch(self, detections, labels):
|
||||
def process_batch(self, detections, gt_bboxes, gt_cls):
|
||||
"""
|
||||
Update confusion matrix for object detection task.
|
||||
|
||||
Args:
|
||||
detections (Array[N, 6]): Detected bounding boxes and their associated information.
|
||||
Each row should contain (x1, y1, x2, y2, conf, class).
|
||||
labels (Array[M, 5]): Ground truth bounding boxes and their associated class labels.
|
||||
Each row should contain (class, x1, y1, x2, y2).
|
||||
gt_bboxes (Array[M, 4]): Ground truth bounding boxes with xyxy format.
|
||||
gt_cls (Array[M]): The class labels.
|
||||
"""
|
||||
if labels.size(0) == 0: # Check if labels is empty
|
||||
if gt_cls.size(0) == 0: # Check if labels is empty
|
||||
if detections is not None:
|
||||
detections = detections[detections[:, 4] > self.conf]
|
||||
detection_classes = detections[:, 5].int()
|
||||
|
|
@ -231,15 +317,15 @@ class ConfusionMatrix:
|
|||
self.matrix[dc, self.nc] += 1 # false positives
|
||||
return
|
||||
if detections is None:
|
||||
gt_classes = labels.int()
|
||||
gt_classes = gt_cls.int()
|
||||
for gc in gt_classes:
|
||||
self.matrix[self.nc, gc] += 1 # background FN
|
||||
return
|
||||
|
||||
detections = detections[detections[:, 4] > self.conf]
|
||||
gt_classes = labels[:, 0].int()
|
||||
gt_classes = gt_cls.int()
|
||||
detection_classes = detections[:, 5].int()
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
iou = box_iou(gt_bboxes, detections[:, :4])
|
||||
|
||||
x = torch.where(iou > self.iou_thres)
|
||||
if x[0].shape[0]:
|
||||
|
|
@ -814,12 +900,12 @@ class SegmentMetrics(SimpleClass):
|
|||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.task = 'segment'
|
||||
|
||||
def process(self, tp_b, tp_m, conf, pred_cls, target_cls):
|
||||
def process(self, tp, tp_m, conf, pred_cls, target_cls):
|
||||
"""
|
||||
Processes the detection and segmentation metrics over the given set of predictions.
|
||||
|
||||
Args:
|
||||
tp_b (list): List of True Positive boxes.
|
||||
tp (list): List of True Positive boxes.
|
||||
tp_m (list): List of True Positive masks.
|
||||
conf (list): List of confidence scores.
|
||||
pred_cls (list): List of predicted classes.
|
||||
|
|
@ -837,7 +923,7 @@ class SegmentMetrics(SimpleClass):
|
|||
prefix='Mask')[2:]
|
||||
self.seg.nc = len(self.names)
|
||||
self.seg.update(results_mask)
|
||||
results_box = ap_per_class(tp_b,
|
||||
results_box = ap_per_class(tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
|
|
@ -938,12 +1024,12 @@ class PoseMetrics(SegmentMetrics):
|
|||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.task = 'pose'
|
||||
|
||||
def process(self, tp_b, tp_p, conf, pred_cls, target_cls):
|
||||
def process(self, tp, tp_p, conf, pred_cls, target_cls):
|
||||
"""
|
||||
Processes the detection and pose metrics over the given set of predictions.
|
||||
|
||||
Args:
|
||||
tp_b (list): List of True Positive boxes.
|
||||
tp (list): List of True Positive boxes.
|
||||
tp_p (list): List of True Positive keypoints.
|
||||
conf (list): List of confidence scores.
|
||||
pred_cls (list): List of predicted classes.
|
||||
|
|
@ -961,7 +1047,7 @@ class PoseMetrics(SegmentMetrics):
|
|||
prefix='Pose')[2:]
|
||||
self.pose.nc = len(self.names)
|
||||
self.pose.update(results_pose)
|
||||
results_box = ap_per_class(tp_b,
|
||||
results_box = ap_per_class(tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
|
|
@ -1067,3 +1153,70 @@ class ClassifyMetrics(SimpleClass):
|
|||
def curves_results(self):
|
||||
"""Returns a list of curves for accessing specific metrics curves."""
|
||||
return []
|
||||
|
||||
|
||||
class OBBMetrics(SimpleClass):
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
self.on_plot = on_plot
|
||||
self.names = names
|
||||
self.box = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def process(self, tp, conf, pred_cls, target_cls):
|
||||
"""Process predicted results for object detection and update metrics."""
|
||||
results = ap_per_class(tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)[2:]
|
||||
self.box.nc = len(self.names)
|
||||
self.box.update(results)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Returns a list of keys for accessing specific metrics."""
|
||||
return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
|
||||
|
||||
def mean_results(self):
|
||||
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
||||
return self.box.mean_results()
|
||||
|
||||
def class_result(self, i):
|
||||
"""Return the result of evaluating the performance of an object detection model on a specific class."""
|
||||
return self.box.class_result(i)
|
||||
|
||||
@property
|
||||
def maps(self):
|
||||
"""Returns mean Average Precision (mAP) scores per class."""
|
||||
return self.box.maps
|
||||
|
||||
@property
|
||||
def fitness(self):
|
||||
"""Returns the fitness of box object."""
|
||||
return self.box.fitness()
|
||||
|
||||
@property
|
||||
def ap_class_index(self):
|
||||
"""Returns the average precision index per class."""
|
||||
return self.box.ap_class_index
|
||||
|
||||
@property
|
||||
def results_dict(self):
|
||||
"""Returns dictionary of computed performance metrics and statistics."""
|
||||
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||
|
||||
@property
|
||||
def curves(self):
|
||||
"""Returns a list of curves for accessing specific metrics curves."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def curves_results(self):
|
||||
"""Returns a list of curves for accessing specific metrics curves."""
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import torch.nn.functional as F
|
|||
import torchvision
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.metrics import batch_probiou
|
||||
|
||||
|
||||
class Profile(contextlib.ContextDecorator):
|
||||
|
|
@ -80,10 +81,10 @@ def segment2box(segment, width=640, height=640):
|
|||
4, dtype=segment.dtype) # xyxy
|
||||
|
||||
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
|
||||
"""
|
||||
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
|
||||
(img1_shape) to the shape of a different image (img0_shape).
|
||||
Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
|
||||
specified in (img1_shape) to the shape of a different image (img0_shape).
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
|
||||
|
|
@ -93,6 +94,7 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
|
|||
calculated based on the size difference between the two images.
|
||||
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
||||
rescaling.
|
||||
xywh (bool): The box format is xywh or not, default=False.
|
||||
|
||||
Returns:
|
||||
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
|
||||
|
|
@ -106,8 +108,11 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
|
|||
pad = ratio_pad[1]
|
||||
|
||||
if padding:
|
||||
boxes[..., [0, 2]] -= pad[0] # x padding
|
||||
boxes[..., [1, 3]] -= pad[1] # y padding
|
||||
boxes[..., 0] -= pad[0] # x padding
|
||||
boxes[..., 1] -= pad[1] # y padding
|
||||
if not xywh:
|
||||
boxes[..., 2] -= pad[0] # x padding
|
||||
boxes[..., 3] -= pad[1] # y padding
|
||||
boxes[..., :4] /= gain
|
||||
return clip_boxes(boxes, img0_shape)
|
||||
|
||||
|
|
@ -128,19 +133,40 @@ def make_divisible(x, divisor):
|
|||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
|
||||
def nms_rotated(boxes, scores, threshold=0.45):
|
||||
"""
|
||||
NMS for obbs, powered by probiou and fast-nms.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): (N, 5), xywhr.
|
||||
scores (torch.Tensor): (N, ).
|
||||
threshold (float): Iou threshold.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
if len(boxes) == 0:
|
||||
return np.empty((0, ), dtype=np.int8)
|
||||
sorted_idx = torch.argsort(scores, descending=True)
|
||||
boxes = boxes[sorted_idx]
|
||||
ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
|
||||
pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
|
||||
return sorted_idx[pick]
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
prediction,
|
||||
conf_thres=0.25,
|
||||
iou_thres=0.45,
|
||||
classes=None,
|
||||
agnostic=False,
|
||||
multi_label=False,
|
||||
labels=(),
|
||||
max_det=300,
|
||||
nc=0, # number of classes (optional)
|
||||
max_time_img=0.05,
|
||||
max_nms=30000,
|
||||
max_wh=7680,
|
||||
prediction,
|
||||
conf_thres=0.25,
|
||||
iou_thres=0.45,
|
||||
classes=None,
|
||||
agnostic=False,
|
||||
multi_label=False,
|
||||
labels=(),
|
||||
max_det=300,
|
||||
nc=0, # number of classes (optional)
|
||||
max_time_img=0.05,
|
||||
max_nms=30000,
|
||||
max_wh=7680,
|
||||
rotated=False,
|
||||
):
|
||||
"""
|
||||
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
|
||||
|
|
@ -190,7 +216,8 @@ def non_max_suppression(
|
|||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
|
||||
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||
if not rotated:
|
||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||
|
||||
t = time.time()
|
||||
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
||||
|
|
@ -200,7 +227,7 @@ def non_max_suppression(
|
|||
x = x[xc[xi]] # confidence
|
||||
|
||||
# Cat apriori labels if autolabelling
|
||||
if labels and len(labels[xi]):
|
||||
if labels and len(labels[xi]) and not rotated:
|
||||
lb = labels[xi]
|
||||
v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
|
||||
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
|
||||
|
|
@ -234,8 +261,13 @@ def non_max_suppression(
|
|||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
scores = x[:, 4] # scores
|
||||
if rotated:
|
||||
boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -2:-1]), dim=-1) # xywhr
|
||||
i = nms_rotated(boxes, scores, iou_thres)
|
||||
else:
|
||||
boxes = x[:, :4] + c # boxes (offset by class)
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
i = i[:max_det] # limit detections
|
||||
|
||||
# # Experimental
|
||||
|
|
@ -320,7 +352,7 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
|||
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
|
||||
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
# gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
top, left = (int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1))) # y, x
|
||||
bottom, right = (int(round(im1_shape[0] - pad[1] + 0.1)), int(round(im1_shape[1] - pad[0] + 0.1)))
|
||||
|
|
@ -476,7 +508,8 @@ def ltwh2xywh(x):
|
|||
|
||||
def xyxyxyxy2xywhr(corners):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation].
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
|
||||
expected in degrees from 0 to 90.
|
||||
|
||||
Args:
|
||||
corners (numpy.ndarray | torch.Tensor): Input corners of shape (n, 8).
|
||||
|
|
@ -484,61 +517,46 @@ def xyxyxyxy2xywhr(corners):
|
|||
Returns:
|
||||
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
||||
"""
|
||||
is_numpy = isinstance(corners, np.ndarray)
|
||||
atan2, sqrt = (np.arctan2, np.sqrt) if is_numpy else (torch.atan2, torch.sqrt)
|
||||
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = corners.T
|
||||
cx = (x1 + x3) / 2
|
||||
cy = (y1 + y3) / 2
|
||||
dx21 = x2 - x1
|
||||
dy21 = y2 - y1
|
||||
|
||||
w = sqrt(dx21 ** 2 + dy21 ** 2)
|
||||
h = sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
|
||||
|
||||
rotation = atan2(-dy21, dx21)
|
||||
rotation *= 180.0 / math.pi # radians to degrees
|
||||
|
||||
return np.vstack((cx, cy, w, h, rotation)).T if is_numpy else torch.stack((cx, cy, w, h, rotation), dim=1)
|
||||
is_torch = isinstance(corners, torch.Tensor)
|
||||
points = corners.cpu().numpy() if is_torch else corners
|
||||
points = points.reshape(len(corners), -1, 2)
|
||||
rboxes = []
|
||||
for pts in points:
|
||||
# NOTE: Use cv2.minAreaRect to get accurate xywhr,
|
||||
# especially some objects are cut off by augmentations in dataloader.
|
||||
(x, y), (w, h), angle = cv2.minAreaRect(pts)
|
||||
rboxes.append([x, y, w, h, angle / 180 * np.pi])
|
||||
rboxes = torch.tensor(rboxes, device=corners.device, dtype=corners.dtype) if is_torch else np.asarray(
|
||||
rboxes, dtype=points.dtype)
|
||||
return rboxes
|
||||
|
||||
|
||||
def xywhr2xyxyxyxy(center):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4].
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
|
||||
be in degrees from 0 to 90.
|
||||
|
||||
Args:
|
||||
center (numpy.ndarray | torch.Tensor): Input data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
||||
center (numpy.ndarray | torch.Tensor): Input data in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 8).
|
||||
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
|
||||
"""
|
||||
is_numpy = isinstance(center, np.ndarray)
|
||||
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
|
||||
|
||||
cx, cy, w, h, rotation = center.T
|
||||
rotation *= math.pi / 180.0 # degrees to radians
|
||||
|
||||
dx = w / 2
|
||||
dy = h / 2
|
||||
|
||||
cos_rot = cos(rotation)
|
||||
sin_rot = sin(rotation)
|
||||
dx_cos_rot = dx * cos_rot
|
||||
dx_sin_rot = dx * sin_rot
|
||||
dy_cos_rot = dy * cos_rot
|
||||
dy_sin_rot = dy * sin_rot
|
||||
|
||||
x1 = cx - dx_cos_rot - dy_sin_rot
|
||||
y1 = cy + dx_sin_rot - dy_cos_rot
|
||||
x2 = cx + dx_cos_rot - dy_sin_rot
|
||||
y2 = cy - dx_sin_rot - dy_cos_rot
|
||||
x3 = cx + dx_cos_rot + dy_sin_rot
|
||||
y3 = cy - dx_sin_rot + dy_cos_rot
|
||||
x4 = cx - dx_cos_rot + dy_sin_rot
|
||||
y4 = cy + dx_sin_rot + dy_cos_rot
|
||||
|
||||
return np.vstack((x1, y1, x2, y2, x3, y3, x4, y4)).T if is_numpy else torch.stack(
|
||||
(x1, y1, x2, y2, x3, y3, x4, y4), dim=1)
|
||||
ctr = center[..., :2]
|
||||
w, h, angle = (center[..., i:i + 1] for i in range(2, 5))
|
||||
cos_value, sin_value = cos(angle), sin(angle)
|
||||
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
|
||||
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
|
||||
vec1 = np.concatenate(vec1, axis=-1) if is_numpy else torch.cat(vec1, dim=-1)
|
||||
vec2 = np.concatenate(vec2, axis=-1) if is_numpy else torch.cat(vec2, dim=-1)
|
||||
pt1 = ctr + vec1 + vec2
|
||||
pt2 = ctr + vec1 - vec2
|
||||
pt3 = ctr - vec1 - vec2
|
||||
pt4 = ctr - vec1 + vec2
|
||||
return np.stack([pt1, pt2, pt3, pt4], axis=-2) if is_numpy else torch.stack([pt1, pt2, pt3, pt4], dim=-2)
|
||||
|
||||
|
||||
def ltwh2xyxy(x):
|
||||
|
|
|
|||
|
|
@ -100,25 +100,35 @@ class Annotator:
|
|||
self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
|
||||
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
|
||||
|
||||
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
||||
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
|
||||
"""Add one xyxy box to image with label."""
|
||||
if isinstance(box, torch.Tensor):
|
||||
box = box.tolist()
|
||||
if self.pil or not is_ascii(label):
|
||||
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
||||
if rotated:
|
||||
p1 = box[0]
|
||||
# NOTE: PIL-version polygon needs tuple type.
|
||||
self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color)
|
||||
else:
|
||||
p1 = (box[0], box[1])
|
||||
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
||||
if label:
|
||||
w, h = self.font.getsize(label) # text width, height
|
||||
outside = box[1] - h >= 0 # label fits outside box
|
||||
outside = p1[1] - h >= 0 # label fits outside box
|
||||
self.draw.rectangle(
|
||||
(box[0], box[1] - h if outside else box[1], box[0] + w + 1,
|
||||
box[1] + 1 if outside else box[1] + h + 1),
|
||||
(p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),
|
||||
fill=color,
|
||||
)
|
||||
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
|
||||
self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
|
||||
self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
|
||||
else: # cv2
|
||||
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
||||
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
|
||||
if rotated:
|
||||
p1 = [int(b) for b in box[0]]
|
||||
# NOTE: cv2-version polylines needs np.asarray type.
|
||||
cv2.polylines(self.im, [np.asarray(box, dtype=np.int)], True, color, self.lw)
|
||||
else:
|
||||
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
||||
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
|
||||
if label:
|
||||
w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
|
||||
outside = p1[1] - h >= 3
|
||||
|
|
@ -563,6 +573,7 @@ def plot_images(images,
|
|||
batch_idx,
|
||||
cls,
|
||||
bboxes=np.zeros(0, dtype=np.float32),
|
||||
confs=None,
|
||||
masks=np.zeros(0, dtype=np.uint8),
|
||||
kpts=np.zeros((0, 51), dtype=np.float32),
|
||||
paths=None,
|
||||
|
|
@ -618,27 +629,29 @@ def plot_images(images,
|
|||
if len(cls) > 0:
|
||||
idx = batch_idx == i
|
||||
classes = cls[idx].astype('int')
|
||||
labels = confs is None
|
||||
|
||||
if len(bboxes):
|
||||
boxes = ops.xywh2xyxy(bboxes[idx, :4]).T
|
||||
labels = bboxes.shape[1] == 4 # labels if no conf column
|
||||
conf = None if labels else bboxes[idx, 4] # check for confidence presence (label vs pred)
|
||||
|
||||
if boxes.shape[1]:
|
||||
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
|
||||
boxes[[0, 2]] *= w # scale to pixels
|
||||
boxes[[1, 3]] *= h
|
||||
boxes = bboxes[idx]
|
||||
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
|
||||
if len(boxes):
|
||||
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
|
||||
boxes[:, [0, 2]] *= w # scale to pixels
|
||||
boxes[:, [1, 3]] *= h
|
||||
elif scale < 1: # absolute coords need scale if image scales
|
||||
boxes *= scale
|
||||
boxes[[0, 2]] += x
|
||||
boxes[[1, 3]] += y
|
||||
for j, box in enumerate(boxes.T.tolist()):
|
||||
boxes[:, :4] *= scale
|
||||
boxes[:, 0] += x
|
||||
boxes[:, 1] += y
|
||||
is_obb = boxes.shape[-1] == 5 # xywhr
|
||||
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
|
||||
for j, box in enumerate(boxes.astype(np.int64).tolist()):
|
||||
c = classes[j]
|
||||
color = colors(c)
|
||||
c = names.get(c, c) if names else c
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
|
||||
annotator.box_label(box, label, color=color)
|
||||
annotator.box_label(box, label, color=color, rotated=is_obb)
|
||||
|
||||
elif len(classes):
|
||||
for c in classes:
|
||||
color = colors(c)
|
||||
|
|
@ -847,7 +860,18 @@ def output_to_target(output, max_det=300):
|
|||
j = torch.full((conf.shape[0], 1), i)
|
||||
targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
|
||||
targets = torch.cat(targets, 0).numpy()
|
||||
return targets[:, 0], targets[:, 1], targets[:, 2:]
|
||||
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
||||
|
||||
|
||||
def output_to_rotated_target(output, max_det=300):
|
||||
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
||||
targets = []
|
||||
for i, o in enumerate(output):
|
||||
box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
|
||||
j = torch.full((conf.shape[0], 1), i)
|
||||
targets.append(torch.cat((j, cls, box, angle, conf), 1))
|
||||
targets = torch.cat(targets, 0).numpy()
|
||||
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
||||
|
||||
|
||||
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
|
||||
|
|
|
|||
|
|
@ -4,59 +4,12 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from .checks import check_version
|
||||
from .metrics import bbox_iou
|
||||
from .metrics import bbox_iou, probiou
|
||||
from .ops import xywhr2xyxyxyxy
|
||||
|
||||
TORCH_1_10 = check_version(torch.__version__, '1.10.0')
|
||||
|
||||
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
||||
"""
|
||||
Select the positive anchor center in gt.
|
||||
|
||||
Args:
|
||||
xy_centers (Tensor): shape(h*w, 2)
|
||||
gt_bboxes (Tensor): shape(b, n_boxes, 4)
|
||||
|
||||
Returns:
|
||||
(Tensor): shape(b, n_boxes, h*w)
|
||||
"""
|
||||
n_anchors = xy_centers.shape[0]
|
||||
bs, n_boxes, _ = gt_bboxes.shape
|
||||
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
||||
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
||||
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
|
||||
return bbox_deltas.amin(3).gt_(eps)
|
||||
|
||||
|
||||
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
"""
|
||||
If an anchor box is assigned to multiple gts, the one with the highest IoI will be selected.
|
||||
|
||||
Args:
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
overlaps (Tensor): shape(b, n_max_boxes, h*w)
|
||||
|
||||
Returns:
|
||||
target_gt_idx (Tensor): shape(b, h*w)
|
||||
fg_mask (Tensor): shape(b, h*w)
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
"""
|
||||
# (b, n_max_boxes, h*w) -> (b, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
|
||||
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
|
||||
|
||||
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
||||
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
||||
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
# Find each grid serve which gt(index)
|
||||
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
|
||||
return target_gt_idx, fg_mask, mask_pos
|
||||
|
||||
|
||||
class TaskAlignedAssigner(nn.Module):
|
||||
"""
|
||||
A task-aligned assigner for object detection.
|
||||
|
|
@ -115,7 +68,7 @@ class TaskAlignedAssigner(nn.Module):
|
|||
mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
|
||||
mask_gt)
|
||||
|
||||
target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
||||
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
||||
|
||||
# Assigned target
|
||||
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
|
||||
|
|
@ -131,7 +84,7 @@ class TaskAlignedAssigner(nn.Module):
|
|||
|
||||
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
||||
"""Get in_gts mask, (b, max_num_obj, h*w)."""
|
||||
mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
|
||||
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
||||
# Get anchor_align metric, (b, max_num_obj, h*w)
|
||||
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
||||
# Get topk_metric mask, (b, max_num_obj, h*w)
|
||||
|
|
@ -157,11 +110,15 @@ class TaskAlignedAssigner(nn.Module):
|
|||
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
|
||||
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
|
||||
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
|
||||
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
||||
overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
|
||||
|
||||
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
|
||||
return align_metric, overlaps
|
||||
|
||||
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
||||
"""Iou calculation for horizontal bounding boxes."""
|
||||
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
||||
|
||||
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
|
||||
"""
|
||||
Select the top-k candidates based on the given metrics.
|
||||
|
|
@ -229,7 +186,7 @@ class TaskAlignedAssigner(nn.Module):
|
|||
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
|
||||
|
||||
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
|
||||
target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
|
||||
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
|
||||
|
||||
# Assigned target scores
|
||||
target_labels.clamp_(0)
|
||||
|
|
@ -245,6 +202,89 @@ class TaskAlignedAssigner(nn.Module):
|
|||
|
||||
return target_labels, target_bboxes, target_scores
|
||||
|
||||
@staticmethod
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
||||
"""
|
||||
Select the positive anchor center in gt.
|
||||
|
||||
Args:
|
||||
xy_centers (Tensor): shape(h*w, 2)
|
||||
gt_bboxes (Tensor): shape(b, n_boxes, 4)
|
||||
|
||||
Returns:
|
||||
(Tensor): shape(b, n_boxes, h*w)
|
||||
"""
|
||||
n_anchors = xy_centers.shape[0]
|
||||
bs, n_boxes, _ = gt_bboxes.shape
|
||||
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
||||
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
||||
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
|
||||
return bbox_deltas.amin(3).gt_(eps)
|
||||
|
||||
@staticmethod
|
||||
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
"""
|
||||
If an anchor box is assigned to multiple gts, the one with the highest IoI will be selected.
|
||||
|
||||
Args:
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
overlaps (Tensor): shape(b, n_max_boxes, h*w)
|
||||
|
||||
Returns:
|
||||
target_gt_idx (Tensor): shape(b, h*w)
|
||||
fg_mask (Tensor): shape(b, h*w)
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
"""
|
||||
# (b, n_max_boxes, h*w) -> (b, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
|
||||
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
|
||||
|
||||
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
||||
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
||||
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
# Find each grid serve which gt(index)
|
||||
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
|
||||
return target_gt_idx, fg_mask, mask_pos
|
||||
|
||||
|
||||
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
||||
|
||||
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
||||
"""Iou calculation for rotated bounding boxes."""
|
||||
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
|
||||
|
||||
@staticmethod
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes):
|
||||
"""
|
||||
Select the positive anchor center in gt for rotated bounding boxes.
|
||||
|
||||
Args:
|
||||
xy_centers (Tensor): shape(h*w, 2)
|
||||
gt_bboxes (Tensor): shape(b, n_boxes, 5)
|
||||
|
||||
Returns:
|
||||
(Tensor): shape(b, n_boxes, h*w)
|
||||
"""
|
||||
# (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
|
||||
corners = xywhr2xyxyxyxy(gt_bboxes)
|
||||
# (b, n_boxes, 1, 2)
|
||||
a, b, _, d = corners.split(1, dim=-2)
|
||||
ab = b - a
|
||||
ad = d - a
|
||||
|
||||
# (b, n_boxes, h*w, 2)
|
||||
ap = xy_centers - a
|
||||
norm_ab = (ab * ab).sum(dim=-1)
|
||||
norm_ad = (ad * ad).sum(dim=-1)
|
||||
ap_dot_ab = (ap * ab).sum(dim=-1)
|
||||
ap_dot_ad = (ap * ad).sum(dim=-1)
|
||||
is_in_box = (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad)
|
||||
return is_in_box
|
||||
|
||||
|
||||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
"""Generate anchors from features."""
|
||||
|
|
@ -277,3 +317,23 @@ def bbox2dist(anchor_points, bbox, reg_max):
|
|||
"""Transform bbox(xyxy) to dist(ltrb)."""
|
||||
x1y1, x2y2 = bbox.chunk(2, -1)
|
||||
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
|
||||
|
||||
|
||||
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
||||
"""
|
||||
Decode predicted object bounding box coordinates from anchor points and distribution.
|
||||
|
||||
Args:
|
||||
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
|
||||
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
|
||||
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
||||
Returns:
|
||||
(torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4).
|
||||
"""
|
||||
lt, rb = pred_dist.split(2, dim=dim)
|
||||
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
|
||||
# (bs, h*w, 1)
|
||||
xf, yf = ((rb - lt) / 2).split(1, dim=dim)
|
||||
x, y = xf * cos - yf * sin, xf * sin + yf * cos
|
||||
xy = torch.cat([x, y], dim=dim) + anchor_points
|
||||
return torch.cat([xy, lt + rb], dim=dim)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue