Ruff Docstring formatting (#15793)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-08-25 04:27:55 +08:00 committed by GitHub
parent d27664216b
commit 776ca86369
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
60 changed files with 241 additions and 309 deletions

View file

@ -140,7 +140,6 @@ class TaskAlignedAssigner(nn.Module):
Returns:
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
"""
# (b, max_num_obj, topk)
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
if topk_mask is None:
@ -184,7 +183,6 @@ class TaskAlignedAssigner(nn.Module):
for positive anchor points, where num_classes is the number
of object classes.
"""
# Assigned target labels, (b, 1)
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
@ -212,14 +210,19 @@ class TaskAlignedAssigner(nn.Module):
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
"""
Select the positive anchor center in gt.
Select positive anchor centers within ground truth bounding boxes.
Args:
xy_centers (Tensor): shape(h*w, 2)
gt_bboxes (Tensor): shape(b, n_boxes, 4)
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
eps (float, optional): Small value for numerical stability. Defaults to 1e-9.
Returns:
(Tensor): shape(b, n_boxes, h*w)
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
Note:
b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
Bounding box format: [x_min, y_min, x_max, y_max].
"""
n_anchors = xy_centers.shape[0]
bs, n_boxes, _ = gt_bboxes.shape
@ -231,18 +234,22 @@ class TaskAlignedAssigner(nn.Module):
@staticmethod
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
"""
If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.
Select anchor boxes with highest IoU when assigned to multiple ground truths.
Args:
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
overlaps (Tensor): shape(b, n_max_boxes, h*w)
mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
n_max_boxes (int): Maximum number of ground truth boxes.
Returns:
target_gt_idx (Tensor): shape(b, h*w)
fg_mask (Tensor): shape(b, h*w)
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
Note:
b: batch size, h: height, w: width.
"""
# (b, n_max_boxes, h*w) -> (b, h*w)
# Convert (b, n_max_boxes, h*w) -> (b, h*w)
fg_mask = mask_pos.sum(-2)
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
@ -328,14 +335,16 @@ def bbox2dist(anchor_points, bbox, reg_max):
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
"""
Decode predicted object bounding box coordinates from anchor points and distribution.
Decode predicted rotated bounding box coordinates from anchor points and distribution.
Args:
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
pred_dist (torch.Tensor): Predicted rotated distance, shape (bs, h*w, 4).
pred_angle (torch.Tensor): Predicted angle, shape (bs, h*w, 1).
anchor_points (torch.Tensor): Anchor points, shape (h*w, 2).
dim (int, optional): Dimension along which to split. Defaults to -1.
Returns:
(torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4).
(torch.Tensor): Predicted rotated bounding boxes, shape (bs, h*w, 4).
"""
lt, rb = pred_dist.split(2, dim=dim)
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)