Feature: Create HUB Models from CLI or Python Script (#7316)

Co-authored-by: Hassaan Farooq <103611273+hassaanfarooq01@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Kalen Michael 2024-01-10 02:36:14 +01:00 committed by GitHub
parent a92adf8231
commit b54055a2c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 356 additions and 154 deletions

View file

@ -1,17 +1,18 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import signal
import sys
import threading
import time
from http import HTTPStatus
from pathlib import Path
from time import sleep
import requests
from hub_sdk import HUB_WEB_ROOT, HUBClient
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, smart_request
from ultralytics.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM
from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
from ultralytics.utils.errors import HUBModelError
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
AGENT_NAME = (f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local')
class HUBTrainingSession:
@ -34,7 +35,7 @@ class HUBTrainingSession:
alive (bool): Indicates if the heartbeat loop is active.
"""
def __init__(self, url):
def __init__(self, identifier):
"""
Initialize the HUBTrainingSession with the provided model identifier.
@ -46,98 +47,251 @@ class HUBTrainingSession:
ValueError: If the provided model identifier is invalid.
ConnectionError: If connecting with global API key is not supported.
"""
from ultralytics.hub.auth import Auth
self.rate_limits = {
'metrics': 3.0,
'ckpt': 900.0,
'heartbeat': 300.0, } # rate limits (seconds)
self.metrics_queue = {} # holds metrics for each epoch until upload
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
# Parse input
if url.startswith(f'{HUB_WEB_ROOT}/models/'):
url = url.split(f'{HUB_WEB_ROOT}/models/')[-1]
if [len(x) for x in url.split('_')] == [42, 20]:
key, model_id = url.split('_')
elif len(url) == 20:
key, model_id = '', url
else:
raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
f"model='{HUB_WEB_ROOT}/models/MODEL_ID' and try again.")
api_key, model_id, self.filename = self._parse_identifier(identifier)
# Authorize
auth = Auth(key)
self.agent_id = None # identifies which instance is communicating with server
self.model_id = model_id
self.model_url = f'{HUB_WEB_ROOT}/models/{model_id}'
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
self.auth_header = auth.get_auth_header()
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
self.timers = {} # rate limit timers (seconds)
self.metrics_queue = {} # metrics queue
self.model = self._get_model()
self.alive = True
self._start_heartbeat() # start heartbeats
self._register_signal_handlers()
# Get credentials
active_key = api_key or SETTINGS.get('api_key')
credentials = {'api_key': active_key} if active_key else None # set credentials
# Initialize client
self.client = HUBClient(credentials)
if model_id:
self.load_model(model_id) # load existing model
else:
self.model = self.client.model() # load empty model
def load_model(self, model_id):
# Initialize model
self.model = self.client.model(model_id)
self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
self._set_train_args()
# Start heartbeats for HUB to monitor agent
self.model.start_heartbeat(self.rate_limits['heartbeat'])
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
def _register_signal_handlers(self):
"""Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
signal.signal(signal.SIGTERM, self._handle_signal)
signal.signal(signal.SIGINT, self._handle_signal)
def create_model(self, model_args):
# Initialize model
payload = {
'config': {
'batchSize': model_args.get('batch', -1),
'epochs': model_args.get('epochs', 300),
'imageSize': model_args.get('imgsz', 640),
'patience': model_args.get('patience', 100),
'device': model_args.get('device', ''),
'cache': model_args.get('cache', 'ram'), },
'dataset': {
'name': model_args.get('data')},
'lineage': {
'architecture': {
'name': self.filename.replace('.pt', '').replace('.yaml', ''), },
'parent': {}, },
'meta': {
'name': self.filename}, }
def _handle_signal(self, signum, frame):
if self.filename.endswith('.pt'):
payload['lineage']['parent']['name'] = self.filename
self.model.create_model(payload)
# Model could not be created
# TODO: improve error handling
if not self.model.id:
return
self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
# Start heartbeats for HUB to monitor agent
self.model.start_heartbeat(self.rate_limits['heartbeat'])
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
def _parse_identifier(self, identifier):
"""
Handle kill signals and prevent heartbeats from being sent on Colab after termination.
Parses the given identifier to determine the type of identifier and extract relevant components.
This method does not use frame, it is included as it is passed by signal.
The method supports different identifier formats:
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
- An identifier containing an API key and a model ID separated by an underscore
- An identifier that is solely a model ID of a fixed length
- A local filename that ends with '.pt' or '.yaml'
Args:
identifier (str): The identifier string to be parsed.
Returns:
(tuple): A tuple containing the API key, model ID, and filename as applicable.
Raises:
HUBModelError: If the identifier format is not recognized.
"""
if self.alive is True:
LOGGER.info(f'{PREFIX}Kill signal received! ❌')
self._stop_heartbeat()
sys.exit(signum)
def _stop_heartbeat(self):
"""Terminate the heartbeat loop."""
self.alive = False
# Initialize variables
api_key, model_id, filename = None, None, None
# Check if identifier is a HUB URL
if identifier.startswith(f'{HUB_WEB_ROOT}/models/'):
# Extract the model_id after the HUB_WEB_ROOT URL
model_id = identifier.split(f'{HUB_WEB_ROOT}/models/')[-1]
else:
# Split the identifier based on underscores only if it's not a HUB URL
parts = identifier.split('_')
# Check if identifier is in the format of API key and model ID
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
api_key, model_id = parts
# Check if identifier is a single model ID
elif len(parts) == 1 and len(parts[0]) == 20:
model_id = parts[0]
# Check if identifier is a local filename
elif identifier.endswith('.pt') or identifier.endswith('.yaml'):
filename = identifier
else:
raise HUBModelError(
f"model='{identifier}' could not be parsed. Check format is correct. "
f'Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file.')
return api_key, model_id, filename
def _set_train_args(self, **kwargs):
if self.model.is_trained():
# Model is already trained
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
if self.model.is_resumable():
# Model has saved weights
self.train_args = {'data': self.model.get_dataset_url(), 'resume': True}
self.model_file = self.model.get_weights_url('last')
else:
# Model has no saved weights
def get_train_args(config):
return {
'batch': config['batchSize'],
'epochs': config['epochs'],
'imgsz': config['imageSize'],
'patience': config['patience'],
'device': config['device'],
'cache': config['cache'],
'data': self.model.get_dataset_url(), }
self.train_args = get_train_args(self.model.data.get('config'))
# Set the model file as either a *.pt or *.yaml file
self.model_file = (self.model.get_weights_url('parent')
if self.model.is_pretrained() else self.model.get_architecture())
if not self.train_args.get('data'):
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
self.model_id = self.model.id
def request_queue(
self,
request_func,
retry=3,
timeout=30,
thread=True,
verbose=True,
progress_total=None,
*args,
**kwargs,
):
def retry_request():
t0 = time.time() # Record the start time for the timeout
for i in range(retry + 1):
if (time.time() - t0) > timeout:
LOGGER.warning(f'{PREFIX}Timeout for request reached. {HELP_MSG}')
break # Timeout reached, exit loop
response = request_func(*args, **kwargs)
if progress_total:
self._show_upload_progress(progress_total, response)
if response is None:
LOGGER.warning(f'{PREFIX}Received no response from the request. {HELP_MSG}')
time.sleep(2 ** i) # Exponential backoff before retrying
continue # Skip further processing and retry
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
return response # Success, no need to retry
if i == 0:
# Initial attempt, check status code and provide messages
message = self._get_failure_message(response, retry, timeout)
if verbose:
LOGGER.warning(f'{PREFIX}{message} {HELP_MSG} ({response.status_code})')
if not self._should_retry(response.status_code):
LOGGER.warning(f'{PREFIX}Request failed. {HELP_MSG} ({response.status_code}')
break # Not an error that should be retried, exit loop
time.sleep(2 ** i) # Exponential backoff for retries
return response
if thread:
# Start a new thread to run the retry_request function
threading.Thread(target=retry_request, daemon=True).start()
else:
# If running in the main thread, call retry_request directly
return retry_request()
def _should_retry(self, status_code):
# Status codes that trigger retries
retry_codes = {
HTTPStatus.REQUEST_TIMEOUT,
HTTPStatus.BAD_GATEWAY,
HTTPStatus.GATEWAY_TIMEOUT, }
return True if status_code in retry_codes else False
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
"""
Generate a retry message based on the response status code.
Args:
response: The HTTP response object.
retry: The number of retry attempts allowed.
timeout: The maximum timeout duration.
Returns:
str: The retry message.
"""
if self._should_retry(response.status_code):
return f'Retrying {retry}x for {timeout}s.' if retry else ''
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
headers = response.headers
return (f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
f"Please retry after {headers['Retry-After']}s.")
else:
try:
return response.json().get('message', 'No JSON message.')
except AttributeError:
return 'Unable to read JSON.'
def upload_metrics(self):
"""Upload model metrics to Ultralytics HUB."""
payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)
def _get_model(self):
"""Fetch and return model data from Ultralytics HUB."""
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
try:
response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
data = response.json().get('data', None)
if data.get('status', None) == 'trained':
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
if not data.get('data', None):
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
self.model_id = data['id']
if data['status'] == 'new': # new model to start training
self.train_args = {
'batch': data['batch_size'], # note HUB argument is slightly different
'epochs': data['epochs'],
'imgsz': data['imgsz'],
'patience': data['patience'],
'device': data['device'],
'cache': data['cache'],
'data': data['data']}
self.model_file = data.get('cfg') or data.get('weights') # cfg for pretrained=False
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
elif data['status'] == 'training': # existing model to resume training
self.train_args = {'data': data['data'], 'resume': True}
self.model_file = data['resume']
return data
except requests.exceptions.ConnectionError as e:
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
except Exception:
raise
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
def upload_model(
self,
epoch: int,
weights: str,
is_best: bool = False,
map: float = 0.0,
final: bool = False,
) -> None:
"""
Upload a model checkpoint to Ultralytics HUB.
@ -149,43 +303,33 @@ class HUBTrainingSession:
final (bool): Indicates if the model is the final model after training.
"""
if Path(weights).is_file():
with open(weights, 'rb') as f:
file = f.read()
progress_total = (Path(weights).stat().st_size if final else None) # Only show progress if final
self.request_queue(
self.model.upload_model,
epoch=epoch,
weights=weights,
is_best=is_best,
map=map,
final=final,
retry=10,
timeout=3600,
thread=not final,
progress_total=progress_total,
)
else:
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
file = None
url = f'{self.api_url}/upload'
# url = 'http://httpbin.org/post' # for debug
data = {'epoch': epoch}
if final:
data.update({'type': 'final', 'map': map})
filesize = Path(weights).stat().st_size
smart_request('post',
url,
data=data,
files={'best.pt': file},
headers=self.auth_header,
retry=10,
timeout=3600,
thread=False,
progress=filesize,
code=4)
else:
data.update({'type': 'epoch', 'isBest': bool(is_best)})
smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)
@threaded
def _start_heartbeat(self):
"""Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
while self.alive:
r = smart_request('post',
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
json={
'agent': AGENT_NAME,
'agentId': self.agent_id},
headers=self.auth_header,
retry=0,
code=5,
thread=False) # already in a thread
self.agent_id = r.json().get('data', {}).get('agentId', None)
sleep(self.rate_limits['heartbeat'])
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
"""
Display a progress bar to track the upload progress of a file download.
Args:
content_length (int): The total size of the content to be downloaded in bytes.
response (requests.Response): The response object from the file download request.
Returns:
(None)
"""
with TQDM(total=content_length, unit='B', unit_scale=True, unit_divisor=1024) as pbar:
for data in response.iter_content(chunk_size=1024):
pbar.update(len(data))