Feature: Create HUB Models from CLI or Python Script (#7316)
Co-authored-by: Hassaan Farooq <103611273+hassaanfarooq01@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
a92adf8231
commit
b54055a2c7
7 changed files with 356 additions and 154 deletions
|
|
@ -5,10 +5,11 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from hub_sdk.config import HUB_WEB_ROOT
|
||||
|
||||
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load
|
||||
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
|
@ -76,8 +77,8 @@ class Model(nn.Module):
|
|||
|
||||
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
||||
if self.is_hub_model(model):
|
||||
from ultralytics.hub.session import HUBTrainingSession
|
||||
self.session = HUBTrainingSession(model)
|
||||
# Fetch model from HUB
|
||||
self.session = self._get_hub_session(model)
|
||||
model = self.session.model_file
|
||||
|
||||
# Check if Triton Server model
|
||||
|
|
@ -93,10 +94,20 @@ class Model(nn.Module):
|
|||
else:
|
||||
self._load(model, task)
|
||||
|
||||
self.model_name = model
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
"""Calls the predict() method with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_hub_session(model: str):
|
||||
"""Creates a session for Hub Training."""
|
||||
from ultralytics.hub.session import HUBTrainingSession
|
||||
|
||||
session = HUBTrainingSession(model)
|
||||
return session if session.client.authenticated else None
|
||||
|
||||
@staticmethod
|
||||
def is_triton_model(model):
|
||||
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
||||
|
|
@ -336,10 +347,11 @@ class Model(nn.Module):
|
|||
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
if self.session: # Ultralytics HUB session
|
||||
if hasattr(self.session, 'model') and self.session.model.id: # Ultralytics HUB session with loaded model
|
||||
if any(kwargs):
|
||||
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
||||
kwargs = self.session.train_args
|
||||
kwargs = self.session.train_args # overwrite kwargs
|
||||
|
||||
checks.check_pip_update_available()
|
||||
|
||||
overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
|
||||
|
|
@ -352,6 +364,20 @@ class Model(nn.Module):
|
|||
if not args.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
|
||||
if SETTINGS['hub'] is True and not self.session:
|
||||
# Create a model in HUB
|
||||
try:
|
||||
self.session = self._get_hub_session(self.model_name)
|
||||
if self.session:
|
||||
self.session.create_model(args)
|
||||
# Check model was created
|
||||
if not getattr(self.session.model, 'id', None):
|
||||
self.session = None
|
||||
except PermissionError:
|
||||
# Ignore permission error
|
||||
pass
|
||||
|
||||
self.trainer.hub_session = self.session # attach optional HUB session
|
||||
self.trainer.train()
|
||||
# Update model and cfg after training
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue