Add RTDETR Trainer (#2745)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia 2023-06-17 17:16:18 +05:30 committed by GitHub
parent 03bce07848
commit a0ba8ef5f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 989 additions and 314 deletions

View file

@ -210,7 +210,9 @@ class BaseModel(nn.Module):
"""
if not hasattr(self, 'criterion'):
self.criterion = self.init_criterion()
return self.criterion(self.predict(batch['img']) if preds is None else preds, batch)
preds = self.forward(batch['img']) if preds is None else preds
return self.criterion(preds, batch)
def init_criterion(self):
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
@ -410,7 +412,7 @@ class RTDETRDetectionModel(DetectionModel):
"""Compute the classification loss between predictions and true labels."""
from ultralytics.vit.utils.loss import RTDETRDetectionLoss
return RTDETRDetectionLoss(num_classes=self.nc, use_vfl=True)
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
def loss(self, batch, preds=None):
if not hasattr(self, 'criterion'):
@ -420,31 +422,36 @@ class RTDETRDetectionModel(DetectionModel):
# NOTE: preprocess gt_bbox and gt_labels to list.
bs = len(img)
batch_idx = batch['batch_idx']
gt_bbox, gt_class = [], []
gt_groups = []
for i in range(bs):
gt_bbox.append(batch['bboxes'][batch_idx == i].to(img.device))
gt_class.append(batch['cls'][batch_idx == i].to(device=img.device, dtype=torch.long))
targets = {'cls': gt_class, 'bboxes': gt_bbox}
gt_groups.append((batch_idx == i).sum().item())
targets = {
'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
'bboxes': batch['bboxes'].to(device=img.device),
'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
'gt_groups': gt_groups}
preds = self.predict(img, batch=targets) if preds is None else preds
dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = preds
# NOTE: `dn_meta` means it's eval mode, loss calculation for eval mode is not supported.
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds
if dn_meta is None:
return 0, torch.zeros(3, device=dec_out_bboxes.device)
dn_out_bboxes, dec_out_bboxes = torch.split(dec_out_bboxes, dn_meta['dn_num_split'], dim=2)
dn_out_logits, dec_out_logits = torch.split(dec_out_logits, dn_meta['dn_num_split'], dim=2)
dn_bboxes, dn_scores = None, None
else:
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
out_bboxes = torch.cat([enc_topk_bboxes.unsqueeze(0), dec_out_bboxes])
out_logits = torch.cat([enc_topk_logits.unsqueeze(0), dec_out_logits])
dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
loss = self.criterion((out_bboxes, out_logits),
loss = self.criterion((dec_bboxes, dec_scores),
targets,
dn_out_bboxes=dn_out_bboxes,
dn_out_logits=dn_out_logits,
dn_bboxes=dn_bboxes,
dn_scores=dn_scores,
dn_meta=dn_meta)
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']])
# NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
device=img.device)
def predict(self, x, profile=False, visualize=False, batch=None):
def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
"""
Perform a forward pass through the network.