Fix missing nc attribute error on NMS export (#19083)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Mohammed Yasin 2025-02-06 10:43:26 +08:00 committed by GitHub
parent 39b037408f
commit 10d435c4fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1556,8 +1556,8 @@ class NMSModel(torch.nn.Module):
preds = self.model(x) preds = self.model(x)
pred = preds[0] if isinstance(preds, tuple) else preds pred = preds[0] if isinstance(preds, tuple) else preds
pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
extra_shape = pred.shape[-1] - (4 + self.model.nc) # extras from Segment, OBB, Pose extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose
boxes, scores, extras = pred.split([4, self.model.nc, extra_shape], dim=2) boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2)
scores, classes = scores.max(dim=-1) scores, classes = scores.max(dim=-1)
self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det
# (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape). # (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).