ultralytics 8.0.239 Ultralytics Actions and hub-sdk adoption (#7431)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
parent
e795277391
commit
fe27db2f6e
139 changed files with 6870 additions and 5125 deletions
|
|
@ -12,16 +12,13 @@ from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM
|
|||
from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
|
||||
from ultralytics.utils.errors import HUBModelError
|
||||
|
||||
AGENT_NAME = (f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local')
|
||||
AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
|
||||
|
||||
|
||||
class HUBTrainingSession:
|
||||
"""
|
||||
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
||||
|
||||
Args:
|
||||
url (str): Model identifier used to initialize the HUB training session.
|
||||
|
||||
Attributes:
|
||||
agent_id (str): Identifier for the instance communicating with the server.
|
||||
model_id (str): Identifier for the YOLO model being trained.
|
||||
|
|
@ -40,17 +37,18 @@ class HUBTrainingSession:
|
|||
Initialize the HUBTrainingSession with the provided model identifier.
|
||||
|
||||
Args:
|
||||
url (str): Model identifier used to initialize the HUB training session.
|
||||
It can be a URL string or a model key with specific format.
|
||||
identifier (str): Model identifier used to initialize the HUB training session.
|
||||
It can be a URL string or a model key with specific format.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided model identifier is invalid.
|
||||
ConnectionError: If connecting with global API key is not supported.
|
||||
"""
|
||||
self.rate_limits = {
|
||||
'metrics': 3.0,
|
||||
'ckpt': 900.0,
|
||||
'heartbeat': 300.0, } # rate limits (seconds)
|
||||
"metrics": 3.0,
|
||||
"ckpt": 900.0,
|
||||
"heartbeat": 300.0,
|
||||
} # rate limits (seconds)
|
||||
self.metrics_queue = {} # holds metrics for each epoch until upload
|
||||
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
||||
|
||||
|
|
@ -58,8 +56,8 @@ class HUBTrainingSession:
|
|||
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
||||
|
||||
# Get credentials
|
||||
active_key = api_key or SETTINGS.get('api_key')
|
||||
credentials = {'api_key': active_key} if active_key else None # set credentials
|
||||
active_key = api_key or SETTINGS.get("api_key")
|
||||
credentials = {"api_key": active_key} if active_key else None # set credentials
|
||||
|
||||
# Initialize client
|
||||
self.client = HUBClient(credentials)
|
||||
|
|
@ -72,35 +70,37 @@ class HUBTrainingSession:
|
|||
def load_model(self, model_id):
|
||||
# Initialize model
|
||||
self.model = self.client.model(model_id)
|
||||
self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
|
||||
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||
|
||||
self._set_train_args()
|
||||
|
||||
# Start heartbeats for HUB to monitor agent
|
||||
self.model.start_heartbeat(self.rate_limits['heartbeat'])
|
||||
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
|
||||
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
||||
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||
|
||||
def create_model(self, model_args):
|
||||
# Initialize model
|
||||
payload = {
|
||||
'config': {
|
||||
'batchSize': model_args.get('batch', -1),
|
||||
'epochs': model_args.get('epochs', 300),
|
||||
'imageSize': model_args.get('imgsz', 640),
|
||||
'patience': model_args.get('patience', 100),
|
||||
'device': model_args.get('device', ''),
|
||||
'cache': model_args.get('cache', 'ram'), },
|
||||
'dataset': {
|
||||
'name': model_args.get('data')},
|
||||
'lineage': {
|
||||
'architecture': {
|
||||
'name': self.filename.replace('.pt', '').replace('.yaml', ''), },
|
||||
'parent': {}, },
|
||||
'meta': {
|
||||
'name': self.filename}, }
|
||||
"config": {
|
||||
"batchSize": model_args.get("batch", -1),
|
||||
"epochs": model_args.get("epochs", 300),
|
||||
"imageSize": model_args.get("imgsz", 640),
|
||||
"patience": model_args.get("patience", 100),
|
||||
"device": model_args.get("device", ""),
|
||||
"cache": model_args.get("cache", "ram"),
|
||||
},
|
||||
"dataset": {"name": model_args.get("data")},
|
||||
"lineage": {
|
||||
"architecture": {
|
||||
"name": self.filename.replace(".pt", "").replace(".yaml", ""),
|
||||
},
|
||||
"parent": {},
|
||||
},
|
||||
"meta": {"name": self.filename},
|
||||
}
|
||||
|
||||
if self.filename.endswith('.pt'):
|
||||
payload['lineage']['parent']['name'] = self.filename
|
||||
if self.filename.endswith(".pt"):
|
||||
payload["lineage"]["parent"]["name"] = self.filename
|
||||
|
||||
self.model.create_model(payload)
|
||||
|
||||
|
|
@ -109,12 +109,12 @@ class HUBTrainingSession:
|
|||
if not self.model.id:
|
||||
return
|
||||
|
||||
self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
|
||||
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||
|
||||
# Start heartbeats for HUB to monitor agent
|
||||
self.model.start_heartbeat(self.rate_limits['heartbeat'])
|
||||
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
||||
|
||||
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
|
||||
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||
|
||||
def _parse_identifier(self, identifier):
|
||||
"""
|
||||
|
|
@ -125,13 +125,13 @@ class HUBTrainingSession:
|
|||
- An identifier containing an API key and a model ID separated by an underscore
|
||||
- An identifier that is solely a model ID of a fixed length
|
||||
- A local filename that ends with '.pt' or '.yaml'
|
||||
|
||||
|
||||
Args:
|
||||
identifier (str): The identifier string to be parsed.
|
||||
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing the API key, model ID, and filename as applicable.
|
||||
|
||||
|
||||
Raises:
|
||||
HUBModelError: If the identifier format is not recognized.
|
||||
"""
|
||||
|
|
@ -140,12 +140,12 @@ class HUBTrainingSession:
|
|||
api_key, model_id, filename = None, None, None
|
||||
|
||||
# Check if identifier is a HUB URL
|
||||
if identifier.startswith(f'{HUB_WEB_ROOT}/models/'):
|
||||
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
||||
# Extract the model_id after the HUB_WEB_ROOT URL
|
||||
model_id = identifier.split(f'{HUB_WEB_ROOT}/models/')[-1]
|
||||
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
|
||||
else:
|
||||
# Split the identifier based on underscores only if it's not a HUB URL
|
||||
parts = identifier.split('_')
|
||||
parts = identifier.split("_")
|
||||
|
||||
# Check if identifier is in the format of API key and model ID
|
||||
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
||||
|
|
@ -154,43 +154,46 @@ class HUBTrainingSession:
|
|||
elif len(parts) == 1 and len(parts[0]) == 20:
|
||||
model_id = parts[0]
|
||||
# Check if identifier is a local filename
|
||||
elif identifier.endswith('.pt') or identifier.endswith('.yaml'):
|
||||
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
|
||||
filename = identifier
|
||||
else:
|
||||
raise HUBModelError(
|
||||
f"model='{identifier}' could not be parsed. Check format is correct. "
|
||||
f'Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file.')
|
||||
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
|
||||
)
|
||||
|
||||
return api_key, model_id, filename
|
||||
|
||||
def _set_train_args(self, **kwargs):
|
||||
if self.model.is_trained():
|
||||
# Model is already trained
|
||||
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
|
||||
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
|
||||
|
||||
if self.model.is_resumable():
|
||||
# Model has saved weights
|
||||
self.train_args = {'data': self.model.get_dataset_url(), 'resume': True}
|
||||
self.model_file = self.model.get_weights_url('last')
|
||||
self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
|
||||
self.model_file = self.model.get_weights_url("last")
|
||||
else:
|
||||
# Model has no saved weights
|
||||
def get_train_args(config):
|
||||
return {
|
||||
'batch': config['batchSize'],
|
||||
'epochs': config['epochs'],
|
||||
'imgsz': config['imageSize'],
|
||||
'patience': config['patience'],
|
||||
'device': config['device'],
|
||||
'cache': config['cache'],
|
||||
'data': self.model.get_dataset_url(), }
|
||||
"batch": config["batchSize"],
|
||||
"epochs": config["epochs"],
|
||||
"imgsz": config["imageSize"],
|
||||
"patience": config["patience"],
|
||||
"device": config["device"],
|
||||
"cache": config["cache"],
|
||||
"data": self.model.get_dataset_url(),
|
||||
}
|
||||
|
||||
self.train_args = get_train_args(self.model.data.get('config'))
|
||||
self.train_args = get_train_args(self.model.data.get("config"))
|
||||
# Set the model file as either a *.pt or *.yaml file
|
||||
self.model_file = (self.model.get_weights_url('parent')
|
||||
if self.model.is_pretrained() else self.model.get_architecture())
|
||||
self.model_file = (
|
||||
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
|
||||
)
|
||||
|
||||
if not self.train_args.get('data'):
|
||||
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
|
||||
if not self.train_args.get("data"):
|
||||
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix
|
||||
|
||||
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
||||
self.model_id = self.model.id
|
||||
|
|
@ -206,12 +209,11 @@ class HUBTrainingSession:
|
|||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
def retry_request():
|
||||
t0 = time.time() # Record the start time for the timeout
|
||||
for i in range(retry + 1):
|
||||
if (time.time() - t0) > timeout:
|
||||
LOGGER.warning(f'{PREFIX}Timeout for request reached. {HELP_MSG}')
|
||||
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
|
||||
break # Timeout reached, exit loop
|
||||
|
||||
response = request_func(*args, **kwargs)
|
||||
|
|
@ -219,8 +221,8 @@ class HUBTrainingSession:
|
|||
self._show_upload_progress(progress_total, response)
|
||||
|
||||
if response is None:
|
||||
LOGGER.warning(f'{PREFIX}Received no response from the request. {HELP_MSG}')
|
||||
time.sleep(2 ** i) # Exponential backoff before retrying
|
||||
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
|
||||
time.sleep(2**i) # Exponential backoff before retrying
|
||||
continue # Skip further processing and retry
|
||||
|
||||
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
||||
|
|
@ -231,13 +233,13 @@ class HUBTrainingSession:
|
|||
message = self._get_failure_message(response, retry, timeout)
|
||||
|
||||
if verbose:
|
||||
LOGGER.warning(f'{PREFIX}{message} {HELP_MSG} ({response.status_code})')
|
||||
LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")
|
||||
|
||||
if not self._should_retry(response.status_code):
|
||||
LOGGER.warning(f'{PREFIX}Request failed. {HELP_MSG} ({response.status_code}')
|
||||
LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
|
||||
break # Not an error that should be retried, exit loop
|
||||
|
||||
time.sleep(2 ** i) # Exponential backoff for retries
|
||||
time.sleep(2**i) # Exponential backoff for retries
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -253,7 +255,8 @@ class HUBTrainingSession:
|
|||
retry_codes = {
|
||||
HTTPStatus.REQUEST_TIMEOUT,
|
||||
HTTPStatus.BAD_GATEWAY,
|
||||
HTTPStatus.GATEWAY_TIMEOUT, }
|
||||
HTTPStatus.GATEWAY_TIMEOUT,
|
||||
}
|
||||
return True if status_code in retry_codes else False
|
||||
|
||||
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
|
||||
|
|
@ -269,16 +272,18 @@ class HUBTrainingSession:
|
|||
str: The retry message.
|
||||
"""
|
||||
if self._should_retry(response.status_code):
|
||||
return f'Retrying {retry}x for {timeout}s.' if retry else ''
|
||||
return f"Retrying {retry}x for {timeout}s." if retry else ""
|
||||
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
|
||||
headers = response.headers
|
||||
return (f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
|
||||
f"Please retry after {headers['Retry-After']}s.")
|
||||
return (
|
||||
f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
|
||||
f"Please retry after {headers['Retry-After']}s."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return response.json().get('message', 'No JSON message.')
|
||||
return response.json().get("message", "No JSON message.")
|
||||
except AttributeError:
|
||||
return 'Unable to read JSON.'
|
||||
return "Unable to read JSON."
|
||||
|
||||
def upload_metrics(self):
|
||||
"""Upload model metrics to Ultralytics HUB."""
|
||||
|
|
@ -303,7 +308,7 @@ class HUBTrainingSession:
|
|||
final (bool): Indicates if the model is the final model after training.
|
||||
"""
|
||||
if Path(weights).is_file():
|
||||
progress_total = (Path(weights).stat().st_size if final else None) # Only show progress if final
|
||||
progress_total = Path(weights).stat().st_size if final else None # Only show progress if final
|
||||
self.request_queue(
|
||||
self.model.upload_model,
|
||||
epoch=epoch,
|
||||
|
|
@ -317,7 +322,7 @@ class HUBTrainingSession:
|
|||
progress_total=progress_total,
|
||||
)
|
||||
else:
|
||||
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
|
||||
|
||||
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
|
||||
"""
|
||||
|
|
@ -330,6 +335,6 @@ class HUBTrainingSession:
|
|||
Returns:
|
||||
(None)
|
||||
"""
|
||||
with TQDM(total=content_length, unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
||||
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
pbar.update(len(data))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue