HUB setup (#108)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
c6eb6720de
commit
2bc9a5c87e
16 changed files with 631 additions and 122 deletions
121
ultralytics/hub/session.py
Normal file
121
ultralytics/hub/session.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
|
||||
import requests
|
||||
|
||||
from ultralytics import __version__
|
||||
from ultralytics.hub.config import HUB_API_ROOT
|
||||
from ultralytics.hub.utils import check_dataset_disk_space, smart_request
|
||||
from ultralytics.yolo.utils import LOGGER, is_colab, threaded
|
||||
|
||||
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
|
||||
|
||||
session = None
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
""" Confirm exit """
|
||||
global hub_logger
|
||||
LOGGER.info(f'Signal received. {signum} {frame}')
|
||||
if isinstance(session, HubTrainingSession):
|
||||
hub_logger.alive = False
|
||||
del hub_logger
|
||||
sys.exit(signum)
|
||||
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
|
||||
class HubTrainingSession:
|
||||
|
||||
def __init__(self, model_id, auth):
|
||||
self.agent_id = None # identifies which instance is communicating with server
|
||||
self.model_id = model_id
|
||||
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
|
||||
self.auth_header = auth.get_auth_header()
|
||||
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
|
||||
self.t = {} # rate limit timers (seconds)
|
||||
self.metrics_queue = {} # metrics queue
|
||||
self.alive = True # for heartbeats
|
||||
self.model = self._get_model()
|
||||
self._heartbeats() # start heartbeats
|
||||
|
||||
def __del__(self):
|
||||
# Class destructor
|
||||
self.alive = False
|
||||
|
||||
def upload_metrics(self):
|
||||
payload = {"metrics": self.metrics_queue.copy(), "type": "metrics"}
|
||||
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
|
||||
|
||||
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
||||
# Upload a model to HUB
|
||||
file = None
|
||||
if Path(weights).is_file():
|
||||
with open(weights, "rb") as f:
|
||||
file = f.read()
|
||||
if final:
|
||||
smart_request(f'{self.api_url}/upload',
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"type": "final",
|
||||
"map": map},
|
||||
files={"best.pt": file},
|
||||
headers=self.auth_header,
|
||||
retry=10,
|
||||
timeout=3600,
|
||||
code=4)
|
||||
else:
|
||||
smart_request(f'{self.api_url}/upload',
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"type": "epoch",
|
||||
"isBest": bool(is_best)},
|
||||
headers=self.auth_header,
|
||||
files={"last.pt": file},
|
||||
code=3)
|
||||
|
||||
def _get_model(self):
|
||||
# Returns model from database by id
|
||||
api_url = f"{HUB_API_ROOT}/v1/models/{self.model_id}"
|
||||
headers = self.auth_header
|
||||
|
||||
try:
|
||||
r = smart_request(api_url, method="get", headers=headers, thread=False, code=0)
|
||||
data = r.json().get("data", None)
|
||||
if not data:
|
||||
return
|
||||
assert data['data'], 'ERROR: Dataset may still be processing. Please wait a minute and try again.' # RF fix
|
||||
self.model_id = data["id"]
|
||||
|
||||
return data
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
|
||||
|
||||
def check_disk_space(self):
|
||||
if not check_dataset_disk_space(self.model['data']):
|
||||
raise MemoryError("Not enough disk space")
|
||||
|
||||
# COMMENT: Should not be needed as HUB is now considered an integration and is in integrations_callbacks
|
||||
# import ultralytics.yolo.utils.callbacks.hub as hub_callbacks
|
||||
# @staticmethod
|
||||
# def register_callbacks(trainer):
|
||||
# for k, v in hub_callbacks.callbacks.items():
|
||||
# trainer.add_callback(k, v)
|
||||
|
||||
@threaded
|
||||
def _heartbeats(self):
|
||||
while self.alive:
|
||||
r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
||||
json={
|
||||
"agent": AGENT_NAME,
|
||||
"agentId": self.agent_id},
|
||||
headers=self.auth_header,
|
||||
retry=0,
|
||||
code=5,
|
||||
thread=False)
|
||||
self.agent_id = r.json().get('data', {}).get('agentId', None)
|
||||
sleep(self.rate_limits['heartbeat'])
|
||||
Loading…
Add table
Add a link
Reference in a new issue