Strip dfl_loss from BboxLoss (#14041)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
f533d77611
commit
f5ccddf5df
3 changed files with 51 additions and 38 deletions
|
|
@ -61,39 +61,22 @@ class FocalLoss(nn.Module):
|
|||
return loss.mean(1).sum()
|
||||
|
||||
|
||||
class BboxLoss(nn.Module):
|
||||
"""Criterion class for computing training losses during training."""
|
||||
class DFLoss(nn.Module):
|
||||
"""Criterion class for computing DFL losses during training."""
|
||||
|
||||
def __init__(self, reg_max, use_dfl=False):
|
||||
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
||||
def __init__(self, reg_max=16) -> None:
|
||||
"""Initialize the DFL module."""
|
||||
super().__init__()
|
||||
self.reg_max = reg_max
|
||||
self.use_dfl = use_dfl
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
"""IoU loss."""
|
||||
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
||||
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
# DFL loss
|
||||
if self.use_dfl:
|
||||
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
|
||||
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
|
||||
loss_dfl = loss_dfl.sum() / target_scores_sum
|
||||
else:
|
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
||||
|
||||
return loss_iou, loss_dfl
|
||||
|
||||
@staticmethod
|
||||
def _df_loss(pred_dist, target):
|
||||
def __call__(self, pred_dist, target):
|
||||
"""
|
||||
Return sum of left and right DFL losses.
|
||||
|
||||
Distribution Focal Loss (DFL) proposed in Generalized Focal Loss
|
||||
https://ieeexplore.ieee.org/document/9792391
|
||||
"""
|
||||
target = target.clamp_(0, self.reg_max - 1 - 0.01)
|
||||
tl = target.long() # target left
|
||||
tr = tl + 1 # target right
|
||||
wl = tr - target # weight left
|
||||
|
|
@ -104,12 +87,37 @@ class BboxLoss(nn.Module):
|
|||
).mean(-1, keepdim=True)
|
||||
|
||||
|
||||
class BboxLoss(nn.Module):
|
||||
"""Criterion class for computing training losses during training."""
|
||||
|
||||
def __init__(self, reg_max=16):
|
||||
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
||||
super().__init__()
|
||||
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
"""IoU loss."""
|
||||
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
||||
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
# DFL loss
|
||||
if self.dfl_loss:
|
||||
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
|
||||
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
||||
loss_dfl = loss_dfl.sum() / target_scores_sum
|
||||
else:
|
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
||||
|
||||
return loss_iou, loss_dfl
|
||||
|
||||
|
||||
class RotatedBboxLoss(BboxLoss):
|
||||
"""Criterion class for computing training losses during training."""
|
||||
|
||||
def __init__(self, reg_max, use_dfl=False):
|
||||
def __init__(self, reg_max):
|
||||
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
||||
super().__init__(reg_max, use_dfl)
|
||||
super().__init__(reg_max)
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
"""IoU loss."""
|
||||
|
|
@ -118,9 +126,9 @@ class RotatedBboxLoss(BboxLoss):
|
|||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
# DFL loss
|
||||
if self.use_dfl:
|
||||
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
|
||||
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
|
||||
if self.dfl_loss:
|
||||
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
|
||||
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
||||
loss_dfl = loss_dfl.sum() / target_scores_sum
|
||||
else:
|
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
||||
|
|
@ -165,18 +173,19 @@ class v8DetectionLoss:
|
|||
self.use_dfl = m.reg_max > 1
|
||||
|
||||
self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
|
||||
self.bbox_loss = BboxLoss(m.reg_max).to(device)
|
||||
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
||||
|
||||
def preprocess(self, targets, batch_size, scale_tensor):
|
||||
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
|
||||
if targets.shape[0] == 0:
|
||||
out = torch.zeros(batch_size, 0, 5, device=self.device)
|
||||
nl, ne = targets.shape
|
||||
if nl == 0:
|
||||
out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
|
||||
else:
|
||||
i = targets[:, 0] # image index
|
||||
_, counts = i.unique(return_counts=True)
|
||||
counts = counts.to(dtype=torch.int32)
|
||||
out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
|
||||
out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
|
||||
for j in range(batch_size):
|
||||
matches = i == j
|
||||
n = matches.sum()
|
||||
|
|
@ -592,7 +601,7 @@ class v8ClassificationLoss:
|
|||
|
||||
def __call__(self, preds, batch):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")
|
||||
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
|
||||
loss_items = loss.detach()
|
||||
return loss, loss_items
|
||||
|
||||
|
|
@ -606,7 +615,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|||
"""
|
||||
super().__init__(model)
|
||||
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)
|
||||
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
|
||||
|
||||
def preprocess(self, targets, batch_size, scale_tensor):
|
||||
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue