diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index df136c14..2cf3d26c 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -1556,8 +1556,8 @@ class NMSModel(torch.nn.Module): preds = self.model(x) pred = preds[0] if isinstance(preds, tuple) else preds pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) - extra_shape = pred.shape[-1] - (4 + self.model.nc) # extras from Segment, OBB, Pose - boxes, scores, extras = pred.split([4, self.model.nc, extra_shape], dim=2) + extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose + boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2) scores, classes = scores.max(dim=-1) self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det # (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).