Add docformatter to pre-commit (#5279)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
This commit is contained in:
parent
c7aa83da31
commit
7517667a33
90 changed files with 1396 additions and 497 deletions
|
|
@ -11,8 +11,8 @@ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
|||
|
||||
class HungarianMatcher(nn.Module):
|
||||
"""
|
||||
A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in
|
||||
an end-to-end fashion.
|
||||
A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an
|
||||
end-to-end fashion.
|
||||
|
||||
HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost
|
||||
function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
|
||||
|
|
@ -32,6 +32,9 @@ class HungarianMatcher(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
|
||||
"""Initializes HungarianMatcher with cost coefficients, Focal Loss, mask prediction, sample points, and alpha
|
||||
gamma factors.
|
||||
"""
|
||||
super().__init__()
|
||||
if cost_gain is None:
|
||||
cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
|
||||
|
|
@ -45,8 +48,8 @@ class HungarianMatcher(nn.Module):
|
|||
def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
|
||||
"""
|
||||
Forward pass for HungarianMatcher. This function computes costs based on prediction and ground truth
|
||||
(classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching
|
||||
between predictions and ground truth based on these costs.
|
||||
(classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching between
|
||||
predictions and ground truth based on these costs.
|
||||
|
||||
Args:
|
||||
pred_bboxes (Tensor): Predicted bounding boxes with shape [batch_size, num_queries, 4].
|
||||
|
|
@ -153,9 +156,9 @@ def get_cdn_group(batch,
|
|||
box_noise_scale=1.0,
|
||||
training=False):
|
||||
"""
|
||||
Get contrastive denoising training group. This function creates a contrastive denoising training group with
|
||||
positive and negative samples from the ground truths (gt). It applies noise to the class labels and bounding
|
||||
box coordinates, and returns the modified labels, bounding boxes, attention mask and meta information.
|
||||
Get contrastive denoising training group. This function creates a contrastive denoising training group with positive
|
||||
and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
|
||||
and returns the modified labels, bounding boxes, attention mask and meta information.
|
||||
|
||||
Args:
|
||||
batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes'
|
||||
|
|
@ -191,12 +194,12 @@ def get_cdn_group(batch,
|
|||
gt_bbox = batch['bboxes'] # bs*num, 4
|
||||
b_idx = batch['batch_idx']
|
||||
|
||||
# each group has positive and negative queries.
|
||||
# Each group has positive and negative queries.
|
||||
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
||||
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
|
||||
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
|
||||
|
||||
# positive and negative mask
|
||||
# Positive and negative mask
|
||||
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
|
||||
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
|
||||
|
||||
|
|
@ -220,10 +223,9 @@ def get_cdn_group(batch,
|
|||
known_bbox += rand_part * diff
|
||||
known_bbox.clip_(min=0.0, max=1.0)
|
||||
dn_bbox = xyxy2xywh(known_bbox)
|
||||
dn_bbox = inverse_sigmoid(dn_bbox)
|
||||
dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
|
||||
|
||||
# total denoising queries
|
||||
num_dn = int(max_nums * 2 * num_group)
|
||||
num_dn = int(max_nums * 2 * num_group) # total denoising queries
|
||||
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
|
||||
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
|
||||
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
|
||||
|
|
@ -256,9 +258,3 @@ def get_cdn_group(batch,
|
|||
|
||||
return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
|
||||
class_embed.device), dn_meta
|
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-6):
|
||||
"""Inverse sigmoid function."""
|
||||
x = x.clip(min=0., max=1.)
|
||||
return torch.log(x / (1 - x + eps) + eps)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue