ultralytics 8.2.40 refactor HUB code into callbacks (#13896)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
3df9d278dc
commit
c636fe0f35
8 changed files with 75 additions and 57 deletions
2
.github/workflows/ci.yaml
vendored
2
.github/workflows/ci.yaml
vendored
|
|
@ -61,6 +61,7 @@ jobs:
|
||||||
run: python docs/build_reference.py
|
run: python docs/build_reference.py
|
||||||
- name: Commit and Push Reference Section Changes
|
- name: Commit and Push Reference Section Changes
|
||||||
run: |
|
run: |
|
||||||
|
git pull origin ${{ github.head_ref || github.ref }}
|
||||||
git add .
|
git add .
|
||||||
git reset HEAD -- .github/workflows/ # workflow changes are not permitted with default token
|
git reset HEAD -- .github/workflows/ # workflow changes are not permitted with default token
|
||||||
git config --global user.name "UltralyticsAssistant"
|
git config --global user.name "UltralyticsAssistant"
|
||||||
|
|
@ -77,6 +78,7 @@ jobs:
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
if: always() && github.event_name == 'pull_request'
|
if: always() && github.event_name == 'pull_request'
|
||||||
run: |
|
run: |
|
||||||
|
git pull origin ${{ github.head_ref || github.ref }}
|
||||||
git add --update # only add updated files
|
git add --update # only add updated files
|
||||||
git reset HEAD -- .github/workflows/ # workflow changes are not permitted with default token
|
git reset HEAD -- .github/workflows/ # workflow changes are not permitted with default token
|
||||||
if ! git diff --staged --quiet; then
|
if ! git diff --staged --quiet; then
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
---
|
---
|
||||||
description: Explore the utilities in the Ultralytics Hub. Learn about smart_request, request_with_credentials, and more to enhance your YOLO projects.
|
description: Explore the utilities in the Ultralytics HUB. Learn about smart_request, request_with_credentials, and more to enhance your YOLO projects.
|
||||||
keywords: Ultralytics, Hub, Utilities, YOLO, smart_request, request_with_credentials
|
keywords: Ultralytics, HUB, Utilities, YOLO, smart_request, request_with_credentials
|
||||||
---
|
---
|
||||||
|
|
||||||
# Reference for `ultralytics/hub/utils.py`
|
# Reference for `ultralytics/hub/utils.py`
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,10 @@ keywords: Ultralytics, callbacks, pretrain, model save, train start, train end,
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.utils.callbacks.hub.on_pretrain_routine_start
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
## ::: ultralytics.utils.callbacks.hub.on_pretrain_routine_end
|
## ::: ultralytics.utils.callbacks.hub.on_pretrain_routine_end
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.2.39"
|
__version__ = "8.2.40"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import torch
|
||||||
|
|
||||||
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||||
from ultralytics.engine.results import Results
|
from ultralytics.engine.results import Results
|
||||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession
|
||||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||||
from ultralytics.utils import (
|
from ultralytics.utils import (
|
||||||
ARGV,
|
ARGV,
|
||||||
|
|
@ -17,7 +17,6 @@ from ultralytics.utils import (
|
||||||
DEFAULT_CFG_DICT,
|
DEFAULT_CFG_DICT,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
RANK,
|
RANK,
|
||||||
SETTINGS,
|
|
||||||
callbacks,
|
callbacks,
|
||||||
checks,
|
checks,
|
||||||
emojis,
|
emojis,
|
||||||
|
|
@ -76,7 +75,6 @@ class Model(nn.Module):
|
||||||
add_callback: Adds a callback function for an event.
|
add_callback: Adds a callback function for an event.
|
||||||
clear_callback: Clears all callbacks for an event.
|
clear_callback: Clears all callbacks for an event.
|
||||||
reset_callbacks: Resets all callbacks to their default functions.
|
reset_callbacks: Resets all callbacks to their default functions.
|
||||||
_get_hub_session: Retrieves or creates an Ultralytics HUB session.
|
|
||||||
is_triton_model: Checks if a model is a Triton Server model.
|
is_triton_model: Checks if a model is a Triton Server model.
|
||||||
is_hub_model: Checks if a model is an Ultralytics HUB model.
|
is_hub_model: Checks if a model is an Ultralytics HUB model.
|
||||||
_reset_ckpt_args: Resets checkpoint arguments when loading a PyTorch model.
|
_reset_ckpt_args: Resets checkpoint arguments when loading a PyTorch model.
|
||||||
|
|
@ -136,7 +134,7 @@ class Model(nn.Module):
|
||||||
if self.is_hub_model(model):
|
if self.is_hub_model(model):
|
||||||
# Fetch model from HUB
|
# Fetch model from HUB
|
||||||
checks.check_requirements("hub-sdk>=0.0.6")
|
checks.check_requirements("hub-sdk>=0.0.6")
|
||||||
self.session = self._get_hub_session(model)
|
self.session = HUBTrainingSession.create_session(model)
|
||||||
model = self.session.model_file
|
model = self.session.model_file
|
||||||
|
|
||||||
# Check if Triton Server model
|
# Check if Triton Server model
|
||||||
|
|
@ -175,14 +173,6 @@ class Model(nn.Module):
|
||||||
"""
|
"""
|
||||||
return self.predict(source, stream, **kwargs)
|
return self.predict(source, stream, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_hub_session(model: str):
|
|
||||||
"""Creates a session for Hub Training."""
|
|
||||||
from ultralytics.hub.session import HUBTrainingSession
|
|
||||||
|
|
||||||
session = HUBTrainingSession(model)
|
|
||||||
return session if session.client.authenticated else None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_triton_model(model: str) -> bool:
|
def is_triton_model(model: str) -> bool:
|
||||||
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
||||||
|
|
@ -656,19 +646,6 @@ class Model(nn.Module):
|
||||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||||
self.model = self.trainer.model
|
self.model = self.trainer.model
|
||||||
|
|
||||||
if SETTINGS["hub"] is True and not self.session:
|
|
||||||
# Create a model in HUB
|
|
||||||
try:
|
|
||||||
self.session = self._get_hub_session(self.model_name)
|
|
||||||
if self.session:
|
|
||||||
self.session.create_model(args)
|
|
||||||
# Check model was created
|
|
||||||
if not getattr(self.session.model, "id", None):
|
|
||||||
self.session = None
|
|
||||||
except (PermissionError, ModuleNotFoundError):
|
|
||||||
# Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.trainer.hub_session = self.session # attach optional HUB session
|
self.trainer.hub_session = self.session # attach optional HUB session
|
||||||
self.trainer.train()
|
self.trainer.train()
|
||||||
# Update model and cfg after training
|
# Update model and cfg after training
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,24 @@ import requests
|
||||||
|
|
||||||
from ultralytics.data.utils import HUBDatasetStats
|
from ultralytics.data.utils import HUBDatasetStats
|
||||||
from ultralytics.hub.auth import Auth
|
from ultralytics.hub.auth import Auth
|
||||||
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
from ultralytics.hub.session import HUBTrainingSession
|
||||||
|
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, events
|
||||||
from ultralytics.utils import LOGGER, SETTINGS, checks
|
from ultralytics.utils import LOGGER, SETTINGS, checks
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"PREFIX",
|
||||||
|
"HUB_WEB_ROOT",
|
||||||
|
"HUBTrainingSession",
|
||||||
|
"login",
|
||||||
|
"logout",
|
||||||
|
"reset_model",
|
||||||
|
"export_fmts_hub",
|
||||||
|
"export_model",
|
||||||
|
"get_export",
|
||||||
|
"check_dataset",
|
||||||
|
"events",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def login(api_key: str = None, save=True) -> bool:
|
def login(api_key: str = None, save=True) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -19,16 +19,12 @@ class HUBTrainingSession:
|
||||||
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
agent_id (str): Identifier for the instance communicating with the server.
|
|
||||||
model_id (str): Identifier for the YOLO model being trained.
|
model_id (str): Identifier for the YOLO model being trained.
|
||||||
model_url (str): URL for the model in Ultralytics HUB.
|
model_url (str): URL for the model in Ultralytics HUB.
|
||||||
api_url (str): API URL for the model in Ultralytics HUB.
|
|
||||||
auth_header (dict): Authentication header for the Ultralytics HUB API requests.
|
|
||||||
rate_limits (dict): Rate limits for different API calls (in seconds).
|
rate_limits (dict): Rate limits for different API calls (in seconds).
|
||||||
timers (dict): Timers for rate limiting.
|
timers (dict): Timers for rate limiting.
|
||||||
metrics_queue (dict): Queue for the model's metrics.
|
metrics_queue (dict): Queue for the model's metrics.
|
||||||
model (dict): Model data fetched from Ultralytics HUB.
|
model (dict): Model data fetched from Ultralytics HUB.
|
||||||
alive (bool): Indicates if the heartbeat loop is active.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, identifier):
|
def __init__(self, identifier):
|
||||||
|
|
@ -46,14 +42,12 @@ class HUBTrainingSession:
|
||||||
"""
|
"""
|
||||||
from hub_sdk import HUBClient
|
from hub_sdk import HUBClient
|
||||||
|
|
||||||
self.rate_limits = {
|
self.rate_limits = {"metrics": 3, "ckpt": 900, "heartbeat": 300} # rate limits (seconds)
|
||||||
"metrics": 3.0,
|
|
||||||
"ckpt": 900.0,
|
|
||||||
"heartbeat": 300.0,
|
|
||||||
} # rate limits (seconds)
|
|
||||||
self.metrics_queue = {} # holds metrics for each epoch until upload
|
self.metrics_queue = {} # holds metrics for each epoch until upload
|
||||||
self.metrics_upload_failed_queue = {} # holds metrics for each epoch if upload failed
|
self.metrics_upload_failed_queue = {} # holds metrics for each epoch if upload failed
|
||||||
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
||||||
|
self.model = None
|
||||||
|
self.model_url = None
|
||||||
|
|
||||||
# Parse input
|
# Parse input
|
||||||
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
||||||
|
|
@ -65,10 +59,26 @@ class HUBTrainingSession:
|
||||||
# Initialize client
|
# Initialize client
|
||||||
self.client = HUBClient(credentials)
|
self.client = HUBClient(credentials)
|
||||||
|
|
||||||
if model_id:
|
# Load models if authenticated
|
||||||
self.load_model(model_id) # load existing model
|
if self.client.authenticated:
|
||||||
else:
|
if model_id:
|
||||||
self.model = self.client.model() # load empty model
|
self.load_model(model_id) # load existing model
|
||||||
|
else:
|
||||||
|
self.model = self.client.model() # load empty model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_session(cls, identifier, args=None):
|
||||||
|
"""Class method to create an authenticated HUBTrainingSession or return None."""
|
||||||
|
try:
|
||||||
|
session = cls(identifier)
|
||||||
|
assert session.client.authenticated, "HUB not authenticated"
|
||||||
|
if args:
|
||||||
|
session.create_model(args)
|
||||||
|
assert session.model.id, "HUB model not loaded correctly"
|
||||||
|
return session
|
||||||
|
# PermissionError and ModuleNotFoundError indicate hub-sdk not installed
|
||||||
|
except (PermissionError, ModuleNotFoundError, AssertionError):
|
||||||
|
return None
|
||||||
|
|
||||||
def load_model(self, model_id):
|
def load_model(self, model_id):
|
||||||
"""Loads an existing model from Ultralytics HUB using the provided model identifier."""
|
"""Loads an existing model from Ultralytics HUB using the provided model identifier."""
|
||||||
|
|
@ -92,14 +102,12 @@ class HUBTrainingSession:
|
||||||
"epochs": model_args.get("epochs", 300),
|
"epochs": model_args.get("epochs", 300),
|
||||||
"imageSize": model_args.get("imgsz", 640),
|
"imageSize": model_args.get("imgsz", 640),
|
||||||
"patience": model_args.get("patience", 100),
|
"patience": model_args.get("patience", 100),
|
||||||
"device": model_args.get("device", ""),
|
"device": str(model_args.get("device", "")), # convert None to string
|
||||||
"cache": model_args.get("cache", "ram"),
|
"cache": str(model_args.get("cache", "ram")), # convert True, False, None to string
|
||||||
},
|
},
|
||||||
"dataset": {"name": model_args.get("data")},
|
"dataset": {"name": model_args.get("data")},
|
||||||
"lineage": {
|
"lineage": {
|
||||||
"architecture": {
|
"architecture": {"name": self.filename.replace(".pt", "").replace(".yaml", "")},
|
||||||
"name": self.filename.replace(".pt", "").replace(".yaml", ""),
|
|
||||||
},
|
|
||||||
"parent": {},
|
"parent": {},
|
||||||
},
|
},
|
||||||
"meta": {"name": self.filename},
|
"meta": {"name": self.filename},
|
||||||
|
|
@ -113,7 +121,7 @@ class HUBTrainingSession:
|
||||||
# Model could not be created
|
# Model could not be created
|
||||||
# TODO: improve error handling
|
# TODO: improve error handling
|
||||||
if not self.model.id:
|
if not self.model.id:
|
||||||
return
|
return None
|
||||||
|
|
||||||
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||||
|
|
||||||
|
|
@ -122,7 +130,8 @@ class HUBTrainingSession:
|
||||||
|
|
||||||
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||||
|
|
||||||
def _parse_identifier(self, identifier):
|
@staticmethod
|
||||||
|
def _parse_identifier(identifier):
|
||||||
"""
|
"""
|
||||||
Parses the given identifier to determine the type of identifier and extract relevant components.
|
Parses the given identifier to determine the type of identifier and extract relevant components.
|
||||||
|
|
||||||
|
|
@ -213,13 +222,14 @@ class HUBTrainingSession:
|
||||||
thread=True,
|
thread=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
progress_total=None,
|
progress_total=None,
|
||||||
stream_reponse=None,
|
stream_response=None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
def retry_request():
|
def retry_request():
|
||||||
"""Attempts to call `request_func` with retries, timeout, and optional threading."""
|
"""Attempts to call `request_func` with retries, timeout, and optional threading."""
|
||||||
t0 = time.time() # Record the start time for the timeout
|
t0 = time.time() # Record the start time for the timeout
|
||||||
|
response = None
|
||||||
for i in range(retry + 1):
|
for i in range(retry + 1):
|
||||||
if (time.time() - t0) > timeout:
|
if (time.time() - t0) > timeout:
|
||||||
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
|
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
|
||||||
|
|
@ -233,7 +243,7 @@ class HUBTrainingSession:
|
||||||
|
|
||||||
if progress_total:
|
if progress_total:
|
||||||
self._show_upload_progress(progress_total, response)
|
self._show_upload_progress(progress_total, response)
|
||||||
elif stream_reponse:
|
elif stream_response:
|
||||||
self._iterate_content(response)
|
self._iterate_content(response)
|
||||||
|
|
||||||
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
||||||
|
|
@ -268,7 +278,8 @@ class HUBTrainingSession:
|
||||||
# If running in the main thread, call retry_request directly
|
# If running in the main thread, call retry_request directly
|
||||||
return retry_request()
|
return retry_request()
|
||||||
|
|
||||||
def _should_retry(self, status_code):
|
@staticmethod
|
||||||
|
def _should_retry(status_code):
|
||||||
"""Determines if a request should be retried based on the HTTP status code."""
|
"""Determines if a request should be retried based on the HTTP status code."""
|
||||||
retry_codes = {
|
retry_codes = {
|
||||||
HTTPStatus.REQUEST_TIMEOUT,
|
HTTPStatus.REQUEST_TIMEOUT,
|
||||||
|
|
@ -338,12 +349,13 @@ class HUBTrainingSession:
|
||||||
timeout=3600,
|
timeout=3600,
|
||||||
thread=not final,
|
thread=not final,
|
||||||
progress_total=progress_total,
|
progress_total=progress_total,
|
||||||
stream_reponse=True,
|
stream_response=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
|
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
|
||||||
|
|
||||||
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
|
@staticmethod
|
||||||
|
def _show_upload_progress(content_length: int, response: requests.Response) -> None:
|
||||||
"""
|
"""
|
||||||
Display a progress bar to track the upload progress of a file download.
|
Display a progress bar to track the upload progress of a file download.
|
||||||
|
|
||||||
|
|
@ -358,7 +370,8 @@ class HUBTrainingSession:
|
||||||
for data in response.iter_content(chunk_size=1024):
|
for data in response.iter_content(chunk_size=1024):
|
||||||
pbar.update(len(data))
|
pbar.update(len(data))
|
||||||
|
|
||||||
def _iterate_content(self, response: requests.Response) -> None:
|
@staticmethod
|
||||||
|
def _iterate_content(response: requests.Response) -> None:
|
||||||
"""
|
"""
|
||||||
Process the streamed HTTP response data.
|
Process the streamed HTTP response data.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,14 @@
|
||||||
import json
|
import json
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events
|
from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events
|
||||||
from ultralytics.utils import LOGGER, SETTINGS
|
from ultralytics.utils import LOGGER, RANK, SETTINGS
|
||||||
|
|
||||||
|
|
||||||
|
def on_pretrain_routine_start(trainer):
|
||||||
|
"""Create a remote Ultralytics HUB session to log local model training."""
|
||||||
|
if RANK in {-1, 0} and SETTINGS["hub"] is True and not getattr(trainer, "hub_session", None):
|
||||||
|
trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args)
|
||||||
|
|
||||||
|
|
||||||
def on_pretrain_routine_end(trainer):
|
def on_pretrain_routine_end(trainer):
|
||||||
|
|
@ -91,6 +97,7 @@ def on_export_start(exporter):
|
||||||
|
|
||||||
callbacks = (
|
callbacks = (
|
||||||
{
|
{
|
||||||
|
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||||
"on_fit_epoch_end": on_fit_epoch_end,
|
"on_fit_epoch_end": on_fit_epoch_end,
|
||||||
"on_model_save": on_model_save,
|
"on_model_save": on_model_save,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue