ultralytics 8.2.40 refactor HUB code into callbacks (#13896)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
3df9d278dc
commit
c636fe0f35
8 changed files with 75 additions and 57 deletions
|
|
@ -4,9 +4,24 @@ import requests
|
|||
|
||||
from ultralytics.data.utils import HUBDatasetStats
|
||||
from ultralytics.hub.auth import Auth
|
||||
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
||||
from ultralytics.hub.session import HUBTrainingSession
|
||||
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, events
|
||||
from ultralytics.utils import LOGGER, SETTINGS, checks
|
||||
|
||||
__all__ = (
|
||||
"PREFIX",
|
||||
"HUB_WEB_ROOT",
|
||||
"HUBTrainingSession",
|
||||
"login",
|
||||
"logout",
|
||||
"reset_model",
|
||||
"export_fmts_hub",
|
||||
"export_model",
|
||||
"get_export",
|
||||
"check_dataset",
|
||||
"events",
|
||||
)
|
||||
|
||||
|
||||
def login(api_key: str = None, save=True) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -19,16 +19,12 @@ class HUBTrainingSession:
|
|||
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
||||
|
||||
Attributes:
|
||||
agent_id (str): Identifier for the instance communicating with the server.
|
||||
model_id (str): Identifier for the YOLO model being trained.
|
||||
model_url (str): URL for the model in Ultralytics HUB.
|
||||
api_url (str): API URL for the model in Ultralytics HUB.
|
||||
auth_header (dict): Authentication header for the Ultralytics HUB API requests.
|
||||
rate_limits (dict): Rate limits for different API calls (in seconds).
|
||||
timers (dict): Timers for rate limiting.
|
||||
metrics_queue (dict): Queue for the model's metrics.
|
||||
model (dict): Model data fetched from Ultralytics HUB.
|
||||
alive (bool): Indicates if the heartbeat loop is active.
|
||||
"""
|
||||
|
||||
def __init__(self, identifier):
|
||||
|
|
@ -46,14 +42,12 @@ class HUBTrainingSession:
|
|||
"""
|
||||
from hub_sdk import HUBClient
|
||||
|
||||
self.rate_limits = {
|
||||
"metrics": 3.0,
|
||||
"ckpt": 900.0,
|
||||
"heartbeat": 300.0,
|
||||
} # rate limits (seconds)
|
||||
self.rate_limits = {"metrics": 3, "ckpt": 900, "heartbeat": 300} # rate limits (seconds)
|
||||
self.metrics_queue = {} # holds metrics for each epoch until upload
|
||||
self.metrics_upload_failed_queue = {} # holds metrics for each epoch if upload failed
|
||||
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
||||
self.model = None
|
||||
self.model_url = None
|
||||
|
||||
# Parse input
|
||||
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
||||
|
|
@ -65,10 +59,26 @@ class HUBTrainingSession:
|
|||
# Initialize client
|
||||
self.client = HUBClient(credentials)
|
||||
|
||||
if model_id:
|
||||
self.load_model(model_id) # load existing model
|
||||
else:
|
||||
self.model = self.client.model() # load empty model
|
||||
# Load models if authenticated
|
||||
if self.client.authenticated:
|
||||
if model_id:
|
||||
self.load_model(model_id) # load existing model
|
||||
else:
|
||||
self.model = self.client.model() # load empty model
|
||||
|
||||
@classmethod
|
||||
def create_session(cls, identifier, args=None):
|
||||
"""Class method to create an authenticated HUBTrainingSession or return None."""
|
||||
try:
|
||||
session = cls(identifier)
|
||||
assert session.client.authenticated, "HUB not authenticated"
|
||||
if args:
|
||||
session.create_model(args)
|
||||
assert session.model.id, "HUB model not loaded correctly"
|
||||
return session
|
||||
# PermissionError and ModuleNotFoundError indicate hub-sdk not installed
|
||||
except (PermissionError, ModuleNotFoundError, AssertionError):
|
||||
return None
|
||||
|
||||
def load_model(self, model_id):
|
||||
"""Loads an existing model from Ultralytics HUB using the provided model identifier."""
|
||||
|
|
@ -92,14 +102,12 @@ class HUBTrainingSession:
|
|||
"epochs": model_args.get("epochs", 300),
|
||||
"imageSize": model_args.get("imgsz", 640),
|
||||
"patience": model_args.get("patience", 100),
|
||||
"device": model_args.get("device", ""),
|
||||
"cache": model_args.get("cache", "ram"),
|
||||
"device": str(model_args.get("device", "")), # convert None to string
|
||||
"cache": str(model_args.get("cache", "ram")), # convert True, False, None to string
|
||||
},
|
||||
"dataset": {"name": model_args.get("data")},
|
||||
"lineage": {
|
||||
"architecture": {
|
||||
"name": self.filename.replace(".pt", "").replace(".yaml", ""),
|
||||
},
|
||||
"architecture": {"name": self.filename.replace(".pt", "").replace(".yaml", "")},
|
||||
"parent": {},
|
||||
},
|
||||
"meta": {"name": self.filename},
|
||||
|
|
@ -113,7 +121,7 @@ class HUBTrainingSession:
|
|||
# Model could not be created
|
||||
# TODO: improve error handling
|
||||
if not self.model.id:
|
||||
return
|
||||
return None
|
||||
|
||||
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||
|
||||
|
|
@ -122,7 +130,8 @@ class HUBTrainingSession:
|
|||
|
||||
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||
|
||||
def _parse_identifier(self, identifier):
|
||||
@staticmethod
|
||||
def _parse_identifier(identifier):
|
||||
"""
|
||||
Parses the given identifier to determine the type of identifier and extract relevant components.
|
||||
|
||||
|
|
@ -213,13 +222,14 @@ class HUBTrainingSession:
|
|||
thread=True,
|
||||
verbose=True,
|
||||
progress_total=None,
|
||||
stream_reponse=None,
|
||||
stream_response=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
def retry_request():
|
||||
"""Attempts to call `request_func` with retries, timeout, and optional threading."""
|
||||
t0 = time.time() # Record the start time for the timeout
|
||||
response = None
|
||||
for i in range(retry + 1):
|
||||
if (time.time() - t0) > timeout:
|
||||
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
|
||||
|
|
@ -233,7 +243,7 @@ class HUBTrainingSession:
|
|||
|
||||
if progress_total:
|
||||
self._show_upload_progress(progress_total, response)
|
||||
elif stream_reponse:
|
||||
elif stream_response:
|
||||
self._iterate_content(response)
|
||||
|
||||
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
||||
|
|
@ -268,7 +278,8 @@ class HUBTrainingSession:
|
|||
# If running in the main thread, call retry_request directly
|
||||
return retry_request()
|
||||
|
||||
def _should_retry(self, status_code):
|
||||
@staticmethod
|
||||
def _should_retry(status_code):
|
||||
"""Determines if a request should be retried based on the HTTP status code."""
|
||||
retry_codes = {
|
||||
HTTPStatus.REQUEST_TIMEOUT,
|
||||
|
|
@ -338,12 +349,13 @@ class HUBTrainingSession:
|
|||
timeout=3600,
|
||||
thread=not final,
|
||||
progress_total=progress_total,
|
||||
stream_reponse=True,
|
||||
stream_response=True,
|
||||
)
|
||||
else:
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
|
||||
|
||||
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
|
||||
@staticmethod
|
||||
def _show_upload_progress(content_length: int, response: requests.Response) -> None:
|
||||
"""
|
||||
Display a progress bar to track the upload progress of a file download.
|
||||
|
||||
|
|
@ -358,7 +370,8 @@ class HUBTrainingSession:
|
|||
for data in response.iter_content(chunk_size=1024):
|
||||
pbar.update(len(data))
|
||||
|
||||
def _iterate_content(self, response: requests.Response) -> None:
|
||||
@staticmethod
|
||||
def _iterate_content(response: requests.Response) -> None:
|
||||
"""
|
||||
Process the streamed HTTP response data.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue