Model coverage cleanup (#4585)
This commit is contained in:
parent
c635418a27
commit
deac7575b1
12 changed files with 132 additions and 175 deletions
|
|
@ -82,8 +82,7 @@ class DETRLoss(nn.Module):
|
|||
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
|
||||
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
|
||||
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
|
||||
loss = {k: v.squeeze() for k, v in loss.items()}
|
||||
return loss
|
||||
return {k: v.squeeze() for k, v in loss.items()}
|
||||
|
||||
def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
|
||||
# masks: [b, query, h, w], gt_mask: list[[n, H, W]]
|
||||
|
|
@ -105,7 +104,8 @@ class DETRLoss(nn.Module):
|
|||
loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
|
||||
return loss
|
||||
|
||||
def _dice_loss(self, inputs, targets, num_gts):
|
||||
@staticmethod
|
||||
def _dice_loss(inputs, targets, num_gts):
|
||||
inputs = F.sigmoid(inputs)
|
||||
inputs = inputs.flatten(1)
|
||||
targets = targets.flatten(1)
|
||||
|
|
@ -163,7 +163,8 @@ class DETRLoss(nn.Module):
|
|||
# loss[f'loss_dice_aux{postfix}'] = loss[4]
|
||||
return loss
|
||||
|
||||
def _get_index(self, match_indices):
|
||||
@staticmethod
|
||||
def _get_index(match_indices):
|
||||
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
|
||||
src_idx = torch.cat([src for (src, _) in match_indices])
|
||||
dst_idx = torch.cat([dst for (_, dst) in match_indices])
|
||||
|
|
@ -257,10 +258,10 @@ class RTDETRDetectionLoss(DETRLoss):
|
|||
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
|
||||
assert len(batch['gt_groups']) == len(dn_pos_idx)
|
||||
|
||||
# denoising match indices
|
||||
# Denoising match indices
|
||||
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
|
||||
|
||||
# compute denoising training loss
|
||||
# Compute denoising training loss
|
||||
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
|
||||
total_loss.update(dn_loss)
|
||||
else:
|
||||
|
|
@ -270,7 +271,8 @@ class RTDETRDetectionLoss(DETRLoss):
|
|||
|
||||
@staticmethod
|
||||
def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
|
||||
"""Get the match indices for denoising.
|
||||
"""
|
||||
Get the match indices for denoising.
|
||||
|
||||
Args:
|
||||
dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising.
|
||||
|
|
@ -279,7 +281,6 @@ class RTDETRDetectionLoss(DETRLoss):
|
|||
|
||||
Returns:
|
||||
dn_match_indices (List(tuple)): Matched indices.
|
||||
|
||||
"""
|
||||
dn_match_indices = []
|
||||
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue