ultralytics 8.0.54 TFLite export improvements and fixes (#1447)
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
30fc4b537f
commit
701fba4770
30 changed files with 198 additions and 166 deletions
|
|
@ -411,12 +411,12 @@ class Detect(nn.Module):
|
|||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
if self.export and self.format == 'edgetpu': # FlexSplitV ops issue
|
||||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||||
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
|
||||
box = x_cat[:, :self.reg_max * 4]
|
||||
cls = x_cat[:, self.reg_max * 4:]
|
||||
else:
|
||||
box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||||
y = torch.cat((dbox, cls.sigmoid()), 1)
|
||||
return y if self.export else (y, x)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue