ultralytics 8.3.76 fix dynamic batch inference with NMS export (#19249)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
0f81777af5
commit
e16593336b
2 changed files with 11 additions and 11 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||||
|
|
||||||
__version__ = "8.3.75"
|
__version__ = "8.3.76"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1560,20 +1560,20 @@ class NMSModel(torch.nn.Module):
|
||||||
|
|
||||||
preds = self.model(x)
|
preds = self.model(x)
|
||||||
pred = preds[0] if isinstance(preds, tuple) else preds
|
pred = preds[0] if isinstance(preds, tuple) else preds
|
||||||
|
kwargs = dict(device=pred.device, dtype=pred.dtype)
|
||||||
|
bs = pred.shape[0]
|
||||||
pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||||
extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose
|
extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose
|
||||||
|
if self.args.dynamic and self.args.batch > 1: # batch size needs to always be same due to loop unroll
|
||||||
|
pad = torch.zeros(torch.max(torch.tensor(self.args.batch - bs), torch.tensor(0)), *pred.shape[1:], **kwargs)
|
||||||
|
pred = torch.cat((pred, pad))
|
||||||
boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2)
|
boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2)
|
||||||
scores, classes = scores.max(dim=-1)
|
scores, classes = scores.max(dim=-1)
|
||||||
self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det
|
self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det
|
||||||
# (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
|
# (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
|
||||||
out = torch.zeros(
|
out = torch.zeros(bs, self.args.max_det, boxes.shape[-1] + 2 + extra_shape, **kwargs)
|
||||||
boxes.shape[0],
|
for i in range(bs):
|
||||||
self.args.max_det,
|
box, cls, score, extra = boxes[i], classes[i], scores[i], extras[i]
|
||||||
boxes.shape[-1] + 2 + extra_shape,
|
|
||||||
device=boxes.device,
|
|
||||||
dtype=boxes.dtype,
|
|
||||||
)
|
|
||||||
for i, (box, cls, score, extra) in enumerate(zip(boxes, classes, scores, extras)):
|
|
||||||
mask = score > self.args.conf
|
mask = score > self.args.conf
|
||||||
if self.is_tf:
|
if self.is_tf:
|
||||||
# TFLite GatherND error if mask is empty
|
# TFLite GatherND error if mask is empty
|
||||||
|
|
@ -1593,7 +1593,7 @@ class NMSModel(torch.nn.Module):
|
||||||
if self.args.format == "tflite": # TFLite is already normalized
|
if self.args.format == "tflite": # TFLite is already normalized
|
||||||
nmsbox *= multiplier
|
nmsbox *= multiplier
|
||||||
else:
|
else:
|
||||||
nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], device=box.device, dtype=box.dtype).max()
|
nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], **kwargs).max()
|
||||||
if not self.args.agnostic_nms: # class-specific NMS
|
if not self.args.agnostic_nms: # class-specific NMS
|
||||||
end = 2 if self.obb else 4
|
end = 2 if self.obb else 4
|
||||||
# fully explicit expansion otherwise reshape error
|
# fully explicit expansion otherwise reshape error
|
||||||
|
|
@ -1624,4 +1624,4 @@ class NMSModel(torch.nn.Module):
|
||||||
# Zero-pad to max_det size to avoid reshape error
|
# Zero-pad to max_det size to avoid reshape error
|
||||||
pad = (0, 0, 0, self.args.max_det - dets.shape[0])
|
pad = (0, 0, 0, self.args.max_det - dets.shape[0])
|
||||||
out[i] = torch.nn.functional.pad(dets, pad)
|
out[i] = torch.nn.functional.pad(dets, pad)
|
||||||
return (out, preds[1]) if self.model.task == "segment" else out
|
return (out[:bs], preds[1]) if self.model.task == "segment" else out[:bs]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue