Feature: Create HUB Models from CLI or Python Script (#7316)

Co-authored-by: Hassaan Farooq <103611273+hassaanfarooq01@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Kalen Michael 2024-01-10 02:36:14 +01:00 committed by GitHub
parent a92adf8231
commit b54055a2c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 356 additions and 154 deletions

View file

@ -5,10 +5,11 @@ import sys
from pathlib import Path
from typing import Union
from hub_sdk.config import HUB_WEB_ROOT
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load
class Model(nn.Module):
@ -76,8 +77,8 @@ class Model(nn.Module):
# Check if Ultralytics HUB model from https://hub.ultralytics.com
if self.is_hub_model(model):
from ultralytics.hub.session import HUBTrainingSession
self.session = HUBTrainingSession(model)
# Fetch model from HUB
self.session = self._get_hub_session(model)
model = self.session.model_file
# Check if Triton Server model
@ -93,10 +94,20 @@ class Model(nn.Module):
else:
self._load(model, task)
self.model_name = model
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the predict() method with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
@staticmethod
def _get_hub_session(model: str):
"""Creates a session for Hub Training."""
from ultralytics.hub.session import HUBTrainingSession
session = HUBTrainingSession(model)
return session if session.client.authenticated else None
@staticmethod
def is_triton_model(model):
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
@ -336,10 +347,11 @@ class Model(nn.Module):
**kwargs (Any): Any number of arguments representing the training configuration.
"""
self._check_is_pytorch_model()
if self.session: # Ultralytics HUB session
if hasattr(self.session, 'model') and self.session.model.id: # Ultralytics HUB session with loaded model
if any(kwargs):
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
kwargs = self.session.train_args
kwargs = self.session.train_args # overwrite kwargs
checks.check_pip_update_available()
overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
@ -352,6 +364,20 @@ class Model(nn.Module):
if not args.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
if SETTINGS['hub'] is True and not self.session:
# Create a model in HUB
try:
self.session = self._get_hub_session(self.model_name)
if self.session:
self.session.create_model(args)
# Check model was created
if not getattr(self.session.model, 'id', None):
self.session = None
except PermissionError:
# Ignore permission error
pass
self.trainer.hub_session = self.session # attach optional HUB session
self.trainer.train()
# Update model and cfg after training