ultralytics 8.3.47 fix softmax and Classify head commonality (#18085)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
95a9796fee
commit
3d52c4fa10
6 changed files with 17 additions and 4 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.3.44"
|
__version__ = "8.3.47"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ from ultralytics.data import build_dataloader
|
||||||
from ultralytics.data.dataset import YOLODataset
|
from ultralytics.data.dataset import YOLODataset
|
||||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||||
from ultralytics.nn.autobackend import check_class_names, default_class_names
|
from ultralytics.nn.autobackend import check_class_names, default_class_names
|
||||||
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder
|
||||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
|
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
|
||||||
from ultralytics.utils import (
|
from ultralytics.utils import (
|
||||||
ARM64,
|
ARM64,
|
||||||
|
|
@ -287,6 +287,8 @@ class Exporter:
|
||||||
|
|
||||||
model = FXModel(model)
|
model = FXModel(model)
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
|
if isinstance(m, Classify):
|
||||||
|
m.export = True
|
||||||
if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB
|
if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB
|
||||||
m.dynamic = self.args.dynamic
|
m.dynamic = self.args.dynamic
|
||||||
m.export = True
|
m.export = True
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,8 @@ class ClassificationPredictor(BasePredictor):
|
||||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|
||||||
|
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
|
||||||
return [
|
return [
|
||||||
Results(orig_img, path=img_path, names=self.model.names, probs=pred.softmax(0))
|
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
|
||||||
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
|
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,10 @@ class ClassificationValidator(BaseValidator):
|
||||||
self.metrics.confusion_matrix = self.confusion_matrix
|
self.metrics.confusion_matrix = self.confusion_matrix
|
||||||
self.metrics.save_dir = self.save_dir
|
self.metrics.save_dir = self.save_dir
|
||||||
|
|
||||||
|
def postprocess(self, preds):
|
||||||
|
"""Preprocesses the classification predictions."""
|
||||||
|
return preds[0] if isinstance(preds, (list, tuple)) else preds
|
||||||
|
|
||||||
def get_stats(self):
|
def get_stats(self):
|
||||||
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
|
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
|
||||||
self.metrics.process(self.targets, self.pred)
|
self.metrics.process(self.targets, self.pred)
|
||||||
|
|
|
||||||
|
|
@ -282,6 +282,8 @@ class Pose(Detect):
|
||||||
class Classify(nn.Module):
|
class Classify(nn.Module):
|
||||||
"""YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
"""YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
||||||
|
|
||||||
|
export = False # export mode
|
||||||
|
|
||||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
|
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
|
||||||
"""Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
|
"""Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -296,7 +298,10 @@ class Classify(nn.Module):
|
||||||
if isinstance(x, list):
|
if isinstance(x, list):
|
||||||
x = torch.cat(x, 1)
|
x = torch.cat(x, 1)
|
||||||
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
||||||
return x
|
if self.training:
|
||||||
|
return x
|
||||||
|
y = x.softmax(1) # get final output
|
||||||
|
return y if self.export else (y, x)
|
||||||
|
|
||||||
|
|
||||||
class WorldDetect(Detect):
|
class WorldDetect(Detect):
|
||||||
|
|
|
||||||
|
|
@ -604,6 +604,7 @@ class v8ClassificationLoss:
|
||||||
|
|
||||||
def __call__(self, preds, batch):
|
def __call__(self, preds, batch):
|
||||||
"""Compute the classification loss between predictions and true labels."""
|
"""Compute the classification loss between predictions and true labels."""
|
||||||
|
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
|
||||||
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
|
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
|
||||||
loss_items = loss.detach()
|
loss_items = loss.detach()
|
||||||
return loss, loss_items
|
return loss, loss_items
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue