ultralytics 8.3.47 fix softmax and Classify head commonality (#18085)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-12-07 20:55:59 +08:00 committed by GitHub
parent 95a9796fee
commit 3d52c4fa10
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 17 additions and 4 deletions

View file

@ -282,6 +282,8 @@ class Pose(Detect):
class Classify(nn.Module):
"""YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
export = False # export mode
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
"""Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
super().__init__()
@ -296,7 +298,10 @@ class Classify(nn.Module):
if isinstance(x, list):
x = torch.cat(x, 1)
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
return x
if self.training:
return x
y = x.softmax(1) # get final output
return y if self.export else (y, x)
class WorldDetect(Detect):