ultralytics 8.2.85 YOLOv10 max_det arg fix (#15917)

Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Francesco Mattioli 2024-09-01 02:04:22 +02:00 committed by GitHub
parent 0546c08102
commit ea13dc6208
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 5 additions and 4 deletions

View file

@ -144,12 +144,12 @@ class Detect(nn.Module):
(torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
dimension format [x, y, w, h, max_class_prob, class_index].
"""
batch_size, anchors, predictions = preds.shape # i.e. shape(16,8400,84)
batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
boxes, scores = preds.split([4, nc], dim=-1)
index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
scores, index = scores.flatten(1).topk(max_det)
scores, index = scores.flatten(1).topk(min(max_det, anchors))
i = torch.arange(batch_size)[..., None] # batch indices
return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)