8.0.60 new HUB training syntax (#1753)

Co-authored-by: Rafael Pierre <97888102+rafaelvp-db@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Semih Demirel <85176438+semihhdemirel@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-04-03 02:36:58 +02:00 committed by GitHub
parent e7876e1ba9
commit 84948651cd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 405 additions and 122 deletions

View file

@ -6,17 +6,62 @@ from time import sleep
import requests
from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request
from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, checks, emojis, is_colab, threaded
from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, check_dataset_disk_space, smart_request
from ultralytics.yolo.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
class HUBTrainingSession:
"""
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
def __init__(self, model_id, auth):
Args:
url (str): Model identifier used to initialize the HUB training session.
Attributes:
agent_id (str): Identifier for the instance communicating with the server.
model_id (str): Identifier for the YOLOv5 model being trained.
model_url (str): URL for the model in Ultralytics HUB.
api_url (str): API URL for the model in Ultralytics HUB.
auth_header (Dict): Authentication header for the Ultralytics HUB API requests.
rate_limits (Dict): Rate limits for different API calls (in seconds).
timers (Dict): Timers for rate limiting.
metrics_queue (Dict): Queue for the model's metrics.
model (Dict): Model data fetched from Ultralytics HUB.
alive (bool): Indicates if the heartbeat loop is active.
"""
def __init__(self, url):
"""
Initialize the HUBTrainingSession with the provided model identifier.
Args:
url (str): Model identifier used to initialize the HUB training session.
It can be a URL string or a model key with specific format.
Raises:
ValueError: If the provided model identifier is invalid.
ConnectionError: If connecting with global API key is not supported.
"""
from ultralytics.hub.auth import Auth
# Parse input
if url.startswith('https://hub.ultralytics.com/models/'):
url = url.split('https://hub.ultralytics.com/models/')[-1]
if [len(x) for x in url.split('_')] == [42, 20]:
key, model_id = url.split('_')
elif len(url) == 20:
key, model_id = '', url
else:
raise ValueError(f'Invalid HUBTrainingSession input: {url}')
# Authorize
auth = Auth(key)
self.agent_id = None # identifies which instance is communicating with server
self.model_id = model_id
self.model_url = f'https://hub.ultralytics.com/models/{model_id}'
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
self.auth_header = auth.get_auth_header()
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
@ -26,16 +71,17 @@ class HUBTrainingSession:
self.alive = True
self._start_heartbeat() # start heartbeats
self._register_signal_handlers()
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
def _register_signal_handlers(self):
"""Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
signal.signal(signal.SIGTERM, self._handle_signal)
signal.signal(signal.SIGINT, self._handle_signal)
def _handle_signal(self, signum, frame):
"""
Prevent heartbeats from being sent on Colab after kill.
This method does not use frame, it is included as it is
passed by signal.
Handle kill signals and prevent heartbeats from being sent on Colab after termination.
This method does not use frame, it is included as it is passed by signal.
"""
if self.alive is True:
LOGGER.info(f'{PREFIX}Kill signal received! ❌')
@ -43,15 +89,16 @@ class HUBTrainingSession:
sys.exit(signum)
def _stop_heartbeat(self):
"""End the heartbeat loop"""
"""Terminate the heartbeat loop."""
self.alive = False
def upload_metrics(self):
"""Upload model metrics to Ultralytics HUB."""
payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
def _get_model(self):
# Returns model from database by id
"""Fetch and return model data from Ultralytics HUB."""
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
try:
@ -59,9 +106,7 @@ class HUBTrainingSession:
data = response.json().get('data', None)
if data.get('status', None) == 'trained':
raise ValueError(
emojis(f'Model is already trained and uploaded to '
f'https://hub.ultralytics.com/models/{self.model_id} 🚀'))
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
if not data.get('data', None):
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
@ -88,11 +133,21 @@ class HUBTrainingSession:
raise
def check_disk_space(self):
if not check_dataset_disk_space(self.model['data']):
"""Check if there is enough disk space for the dataset."""
if not check_dataset_disk_space(url=self.model['data']):
raise MemoryError('Not enough disk space')
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
# Upload a model to HUB
"""
Upload a model checkpoint to Ultralytics HUB.
Args:
epoch (int): The current training epoch.
weights (str): Path to the model weights file.
is_best (bool): Indicates if the current model is the best one so far.
map (float): Mean average precision of the model.
final (bool): Indicates if the model is the final model after training.
"""
if Path(weights).is_file():
with open(weights, 'rb') as f:
file = f.read()
@ -120,6 +175,7 @@ class HUBTrainingSession:
@threaded
def _start_heartbeat(self):
"""Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
while self.alive:
r = smart_request('post',
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',