Fix namespace collision in Segment, OBB and Pose classes (#11186)
Co-authored-by: malopez <miguelangel.lopez@solute.es> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
dda7869205
commit
62f55bae26
1 changed files with 3 additions and 6 deletions
|
|
@ -98,7 +98,6 @@ class Segment(Detect):
|
|||
self.nm = nm # number of masks
|
||||
self.npr = npr # number of protos
|
||||
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
||||
self.detect = Detect.forward
|
||||
|
||||
c4 = max(ch[0] // 4, self.nm)
|
||||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
||||
|
|
@ -109,7 +108,7 @@ class Segment(Detect):
|
|||
bs = p.shape[0] # batch size
|
||||
|
||||
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
||||
x = self.detect(self, x)
|
||||
x = Detect.forward(self, x)
|
||||
if self.training:
|
||||
return x, mc, p
|
||||
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
|
||||
|
|
@ -122,7 +121,6 @@ class OBB(Detect):
|
|||
"""Initialize OBB with number of classes `nc` and layer channels `ch`."""
|
||||
super().__init__(nc, ch)
|
||||
self.ne = ne # number of extra parameters
|
||||
self.detect = Detect.forward
|
||||
|
||||
c4 = max(ch[0] // 4, self.ne)
|
||||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
|
||||
|
|
@ -136,7 +134,7 @@ class OBB(Detect):
|
|||
# angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
|
||||
if not self.training:
|
||||
self.angle = angle
|
||||
x = self.detect(self, x)
|
||||
x = Detect.forward(self, x)
|
||||
if self.training:
|
||||
return x, angle
|
||||
return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
|
||||
|
|
@ -154,7 +152,6 @@ class Pose(Detect):
|
|||
super().__init__(nc, ch)
|
||||
self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
||||
self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
|
||||
self.detect = Detect.forward
|
||||
|
||||
c4 = max(ch[0] // 4, self.nk)
|
||||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
|
||||
|
|
@ -163,7 +160,7 @@ class Pose(Detect):
|
|||
"""Perform forward pass through YOLO model and return predictions."""
|
||||
bs = x[0].shape[0] # batch size
|
||||
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
|
||||
x = self.detect(self, x)
|
||||
x = Detect.forward(self, x)
|
||||
if self.training:
|
||||
return x, kpt
|
||||
pred_kpt = self.kpts_decode(bs, kpt)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue