ultralytics 8.2.40 refactor HUB code into callbacks (#13896)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
3df9d278dc
commit
c636fe0f35
8 changed files with 75 additions and 57 deletions
|
|
@ -9,7 +9,7 @@ import torch
|
|||
|
||||
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
||||
from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||
from ultralytics.utils import (
|
||||
ARGV,
|
||||
|
|
@ -17,7 +17,6 @@ from ultralytics.utils import (
|
|||
DEFAULT_CFG_DICT,
|
||||
LOGGER,
|
||||
RANK,
|
||||
SETTINGS,
|
||||
callbacks,
|
||||
checks,
|
||||
emojis,
|
||||
|
|
@ -76,7 +75,6 @@ class Model(nn.Module):
|
|||
add_callback: Adds a callback function for an event.
|
||||
clear_callback: Clears all callbacks for an event.
|
||||
reset_callbacks: Resets all callbacks to their default functions.
|
||||
_get_hub_session: Retrieves or creates an Ultralytics HUB session.
|
||||
is_triton_model: Checks if a model is a Triton Server model.
|
||||
is_hub_model: Checks if a model is an Ultralytics HUB model.
|
||||
_reset_ckpt_args: Resets checkpoint arguments when loading a PyTorch model.
|
||||
|
|
@ -136,7 +134,7 @@ class Model(nn.Module):
|
|||
if self.is_hub_model(model):
|
||||
# Fetch model from HUB
|
||||
checks.check_requirements("hub-sdk>=0.0.6")
|
||||
self.session = self._get_hub_session(model)
|
||||
self.session = HUBTrainingSession.create_session(model)
|
||||
model = self.session.model_file
|
||||
|
||||
# Check if Triton Server model
|
||||
|
|
@ -175,14 +173,6 @@ class Model(nn.Module):
|
|||
"""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_hub_session(model: str):
|
||||
"""Creates a session for Hub Training."""
|
||||
from ultralytics.hub.session import HUBTrainingSession
|
||||
|
||||
session = HUBTrainingSession(model)
|
||||
return session if session.client.authenticated else None
|
||||
|
||||
@staticmethod
|
||||
def is_triton_model(model: str) -> bool:
|
||||
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
||||
|
|
@ -656,19 +646,6 @@ class Model(nn.Module):
|
|||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
|
||||
if SETTINGS["hub"] is True and not self.session:
|
||||
# Create a model in HUB
|
||||
try:
|
||||
self.session = self._get_hub_session(self.model_name)
|
||||
if self.session:
|
||||
self.session.create_model(args)
|
||||
# Check model was created
|
||||
if not getattr(self.session.model, "id", None):
|
||||
self.session = None
|
||||
except (PermissionError, ModuleNotFoundError):
|
||||
# Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed
|
||||
pass
|
||||
|
||||
self.trainer.hub_session = self.session # attach optional HUB session
|
||||
self.trainer.train()
|
||||
# Update model and cfg after training
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue