Fix namespace collision in Segment, OBB and Pose classes (#11186)

Co-authored-by: malopez <miguelangel.lopez@solute.es>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Miguel Angel Lopez 2024-05-03 03:22:55 +02:00 committed by GitHub
parent dda7869205
commit 62f55bae26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -98,7 +98,6 @@ class Segment(Detect):
self.nm = nm # number of masks self.nm = nm # number of masks
self.npr = npr # number of protos self.npr = npr # number of protos
self.proto = Proto(ch[0], self.npr, self.nm) # protos self.proto = Proto(ch[0], self.npr, self.nm) # protos
self.detect = Detect.forward
c4 = max(ch[0] // 4, self.nm) c4 = max(ch[0] // 4, self.nm)
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch) self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
@ -109,7 +108,7 @@ class Segment(Detect):
bs = p.shape[0] # batch size bs = p.shape[0] # batch size
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
x = self.detect(self, x) x = Detect.forward(self, x)
if self.training: if self.training:
return x, mc, p return x, mc, p
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p)) return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
@ -122,7 +121,6 @@ class OBB(Detect):
"""Initialize OBB with number of classes `nc` and layer channels `ch`.""" """Initialize OBB with number of classes `nc` and layer channels `ch`."""
super().__init__(nc, ch) super().__init__(nc, ch)
self.ne = ne # number of extra parameters self.ne = ne # number of extra parameters
self.detect = Detect.forward
c4 = max(ch[0] // 4, self.ne) c4 = max(ch[0] // 4, self.ne)
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch) self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
@ -136,7 +134,7 @@ class OBB(Detect):
# angle = angle.sigmoid() * math.pi / 2 # [0, pi/2] # angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
if not self.training: if not self.training:
self.angle = angle self.angle = angle
x = self.detect(self, x) x = Detect.forward(self, x)
if self.training: if self.training:
return x, angle return x, angle
return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle)) return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
@ -154,7 +152,6 @@ class Pose(Detect):
super().__init__(nc, ch) super().__init__(nc, ch)
self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
self.detect = Detect.forward
c4 = max(ch[0] // 4, self.nk) c4 = max(ch[0] // 4, self.nk)
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch) self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
@ -163,7 +160,7 @@ class Pose(Detect):
"""Perform forward pass through YOLO model and return predictions.""" """Perform forward pass through YOLO model and return predictions."""
bs = x[0].shape[0] # batch size bs = x[0].shape[0] # batch size
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w) kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
x = self.detect(self, x) x = Detect.forward(self, x)
if self.training: if self.training:
return x, kpt return x, kpt
pred_kpt = self.kpts_decode(bs, kpt) pred_kpt = self.kpts_decode(bs, kpt)