ultralytics 8.3.47 fix softmax and Classify head commonality (#18085)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
95a9796fee
commit
3d52c4fa10
6 changed files with 17 additions and 4 deletions
|
|
@ -282,6 +282,8 @@ class Pose(Detect):
|
|||
class Classify(nn.Module):
|
||||
"""YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
||||
|
||||
export = False # export mode
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
|
||||
"""Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
|
||||
super().__init__()
|
||||
|
|
@ -296,7 +298,10 @@ class Classify(nn.Module):
|
|||
if isinstance(x, list):
|
||||
x = torch.cat(x, 1)
|
||||
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
||||
return x
|
||||
if self.training:
|
||||
return x
|
||||
y = x.softmax(1) # get final output
|
||||
return y if self.export else (y, x)
|
||||
|
||||
|
||||
class WorldDetect(Detect):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue