ultralytics 8.2.49 fix classification setup_model (#14199)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
6c13bea7b8
commit
64862f1b69
2 changed files with 8 additions and 17 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.2.48"
|
||||
__version__ = "8.2.49"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.engine.trainer import BaseTrainer
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
|
||||
from ultralytics.nn.tasks import ClassificationModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
|
||||
from ultralytics.utils.plotting import plot_images, plot_results
|
||||
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
||||
|
|
@ -60,23 +60,14 @@ class ClassificationTrainer(BaseTrainer):
|
|||
"""Load, create or download model for any task."""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
model, ckpt = str(self.model), None
|
||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||
if model.endswith(".pt"):
|
||||
self.model, ckpt = attempt_load_one_weight(model, device="cpu")
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
elif model.split(".")[-1] in {"yaml", "yml"}:
|
||||
self.model = self.get_model(cfg=model)
|
||||
elif model in torchvision.models.__dict__:
|
||||
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
|
||||
if str(self.model) in torchvision.models.__dict__:
|
||||
self.model = torchvision.models.__dict__[self.model](
|
||||
weights="IMAGENET1K_V1" if self.args.pretrained else None
|
||||
)
|
||||
ckpt = None
|
||||
else:
|
||||
raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
|
||||
ckpt = super().setup_model()
|
||||
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
||||
|
||||
return ckpt
|
||||
|
||||
def build_dataset(self, img_path, mode="train", batch=None):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue