diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index a15f8567..d2834c94 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.48" +__version__ = "8.2.49" import os diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py index e049313b..931112a3 100644 --- a/ultralytics/models/yolo/classify/train.py +++ b/ultralytics/models/yolo/classify/train.py @@ -5,7 +5,7 @@ import torch from ultralytics.data import ClassificationDataset, build_dataloader from ultralytics.engine.trainer import BaseTrainer from ultralytics.models import yolo -from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight +from ultralytics.nn.tasks import ClassificationModel from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr from ultralytics.utils.plotting import plot_images, plot_results from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first @@ -60,23 +60,14 @@ class ClassificationTrainer(BaseTrainer): """Load, create or download model for any task.""" import torchvision # scope for faster 'import ultralytics' - if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed - return - - model, ckpt = str(self.model), None - # Load a YOLO model locally, from torchvision, or from Ultralytics assets - if model.endswith(".pt"): - self.model, ckpt = attempt_load_one_weight(model, device="cpu") - for p in self.model.parameters(): - p.requires_grad = True # for training - elif model.split(".")[-1] in {"yaml", "yml"}: - self.model = self.get_model(cfg=model) - elif model in torchvision.models.__dict__: - self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None) + if str(self.model) in torchvision.models.__dict__: + self.model = torchvision.models.__dict__[self.model]( + weights="IMAGENET1K_V1" if self.args.pretrained else None + ) + ckpt = None else: - raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.") + ckpt = super().setup_model() ClassificationModel.reshape_outputs(self.model, self.data["nc"]) - return ckpt def build_dataset(self, img_path, mode="train", batch=None):