Model coverage cleanup (#4585)

This commit is contained in:
Glenn Jocher 2023-08-27 04:19:41 +02:00 committed by GitHub
parent c635418a27
commit deac7575b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 132 additions and 175 deletions

View file

@ -82,8 +82,7 @@ class DETRLoss(nn.Module):
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
loss = {k: v.squeeze() for k, v in loss.items()}
return loss
return {k: v.squeeze() for k, v in loss.items()}
def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
# masks: [b, query, h, w], gt_mask: list[[n, H, W]]
@ -105,7 +104,8 @@ class DETRLoss(nn.Module):
loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
return loss
def _dice_loss(self, inputs, targets, num_gts):
@staticmethod
def _dice_loss(inputs, targets, num_gts):
inputs = F.sigmoid(inputs)
inputs = inputs.flatten(1)
targets = targets.flatten(1)
@ -163,7 +163,8 @@ class DETRLoss(nn.Module):
# loss[f'loss_dice_aux{postfix}'] = loss[4]
return loss
def _get_index(self, match_indices):
@staticmethod
def _get_index(match_indices):
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
src_idx = torch.cat([src for (src, _) in match_indices])
dst_idx = torch.cat([dst for (_, dst) in match_indices])
@ -257,10 +258,10 @@ class RTDETRDetectionLoss(DETRLoss):
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
assert len(batch['gt_groups']) == len(dn_pos_idx)
# denoising match indices
# Denoising match indices
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
# compute denoising training loss
# Compute denoising training loss
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
total_loss.update(dn_loss)
else:
@ -270,7 +271,8 @@ class RTDETRDetectionLoss(DETRLoss):
@staticmethod
def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
"""Get the match indices for denoising.
"""
Get the match indices for denoising.
Args:
dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising.
@ -279,7 +281,6 @@ class RTDETRDetectionLoss(DETRLoss):
Returns:
dn_match_indices (List(tuple)): Matched indices.
"""
dn_match_indices = []
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)