ultralytics 8.2.40 refactor HUB code into callbacks (#13896)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-06-23 14:52:12 +02:00 committed by GitHub
parent 3df9d278dc
commit c636fe0f35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 75 additions and 57 deletions

View file

@ -9,7 +9,7 @@ import torch
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.engine.results import Results
from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import (
ARGV,
@ -17,7 +17,6 @@ from ultralytics.utils import (
DEFAULT_CFG_DICT,
LOGGER,
RANK,
SETTINGS,
callbacks,
checks,
emojis,
@ -76,7 +75,6 @@ class Model(nn.Module):
add_callback: Adds a callback function for an event.
clear_callback: Clears all callbacks for an event.
reset_callbacks: Resets all callbacks to their default functions.
_get_hub_session: Retrieves or creates an Ultralytics HUB session.
is_triton_model: Checks if a model is a Triton Server model.
is_hub_model: Checks if a model is an Ultralytics HUB model.
_reset_ckpt_args: Resets checkpoint arguments when loading a PyTorch model.
@ -136,7 +134,7 @@ class Model(nn.Module):
if self.is_hub_model(model):
# Fetch model from HUB
checks.check_requirements("hub-sdk>=0.0.6")
self.session = self._get_hub_session(model)
self.session = HUBTrainingSession.create_session(model)
model = self.session.model_file
# Check if Triton Server model
@ -175,14 +173,6 @@ class Model(nn.Module):
"""
return self.predict(source, stream, **kwargs)
@staticmethod
def _get_hub_session(model: str):
"""Creates a session for Hub Training."""
from ultralytics.hub.session import HUBTrainingSession
session = HUBTrainingSession(model)
return session if session.client.authenticated else None
@staticmethod
def is_triton_model(model: str) -> bool:
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
@ -656,19 +646,6 @@ class Model(nn.Module):
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
if SETTINGS["hub"] is True and not self.session:
# Create a model in HUB
try:
self.session = self._get_hub_session(self.model_name)
if self.session:
self.session.create_model(args)
# Check model was created
if not getattr(self.session.model, "id", None):
self.session = None
except (PermissionError, ModuleNotFoundError):
# Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed
pass
self.trainer.hub_session = self.session # attach optional HUB session
self.trainer.train()
# Update model and cfg after training