Fix namespace collision in Segment, OBB and Pose classes (#11186)
Co-authored-by: malopez <miguelangel.lopez@solute.es> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
dda7869205
commit
62f55bae26
1 changed files with 3 additions and 6 deletions
|
|
@ -98,7 +98,6 @@ class Segment(Detect):
|
||||||
self.nm = nm # number of masks
|
self.nm = nm # number of masks
|
||||||
self.npr = npr # number of protos
|
self.npr = npr # number of protos
|
||||||
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
||||||
self.detect = Detect.forward
|
|
||||||
|
|
||||||
c4 = max(ch[0] // 4, self.nm)
|
c4 = max(ch[0] // 4, self.nm)
|
||||||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
||||||
|
|
@ -109,7 +108,7 @@ class Segment(Detect):
|
||||||
bs = p.shape[0] # batch size
|
bs = p.shape[0] # batch size
|
||||||
|
|
||||||
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
||||||
x = self.detect(self, x)
|
x = Detect.forward(self, x)
|
||||||
if self.training:
|
if self.training:
|
||||||
return x, mc, p
|
return x, mc, p
|
||||||
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
|
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
|
||||||
|
|
@ -122,7 +121,6 @@ class OBB(Detect):
|
||||||
"""Initialize OBB with number of classes `nc` and layer channels `ch`."""
|
"""Initialize OBB with number of classes `nc` and layer channels `ch`."""
|
||||||
super().__init__(nc, ch)
|
super().__init__(nc, ch)
|
||||||
self.ne = ne # number of extra parameters
|
self.ne = ne # number of extra parameters
|
||||||
self.detect = Detect.forward
|
|
||||||
|
|
||||||
c4 = max(ch[0] // 4, self.ne)
|
c4 = max(ch[0] // 4, self.ne)
|
||||||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
|
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
|
||||||
|
|
@ -136,7 +134,7 @@ class OBB(Detect):
|
||||||
# angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
|
# angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
|
||||||
if not self.training:
|
if not self.training:
|
||||||
self.angle = angle
|
self.angle = angle
|
||||||
x = self.detect(self, x)
|
x = Detect.forward(self, x)
|
||||||
if self.training:
|
if self.training:
|
||||||
return x, angle
|
return x, angle
|
||||||
return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
|
return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
|
||||||
|
|
@ -154,7 +152,6 @@ class Pose(Detect):
|
||||||
super().__init__(nc, ch)
|
super().__init__(nc, ch)
|
||||||
self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
||||||
self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
|
self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
|
||||||
self.detect = Detect.forward
|
|
||||||
|
|
||||||
c4 = max(ch[0] // 4, self.nk)
|
c4 = max(ch[0] // 4, self.nk)
|
||||||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
|
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
|
||||||
|
|
@ -163,7 +160,7 @@ class Pose(Detect):
|
||||||
"""Perform forward pass through YOLO model and return predictions."""
|
"""Perform forward pass through YOLO model and return predictions."""
|
||||||
bs = x[0].shape[0] # batch size
|
bs = x[0].shape[0] # batch size
|
||||||
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
|
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
|
||||||
x = self.detect(self, x)
|
x = Detect.forward(self, x)
|
||||||
if self.training:
|
if self.training:
|
||||||
return x, kpt
|
return x, kpt
|
||||||
pred_kpt = self.kpts_decode(bs, kpt)
|
pred_kpt = self.kpts_decode(bs, kpt)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue