ultralytics 8.0.62 HUB Syntax updates and fixes (#1795)

Co-authored-by: Danny Kim <imbird0312@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: MagicCodess <32194768+MagicCodess@users.noreply.github.com>
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Amjad Alsharafi <26300843+Amjad50@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-04-05 02:14:52 +02:00 committed by GitHub
parent 4198570a4b
commit 37274c9845
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 189 additions and 97 deletions

View file

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import sys
from copy import deepcopy
from pathlib import Path
from typing import Union
@ -77,7 +78,7 @@ class YOLO:
task (Any, optional): Task type for the YOLO model. Defaults to None.
"""
self._reset_callbacks()
self.callbacks = deepcopy(callbacks.default_callbacks)
self.predictor = None # reuse predictor
self.model = None # model object
self.trainer = None # trainer object
@ -91,7 +92,7 @@ class YOLO:
model = str(model).strip() # strip spaces
# Check if Ultralytics HUB model from https://hub.ultralytics.com
if model.startswith('https://hub.ultralytics.com/models/'):
if self.is_hub_model(model):
from ultralytics.hub.session import HUBTrainingSession
self.session = HUBTrainingSession(model)
model = self.session.model_file
@ -112,6 +113,13 @@ class YOLO:
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@staticmethod
def is_hub_model(model):
return any((
model.startswith('https://hub.ultralytics.com/models/'),
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
(len(model) == 20 and not Path(model).exists() and not any(x in model for x in './\\')))) # MODELID
def _new(self, cfg: str, task=None, verbose=True):
"""
Initializes a new model and infers the task type from the model definitions.
@ -220,8 +228,7 @@ class YOLO:
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and \
('predict' in sys.argv or 'mode=predict' in sys.argv)
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
overrides = self.overrides.copy()
overrides['conf'] = 0.25
@ -231,7 +238,7 @@ class YOLO:
overrides['save'] = kwargs.get('save', False) # not save files by default
if not self.predictor:
self.task = overrides.get('task') or self.task
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
@ -380,19 +387,17 @@ class YOLO:
"""
return self.model.transforms if hasattr(self.model, 'transforms') else None
@staticmethod
def add_callback(event: str, func):
def add_callback(self, event: str, func):
"""
Add callback
"""
callbacks.default_callbacks[event].append(func)
self.callbacks[event].append(func)
@staticmethod
def _reset_ckpt_args(args):
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
return {k: v for k, v in args.items() if k in include}
@staticmethod
def _reset_callbacks():
def _reset_callbacks(self):
for event in callbacks.default_callbacks.keys():
callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]]
self.callbacks[event] = [callbacks.default_callbacks[event][0]]