Fix missing nc attribute error on NMS export (#19083)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
39b037408f
commit
10d435c4fd
1 changed files with 2 additions and 2 deletions
|
|
@ -1556,8 +1556,8 @@ class NMSModel(torch.nn.Module):
|
||||||
preds = self.model(x)
|
preds = self.model(x)
|
||||||
pred = preds[0] if isinstance(preds, tuple) else preds
|
pred = preds[0] if isinstance(preds, tuple) else preds
|
||||||
pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||||
extra_shape = pred.shape[-1] - (4 + self.model.nc) # extras from Segment, OBB, Pose
|
extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose
|
||||||
boxes, scores, extras = pred.split([4, self.model.nc, extra_shape], dim=2)
|
boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2)
|
||||||
scores, classes = scores.max(dim=-1)
|
scores, classes = scores.max(dim=-1)
|
||||||
self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det
|
self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det
|
||||||
# (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
|
# (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue