ultralytics 8.3.47 fix softmax and Classify head commonality (#18085)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
95a9796fee
commit
3d52c4fa10
6 changed files with 17 additions and 4 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.3.44"
|
||||
__version__ = "8.3.47"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ from ultralytics.data import build_dataloader
|
|||
from ultralytics.data.dataset import YOLODataset
|
||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.nn.autobackend import check_class_names, default_class_names
|
||||
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
||||
from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder
|
||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
|
||||
from ultralytics.utils import (
|
||||
ARM64,
|
||||
|
|
@ -287,6 +287,8 @@ class Exporter:
|
|||
|
||||
model = FXModel(model)
|
||||
for m in model.modules():
|
||||
if isinstance(m, Classify):
|
||||
m.export = True
|
||||
if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB
|
||||
m.dynamic = self.args.dynamic
|
||||
m.export = True
|
||||
|
|
|
|||
|
|
@ -53,7 +53,8 @@ class ClassificationPredictor(BasePredictor):
|
|||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
|
||||
return [
|
||||
Results(orig_img, path=img_path, names=self.model.names, probs=pred.softmax(0))
|
||||
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
|
||||
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
|
||||
]
|
||||
|
|
|
|||
|
|
@ -71,6 +71,10 @@ class ClassificationValidator(BaseValidator):
|
|||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
self.metrics.save_dir = self.save_dir
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Preprocesses the classification predictions."""
|
||||
return preds[0] if isinstance(preds, (list, tuple)) else preds
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
|
||||
self.metrics.process(self.targets, self.pred)
|
||||
|
|
|
|||
|
|
@ -282,6 +282,8 @@ class Pose(Detect):
|
|||
class Classify(nn.Module):
|
||||
"""YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
||||
|
||||
export = False # export mode
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
|
||||
"""Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
|
||||
super().__init__()
|
||||
|
|
@ -296,7 +298,10 @@ class Classify(nn.Module):
|
|||
if isinstance(x, list):
|
||||
x = torch.cat(x, 1)
|
||||
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
||||
return x
|
||||
if self.training:
|
||||
return x
|
||||
y = x.softmax(1) # get final output
|
||||
return y if self.export else (y, x)
|
||||
|
||||
|
||||
class WorldDetect(Detect):
|
||||
|
|
|
|||
|
|
@ -604,6 +604,7 @@ class v8ClassificationLoss:
|
|||
|
||||
def __call__(self, preds, batch):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
|
||||
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
|
||||
loss_items = loss.detach()
|
||||
return loss, loss_items
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue