ultralytics 8.0.167 Tuner updates and HUB Pose and Classify fixes (#4656)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
8596ee241f
commit
d2cf7acce0
21 changed files with 174 additions and 144 deletions
|
|
@ -84,35 +84,36 @@ class DETRLoss(nn.Module):
|
|||
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
|
||||
return {k: v.squeeze() for k, v in loss.items()}
|
||||
|
||||
def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
|
||||
# masks: [b, query, h, w], gt_mask: list[[n, H, W]]
|
||||
name_mask = f'loss_mask{postfix}'
|
||||
name_dice = f'loss_dice{postfix}'
|
||||
# This function is for future RT-DETR Segment models
|
||||
# def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
|
||||
# # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
|
||||
# name_mask = f'loss_mask{postfix}'
|
||||
# name_dice = f'loss_dice{postfix}'
|
||||
#
|
||||
# loss = {}
|
||||
# if sum(len(a) for a in gt_mask) == 0:
|
||||
# loss[name_mask] = torch.tensor(0., device=self.device)
|
||||
# loss[name_dice] = torch.tensor(0., device=self.device)
|
||||
# return loss
|
||||
#
|
||||
# num_gts = len(gt_mask)
|
||||
# src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
|
||||
# src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
|
||||
# # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
|
||||
# loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
|
||||
# torch.tensor([num_gts], dtype=torch.float32))
|
||||
# loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
|
||||
# return loss
|
||||
|
||||
loss = {}
|
||||
if sum(len(a) for a in gt_mask) == 0:
|
||||
loss[name_mask] = torch.tensor(0., device=self.device)
|
||||
loss[name_dice] = torch.tensor(0., device=self.device)
|
||||
return loss
|
||||
|
||||
num_gts = len(gt_mask)
|
||||
src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
|
||||
src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
|
||||
# TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
|
||||
loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
|
||||
torch.tensor([num_gts], dtype=torch.float32))
|
||||
loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def _dice_loss(inputs, targets, num_gts):
|
||||
inputs = F.sigmoid(inputs)
|
||||
inputs = inputs.flatten(1)
|
||||
targets = targets.flatten(1)
|
||||
numerator = 2 * (inputs * targets).sum(1)
|
||||
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
return loss.sum() / num_gts
|
||||
# This function is for future RT-DETR Segment models
|
||||
# @staticmethod
|
||||
# def _dice_loss(inputs, targets, num_gts):
|
||||
# inputs = F.sigmoid(inputs).flatten(1)
|
||||
# targets = targets.flatten(1)
|
||||
# numerator = 2 * (inputs * targets).sum(1)
|
||||
# denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
# loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
# return loss.sum() / num_gts
|
||||
|
||||
def _get_loss_aux(self,
|
||||
pred_bboxes,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue