ultralytics 8.2.49 fix classification setup_model (#14199)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
6c13bea7b8
commit
64862f1b69
2 changed files with 8 additions and 17 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.2.48"
|
__version__ = "8.2.49"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||||
from ultralytics.engine.trainer import BaseTrainer
|
from ultralytics.engine.trainer import BaseTrainer
|
||||||
from ultralytics.models import yolo
|
from ultralytics.models import yolo
|
||||||
from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
|
from ultralytics.nn.tasks import ClassificationModel
|
||||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
|
||||||
from ultralytics.utils.plotting import plot_images, plot_results
|
from ultralytics.utils.plotting import plot_images, plot_results
|
||||||
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
||||||
|
|
@ -60,23 +60,14 @@ class ClassificationTrainer(BaseTrainer):
|
||||||
"""Load, create or download model for any task."""
|
"""Load, create or download model for any task."""
|
||||||
import torchvision # scope for faster 'import ultralytics'
|
import torchvision # scope for faster 'import ultralytics'
|
||||||
|
|
||||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
if str(self.model) in torchvision.models.__dict__:
|
||||||
return
|
self.model = torchvision.models.__dict__[self.model](
|
||||||
|
weights="IMAGENET1K_V1" if self.args.pretrained else None
|
||||||
model, ckpt = str(self.model), None
|
)
|
||||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
ckpt = None
|
||||||
if model.endswith(".pt"):
|
|
||||||
self.model, ckpt = attempt_load_one_weight(model, device="cpu")
|
|
||||||
for p in self.model.parameters():
|
|
||||||
p.requires_grad = True # for training
|
|
||||||
elif model.split(".")[-1] in {"yaml", "yml"}:
|
|
||||||
self.model = self.get_model(cfg=model)
|
|
||||||
elif model in torchvision.models.__dict__:
|
|
||||||
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
|
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
|
ckpt = super().setup_model()
|
||||||
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
||||||
|
|
||||||
return ckpt
|
return ckpt
|
||||||
|
|
||||||
def build_dataset(self, img_path, mode="train", batch=None):
|
def build_dataset(self, img_path, mode="train", batch=None):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue