Ruff Docstring formatting (#15793)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-08-25 04:27:55 +08:00 committed by GitHub
parent d27664216b
commit 776ca86369
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
60 changed files with 241 additions and 309 deletions

View file

@ -34,15 +34,19 @@ class DETRLoss(nn.Module):
self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
):
"""
DETR loss function.
Initialize DETR loss function with customizable components and gains.
Uses default loss_gain if not provided. Initializes HungarianMatcher with
preset cost gains. Supports auxiliary losses and various loss types.
Args:
nc (int): The number of classes.
loss_gain (dict): The coefficient of loss.
aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
use_vfl (bool): Use VarifocalLoss or not.
use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
uni_match_ind (int): The fixed indices of a layer.
nc (int): Number of classes.
loss_gain (dict): Coefficients for different loss components.
aux_loss (bool): Use auxiliary losses from each decoder layer.
use_fl (bool): Use FocalLoss.
use_vfl (bool): Use VarifocalLoss.
use_uni_match (bool): Use fixed layer for auxiliary branch label assignment.
uni_match_ind (int): Index of fixed layer for uni_match.
"""
super().__init__()
@ -82,9 +86,7 @@ class DETRLoss(nn.Module):
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
"""Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
boxes.
"""
"""Computes bounding box and GIoU losses for predicted and ground truth bounding boxes."""
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
name_bbox = f"loss_bbox{postfix}"
name_giou = f"loss_giou{postfix}"
@ -250,14 +252,24 @@ class DETRLoss(nn.Module):
def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
"""
Calculate loss for predicted bounding boxes and scores.
Args:
pred_bboxes (torch.Tensor): [l, b, query, 4]
pred_scores (torch.Tensor): [l, b, query, num_classes]
batch (dict): A dict includes:
gt_cls (torch.Tensor) with shape [num_gts, ],
gt_bboxes (torch.Tensor): [num_gts, 4],
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
postfix (str): postfix of loss name.
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
batch (dict): Batch information containing:
cls (torch.Tensor): Ground truth classes, shape [num_gts].
bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
gt_groups (List[int]): Number of ground truths for each image in the batch.
postfix (str): Postfix for loss names.
**kwargs (Any): Additional arguments, may include 'match_indices'.
Returns:
(dict): Computed losses, including main and auxiliary (if enabled).
Note:
Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
self.aux_loss is True.
"""
self.device = pred_bboxes.device
match_indices = kwargs.get("match_indices", None)

View file

@ -32,9 +32,7 @@ class HungarianMatcher(nn.Module):
"""
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
"""Initializes HungarianMatcher with cost coefficients, Focal Loss, mask prediction, sample points, and alpha
gamma factors.
"""
"""Initializes a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes."""
super().__init__()
if cost_gain is None:
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
@ -70,7 +68,6 @@ class HungarianMatcher(nn.Module):
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, nq, nc = pred_scores.shape
if sum(gt_groups) == 0:
@ -175,7 +172,6 @@ def get_cdn_group(
bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
is less than or equal to 0, the function returns None for all elements in the tuple.
"""
if (not training) or num_dn <= 0:
return None, None, None, None
gt_groups = batch["gt_groups"]