ultralytics 8.0.62 HUB Syntax updates and fixes (#1795)
Co-authored-by: Danny Kim <imbird0312@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: MagicCodess <32194768+MagicCodess@users.noreply.github.com> Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Amjad Alsharafi <26300843+Amjad50@users.noreply.github.com>
This commit is contained in:
parent
4198570a4b
commit
37274c9845
19 changed files with 189 additions and 97 deletions
|
|
@ -1,6 +1,7 @@
|
|||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
|
@ -77,7 +78,7 @@ class YOLO:
|
|||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||
|
||||
"""
|
||||
self._reset_callbacks()
|
||||
self.callbacks = deepcopy(callbacks.default_callbacks)
|
||||
self.predictor = None # reuse predictor
|
||||
self.model = None # model object
|
||||
self.trainer = None # trainer object
|
||||
|
|
@ -91,7 +92,7 @@ class YOLO:
|
|||
model = str(model).strip() # strip spaces
|
||||
|
||||
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
||||
if model.startswith('https://hub.ultralytics.com/models/'):
|
||||
if self.is_hub_model(model):
|
||||
from ultralytics.hub.session import HUBTrainingSession
|
||||
self.session = HUBTrainingSession(model)
|
||||
model = self.session.model_file
|
||||
|
|
@ -112,6 +113,13 @@ class YOLO:
|
|||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
@staticmethod
|
||||
def is_hub_model(model):
|
||||
return any((
|
||||
model.startswith('https://hub.ultralytics.com/models/'),
|
||||
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
|
||||
(len(model) == 20 and not Path(model).exists() and not any(x in model for x in './\\')))) # MODELID
|
||||
|
||||
def _new(self, cfg: str, task=None, verbose=True):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
|
|
@ -220,8 +228,7 @@ class YOLO:
|
|||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and \
|
||||
('predict' in sys.argv or 'mode=predict' in sys.argv)
|
||||
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
|
||||
|
||||
overrides = self.overrides.copy()
|
||||
overrides['conf'] = 0.25
|
||||
|
|
@ -231,7 +238,7 @@ class YOLO:
|
|||
overrides['save'] = kwargs.get('save', False) # not save files by default
|
||||
if not self.predictor:
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
|
||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
|
|
@ -380,19 +387,17 @@ class YOLO:
|
|||
"""
|
||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||
|
||||
@staticmethod
|
||||
def add_callback(event: str, func):
|
||||
def add_callback(self, event: str, func):
|
||||
"""
|
||||
Add callback
|
||||
"""
|
||||
callbacks.default_callbacks[event].append(func)
|
||||
self.callbacks[event].append(func)
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
||||
return {k: v for k, v in args.items() if k in include}
|
||||
|
||||
@staticmethod
|
||||
def _reset_callbacks():
|
||||
def _reset_callbacks(self):
|
||||
for event in callbacks.default_callbacks.keys():
|
||||
callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue