diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 9d19f6ab..bd608b1d 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.44" +__version__ = "8.3.47" import os diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index c0e29e7e..ae84cab9 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -73,7 +73,7 @@ from ultralytics.data import build_dataloader from ultralytics.data.dataset import YOLODataset from ultralytics.data.utils import check_cls_dataset, check_det_dataset from ultralytics.nn.autobackend import check_class_names, default_class_names -from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder +from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel from ultralytics.utils import ( ARM64, @@ -287,6 +287,8 @@ class Exporter: model = FXModel(model) for m in model.modules(): + if isinstance(m, Classify): + m.export = True if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB m.dynamic = self.args.dynamic m.export = True diff --git a/ultralytics/models/yolo/classify/predict.py b/ultralytics/models/yolo/classify/predict.py index b75a1949..385f75bc 100644 --- a/ultralytics/models/yolo/classify/predict.py +++ b/ultralytics/models/yolo/classify/predict.py @@ -53,7 +53,8 @@ class ClassificationPredictor(BasePredictor): if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + preds = preds[0] if isinstance(preds, (list, tuple)) else preds return [ - Results(orig_img, path=img_path, names=self.model.names, probs=pred.softmax(0)) + Results(orig_img, path=img_path, names=self.model.names, probs=pred) for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]) ] diff --git a/ultralytics/models/yolo/classify/val.py b/ultralytics/models/yolo/classify/val.py index e54f0411..67333f26 100644 --- a/ultralytics/models/yolo/classify/val.py +++ b/ultralytics/models/yolo/classify/val.py @@ -71,6 +71,10 @@ class ClassificationValidator(BaseValidator): self.metrics.confusion_matrix = self.confusion_matrix self.metrics.save_dir = self.save_dir + def postprocess(self, preds): + """Preprocesses the classification predictions.""" + return preds[0] if isinstance(preds, (list, tuple)) else preds + def get_stats(self): """Returns a dictionary of metrics obtained by processing targets and predictions.""" self.metrics.process(self.targets, self.pred) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 25964ac2..0afb5fd1 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -282,6 +282,8 @@ class Pose(Detect): class Classify(nn.Module): """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).""" + export = False # export mode + def __init__(self, c1, c2, k=1, s=1, p=None, g=1): """Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.""" super().__init__() @@ -296,7 +298,10 @@ class Classify(nn.Module): if isinstance(x, list): x = torch.cat(x, 1) x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) - return x + if self.training: + return x + y = x.softmax(1) # get final output + return y if self.export else (y, x) class WorldDetect(Detect): diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 73952868..c6557df4 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -604,6 +604,7 @@ class v8ClassificationLoss: def __call__(self, preds, batch): """Compute the classification loss between predictions and true labels.""" + preds = preds[1] if isinstance(preds, (list, tuple)) else preds loss = F.cross_entropy(preds, batch["cls"], reduction="mean") loss_items = loss.detach() return loss, loss_items