ultralytics 8.2.97 robust HUB model downloads (#16347)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
7834d19776
commit
6dcc4a0610
4 changed files with 29 additions and 46 deletions
|
|
@ -5,6 +5,7 @@ import threading
|
|||
import time
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -77,7 +78,6 @@ class HUBTrainingSession:
|
|||
if not session.client.authenticated:
|
||||
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Login to Ultralytics HUB with 'yolo hub login API_KEY'.")
|
||||
exit()
|
||||
return None
|
||||
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
|
||||
session.create_model(args)
|
||||
|
|
@ -96,7 +96,8 @@ class HUBTrainingSession:
|
|||
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||
if self.model.is_trained():
|
||||
print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
|
||||
self.model_file = self.model.get_weights_url("best")
|
||||
url = self.model.get_weights_url("best") # download URL with auth
|
||||
self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id)
|
||||
return
|
||||
|
||||
# Set training args and start heartbeats for HUB to monitor agent
|
||||
|
|
@ -146,9 +147,8 @@ class HUBTrainingSession:
|
|||
Parses the given identifier to determine the type of identifier and extract relevant components.
|
||||
|
||||
The method supports different identifier formats:
|
||||
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
|
||||
- An identifier containing an API key and a model ID separated by an underscore
|
||||
- An identifier that is solely a model ID of a fixed length
|
||||
- A HUB model URL https://hub.ultralytics.com/models/MODEL
|
||||
- A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY
|
||||
- A local filename that ends with '.pt' or '.yaml'
|
||||
|
||||
Args:
|
||||
|
|
@ -160,32 +160,26 @@ class HUBTrainingSession:
|
|||
Raises:
|
||||
HUBModelError: If the identifier format is not recognized.
|
||||
"""
|
||||
# Initialize variables
|
||||
api_key, model_id, filename = None, None, None
|
||||
|
||||
# Check if identifier is a HUB URL
|
||||
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
||||
# Extract the model_id after the HUB_WEB_ROOT URL
|
||||
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
|
||||
# path = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
|
||||
# parts = path.split("_")
|
||||
# if Path(path).suffix in {".pt", ".yaml"}:
|
||||
# filename = path
|
||||
# elif len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
||||
# api_key, model_id = parts
|
||||
# elif len(path) == 20:
|
||||
# model_id = path
|
||||
|
||||
if Path(identifier).suffix in {".pt", ".yaml"}:
|
||||
filename = identifier
|
||||
elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
||||
parsed_url = urlparse(identifier)
|
||||
model_id = Path(parsed_url.path).stem # handle possible final backslash robustly
|
||||
query_params = parse_qs(parsed_url.query) # dictionary, i.e. {"api_key": ["API_KEY_HERE"]}
|
||||
api_key = query_params.get("api_key", [None])[0]
|
||||
else:
|
||||
# Split the identifier based on underscores only if it's not a HUB URL
|
||||
parts = identifier.split("_")
|
||||
|
||||
# Check if identifier is in the format of API key and model ID
|
||||
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
||||
api_key, model_id = parts
|
||||
# Check if identifier is a single model ID
|
||||
elif len(parts) == 1 and len(parts[0]) == 20:
|
||||
model_id = parts[0]
|
||||
# Check if identifier is a local filename
|
||||
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
|
||||
filename = identifier
|
||||
else:
|
||||
raise HUBModelError(
|
||||
f"model='{identifier}' could not be parsed. Check format is correct. "
|
||||
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
|
||||
)
|
||||
|
||||
raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID")
|
||||
return api_key, model_id, filename
|
||||
|
||||
def _set_train_args(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue