ultralytics 8.3.47 fix softmax and Classify head commonality (#18085)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-12-07 20:55:59 +08:00 committed by GitHub
parent 95a9796fee
commit 3d52c4fa10
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 17 additions and 4 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.44"
__version__ = "8.3.47"
import os

View file

@ -73,7 +73,7 @@ from ultralytics.data import build_dataloader
from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import check_class_names, default_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
from ultralytics.utils import (
ARM64,
@ -287,6 +287,8 @@ class Exporter:
model = FXModel(model)
for m in model.modules():
if isinstance(m, Classify):
m.export = True
if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB
m.dynamic = self.args.dynamic
m.export = True

View file

@ -53,7 +53,8 @@ class ClassificationPredictor(BasePredictor):
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
return [
Results(orig_img, path=img_path, names=self.model.names, probs=pred.softmax(0))
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
]

View file

@ -71,6 +71,10 @@ class ClassificationValidator(BaseValidator):
self.metrics.confusion_matrix = self.confusion_matrix
self.metrics.save_dir = self.save_dir
def postprocess(self, preds):
"""Preprocesses the classification predictions."""
return preds[0] if isinstance(preds, (list, tuple)) else preds
def get_stats(self):
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
self.metrics.process(self.targets, self.pred)

View file

@ -282,6 +282,8 @@ class Pose(Detect):
class Classify(nn.Module):
"""YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
export = False # export mode
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
"""Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
super().__init__()
@ -296,7 +298,10 @@ class Classify(nn.Module):
if isinstance(x, list):
x = torch.cat(x, 1)
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
return x
if self.training:
return x
y = x.softmax(1) # get final output
return y if self.export else (y, x)
class WorldDetect(Detect):

View file

@ -604,6 +604,7 @@ class v8ClassificationLoss:
def __call__(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
loss_items = loss.detach()
return loss, loss_items