ultralytics 8.3.29 Sony IMX500 export (#14878)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Francesco Mattioli <Francesco.mttl@gmail.com> Co-authored-by: Lakshantha Dissanayake <lakshantha@ultralytics.com> Co-authored-by: Lakshantha Dissanayake <lakshanthad@yahoo.com> Co-authored-by: Chizkiyahu Raful <37312901+Chizkiyahu@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Muhammad Rizwan Munawar <muhammadrizwanmunawar123@gmail.com> Co-authored-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com>
This commit is contained in:
parent
2c6cd68144
commit
0fa1d7d5a6
16 changed files with 281 additions and 17 deletions
|
|
@ -23,6 +23,7 @@ class Detect(nn.Module):
|
|||
|
||||
dynamic = False # force grid reconstruction
|
||||
export = False # export mode
|
||||
format = None # export format
|
||||
end2end = False # end2end
|
||||
max_det = 300 # max_det
|
||||
shape = None
|
||||
|
|
@ -101,7 +102,7 @@ class Detect(nn.Module):
|
|||
# Inference path
|
||||
shape = x[0].shape # BCHW
|
||||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||||
if self.dynamic or self.shape != shape:
|
||||
if self.format != "imx" and (self.dynamic or self.shape != shape):
|
||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
|
|
@ -119,6 +120,11 @@ class Detect(nn.Module):
|
|||
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
|
||||
norm = self.strides / (self.stride[0] * grid_size)
|
||||
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
|
||||
elif self.export and self.format == "imx":
|
||||
dbox = self.decode_bboxes(
|
||||
self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False
|
||||
)
|
||||
return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
|
||||
else:
|
||||
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
|
||||
|
||||
|
|
@ -137,9 +143,9 @@ class Detect(nn.Module):
|
|||
a[-1].bias.data[:] = 1.0 # box
|
||||
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
||||
|
||||
def decode_bboxes(self, bboxes, anchors):
|
||||
def decode_bboxes(self, bboxes, anchors, xywh=True):
|
||||
"""Decode bounding boxes."""
|
||||
return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
|
||||
return dist2bbox(bboxes, anchors, xywh=xywh and (not self.end2end), dim=1)
|
||||
|
||||
@staticmethod
|
||||
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue