Add RTDETR Trainer (#2745)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia 2023-06-17 17:16:18 +05:30 committed by GitHub
parent 03bce07848
commit a0ba8ef5f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 989 additions and 314 deletions

View file

@ -24,10 +24,34 @@ class VarifocalLoss(nn.Module):
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False):
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
weight).sum()
weight).mean(1).sum()
return loss
# Losses
class FocalLoss(nn.Module):
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
def __init__(self, ):
super().__init__()
def forward(self, pred, label, gamma=1.5, alpha=0.25):
"""Calculates and updates confusion matrix for object detection/classification tasks."""
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
pred_prob = pred.sigmoid() # prob from logits
p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
modulating_factor = (1.0 - p_t) ** gamma
loss *= modulating_factor
if alpha > 0:
alpha_factor = label * alpha + (1 - label) * (1 - alpha)
loss *= alpha_factor
return loss.mean(1).sum()
class BboxLoss(nn.Module):
def __init__(self, reg_max, use_dfl=False):