ultralytics 8.2.49 fix classification setup_model (#14199)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-07-05 00:58:04 +08:00 committed by GitHub
parent 6c13bea7b8
commit 64862f1b69
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 17 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.48" __version__ = "8.2.49"
import os import os

View file

@ -5,7 +5,7 @@ import torch
from ultralytics.data import ClassificationDataset, build_dataloader from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.trainer import BaseTrainer from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight from ultralytics.nn.tasks import ClassificationModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
from ultralytics.utils.plotting import plot_images, plot_results from ultralytics.utils.plotting import plot_images, plot_results
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
@ -60,23 +60,14 @@ class ClassificationTrainer(BaseTrainer):
"""Load, create or download model for any task.""" """Load, create or download model for any task."""
import torchvision # scope for faster 'import ultralytics' import torchvision # scope for faster 'import ultralytics'
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed if str(self.model) in torchvision.models.__dict__:
return self.model = torchvision.models.__dict__[self.model](
weights="IMAGENET1K_V1" if self.args.pretrained else None
model, ckpt = str(self.model), None )
# Load a YOLO model locally, from torchvision, or from Ultralytics assets ckpt = None
if model.endswith(".pt"):
self.model, ckpt = attempt_load_one_weight(model, device="cpu")
for p in self.model.parameters():
p.requires_grad = True # for training
elif model.split(".")[-1] in {"yaml", "yml"}:
self.model = self.get_model(cfg=model)
elif model in torchvision.models.__dict__:
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
else: else:
raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.") ckpt = super().setup_model()
ClassificationModel.reshape_outputs(self.model, self.data["nc"]) ClassificationModel.reshape_outputs(self.model, self.data["nc"])
return ckpt return ckpt
def build_dataset(self, img_path, mode="train", batch=None): def build_dataset(self, img_path, mode="train", batch=None):