ultralytics 8.3.26 EdgeTPU Pose models fix (#17281)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
e8743f2ac9
commit
f4e7756bff
3 changed files with 19 additions and 4 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.3.25"
|
__version__ = "8.3.26"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -663,6 +663,9 @@ class AutoBackend(nn.Module):
|
||||||
else:
|
else:
|
||||||
x[:, [0, 2]] *= w
|
x[:, [0, 2]] *= w
|
||||||
x[:, [1, 3]] *= h
|
x[:, [1, 3]] *= h
|
||||||
|
if self.task == "pose":
|
||||||
|
x[:, 5::3] *= w
|
||||||
|
x[:, 6::3] *= h
|
||||||
y.append(x)
|
y.append(x)
|
||||||
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
|
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
|
||||||
if len(y) == 2: # segment with (det, proto) output order reversed
|
if len(y) == 2: # segment with (det, proto) output order reversed
|
||||||
|
|
|
||||||
|
|
@ -246,9 +246,21 @@ class Pose(Detect):
|
||||||
def kpts_decode(self, bs, kpts):
|
def kpts_decode(self, bs, kpts):
|
||||||
"""Decodes keypoints."""
|
"""Decodes keypoints."""
|
||||||
ndim = self.kpt_shape[1]
|
ndim = self.kpt_shape[1]
|
||||||
if self.export: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
|
if self.export:
|
||||||
y = kpts.view(bs, *self.kpt_shape, -1)
|
if self.format in {
|
||||||
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
|
"tflite",
|
||||||
|
"edgetpu",
|
||||||
|
}: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
|
||||||
|
# Precompute normalization factor to increase numerical stability
|
||||||
|
y = kpts.view(bs, *self.kpt_shape, -1)
|
||||||
|
grid_h, grid_w = self.shape[2], self.shape[3]
|
||||||
|
grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
|
||||||
|
norm = self.strides / (self.stride[0] * grid_size)
|
||||||
|
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
|
||||||
|
else:
|
||||||
|
# NCNN fix
|
||||||
|
y = kpts.view(bs, *self.kpt_shape, -1)
|
||||||
|
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
|
||||||
if ndim == 3:
|
if ndim == 3:
|
||||||
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
|
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
|
||||||
return a.view(bs, self.nk, -1)
|
return a.view(bs, self.nk, -1)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue