Fix missing nc attribute error on NMS export (#19083)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
39b037408f
commit
10d435c4fd
1 changed files with 2 additions and 2 deletions
|
|
@ -1556,8 +1556,8 @@ class NMSModel(torch.nn.Module):
|
|||
preds = self.model(x)
|
||||
pred = preds[0] if isinstance(preds, tuple) else preds
|
||||
pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||
extra_shape = pred.shape[-1] - (4 + self.model.nc) # extras from Segment, OBB, Pose
|
||||
boxes, scores, extras = pred.split([4, self.model.nc, extra_shape], dim=2)
|
||||
extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose
|
||||
boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2)
|
||||
scores, classes = scores.max(dim=-1)
|
||||
self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det
|
||||
# (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue