ultralytics 8.2.85 YOLOv10 max_det arg fix (#15917)
Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
0546c08102
commit
ea13dc6208
4 changed files with 5 additions and 4 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.2.84"
|
||||
__version__ = "8.2.85"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -249,6 +249,7 @@ class Exporter:
|
|||
m.dynamic = self.args.dynamic
|
||||
m.export = True
|
||||
m.format = self.args.format
|
||||
m.max_det = self.args.max_det
|
||||
elif isinstance(m, C2f) and not is_tf_format:
|
||||
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
|
||||
m.forward = m.forward_split
|
||||
|
|
|
|||
|
|
@ -144,12 +144,12 @@ class Detect(nn.Module):
|
|||
(torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
|
||||
dimension format [x, y, w, h, max_class_prob, class_index].
|
||||
"""
|
||||
batch_size, anchors, predictions = preds.shape # i.e. shape(16,8400,84)
|
||||
batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
|
||||
boxes, scores = preds.split([4, nc], dim=-1)
|
||||
index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
|
||||
boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
|
||||
scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
|
||||
scores, index = scores.flatten(1).topk(max_det)
|
||||
scores, index = scores.flatten(1).topk(min(max_det, anchors))
|
||||
i = torch.arange(batch_size)[..., None] # batch indices
|
||||
return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
|
||||
|
||||
|
|
|
|||
|
|
@ -218,7 +218,7 @@ def non_max_suppression(
|
|||
classes = torch.tensor(classes, device=prediction.device)
|
||||
|
||||
if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6)
|
||||
output = [pred[pred[:, 4] > conf_thres] for pred in prediction]
|
||||
output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
|
||||
if classes is not None:
|
||||
output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
|
||||
return output
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue