Add RTDETR Trainer (#2745)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
03bce07848
commit
a0ba8ef5f0
23 changed files with 989 additions and 314 deletions
|
|
@ -210,7 +210,9 @@ class BaseModel(nn.Module):
|
|||
"""
|
||||
if not hasattr(self, 'criterion'):
|
||||
self.criterion = self.init_criterion()
|
||||
return self.criterion(self.predict(batch['img']) if preds is None else preds, batch)
|
||||
|
||||
preds = self.forward(batch['img']) if preds is None else preds
|
||||
return self.criterion(preds, batch)
|
||||
|
||||
def init_criterion(self):
|
||||
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
|
||||
|
|
@ -410,7 +412,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
"""Compute the classification loss between predictions and true labels."""
|
||||
from ultralytics.vit.utils.loss import RTDETRDetectionLoss
|
||||
|
||||
return RTDETRDetectionLoss(num_classes=self.nc, use_vfl=True)
|
||||
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
|
||||
|
||||
def loss(self, batch, preds=None):
|
||||
if not hasattr(self, 'criterion'):
|
||||
|
|
@ -420,31 +422,36 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
# NOTE: preprocess gt_bbox and gt_labels to list.
|
||||
bs = len(img)
|
||||
batch_idx = batch['batch_idx']
|
||||
gt_bbox, gt_class = [], []
|
||||
gt_groups = []
|
||||
for i in range(bs):
|
||||
gt_bbox.append(batch['bboxes'][batch_idx == i].to(img.device))
|
||||
gt_class.append(batch['cls'][batch_idx == i].to(device=img.device, dtype=torch.long))
|
||||
targets = {'cls': gt_class, 'bboxes': gt_bbox}
|
||||
gt_groups.append((batch_idx == i).sum().item())
|
||||
targets = {
|
||||
'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
|
||||
'bboxes': batch['bboxes'].to(device=img.device),
|
||||
'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
|
||||
'gt_groups': gt_groups}
|
||||
|
||||
preds = self.predict(img, batch=targets) if preds is None else preds
|
||||
dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = preds
|
||||
# NOTE: `dn_meta` means it's eval mode, loss calculation for eval mode is not supported.
|
||||
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds
|
||||
if dn_meta is None:
|
||||
return 0, torch.zeros(3, device=dec_out_bboxes.device)
|
||||
dn_out_bboxes, dec_out_bboxes = torch.split(dec_out_bboxes, dn_meta['dn_num_split'], dim=2)
|
||||
dn_out_logits, dec_out_logits = torch.split(dec_out_logits, dn_meta['dn_num_split'], dim=2)
|
||||
dn_bboxes, dn_scores = None, None
|
||||
else:
|
||||
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
|
||||
dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
|
||||
|
||||
out_bboxes = torch.cat([enc_topk_bboxes.unsqueeze(0), dec_out_bboxes])
|
||||
out_logits = torch.cat([enc_topk_logits.unsqueeze(0), dec_out_logits])
|
||||
dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
|
||||
dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
|
||||
|
||||
loss = self.criterion((out_bboxes, out_logits),
|
||||
loss = self.criterion((dec_bboxes, dec_scores),
|
||||
targets,
|
||||
dn_out_bboxes=dn_out_bboxes,
|
||||
dn_out_logits=dn_out_logits,
|
||||
dn_bboxes=dn_bboxes,
|
||||
dn_scores=dn_scores,
|
||||
dn_meta=dn_meta)
|
||||
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']])
|
||||
# NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
|
||||
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
|
||||
device=img.device)
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, batch=None):
|
||||
def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
|
||||
"""
|
||||
Perform a forward pass through the network.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue