ultralytics 8.2.97 robust HUB model downloads (#16347)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-09-19 01:19:22 +02:00 committed by GitHub
parent 7834d19776
commit 6dcc4a0610
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 29 additions and 46 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.96" __version__ = "8.2.97"
import os import os

View file

@ -712,6 +712,7 @@ def entrypoint(debug=""):
"cfg": lambda: yaml_print(DEFAULT_CFG_PATH), "cfg": lambda: yaml_print(DEFAULT_CFG_PATH),
"hub": lambda: handle_yolo_hub(args[1:]), "hub": lambda: handle_yolo_hub(args[1:]),
"login": lambda: handle_yolo_hub(args), "login": lambda: handle_yolo_hub(args),
"logout": lambda: handle_yolo_hub(args),
"copy-cfg": copy_default_cfg, "copy-cfg": copy_default_cfg,
"explorer": lambda: handle_explorer(args[1:]), "explorer": lambda: handle_explorer(args[1:]),
"streamlit-predict": lambda: handle_streamlit_inference(), "streamlit-predict": lambda: handle_streamlit_inference(),

View file

@ -206,33 +206,21 @@ class Model(nn.Module):
Check if the provided model is an Ultralytics HUB model. Check if the provided model is an Ultralytics HUB model.
This static method determines whether the given model string represents a valid Ultralytics HUB model This static method determines whether the given model string represents a valid Ultralytics HUB model
identifier. It checks for three possible formats: a full HUB URL, an API key and model ID combination, identifier.
or a standalone model ID.
Args: Args:
model (str): The model identifier to check. This can be a URL, an API key and model ID model (str): The model string to check.
combination, or a standalone model ID.
Returns: Returns:
(bool): True if the model is a valid Ultralytics HUB model, False otherwise. (bool): True if the model is a valid Ultralytics HUB model, False otherwise.
Examples: Examples:
>>> Model.is_hub_model("https://hub.ultralytics.com/models/example_model") >>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL")
True True
>>> Model.is_hub_model("api_key_example_model_id") >>> Model.is_hub_model("yolov8n.pt")
True
>>> Model.is_hub_model("example_model_id")
True
>>> Model.is_hub_model("not_a_hub_model.pt")
False False
""" """
return any( return model.startswith(f"{HUB_WEB_ROOT}/models/")
(
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL
)
)
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None: def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
""" """

View file

@ -5,6 +5,7 @@ import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path from pathlib import Path
from urllib.parse import parse_qs, urlparse
import requests import requests
@ -77,7 +78,6 @@ class HUBTrainingSession:
if not session.client.authenticated: if not session.client.authenticated:
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"): if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Login to Ultralytics HUB with 'yolo hub login API_KEY'.") LOGGER.warning(f"{PREFIX}WARNING ⚠️ Login to Ultralytics HUB with 'yolo hub login API_KEY'.")
exit()
return None return None
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
session.create_model(args) session.create_model(args)
@ -96,7 +96,8 @@ class HUBTrainingSession:
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
if self.model.is_trained(): if self.model.is_trained():
print(emojis(f"Loading trained HUB model {self.model_url} 🚀")) print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
self.model_file = self.model.get_weights_url("best") url = self.model.get_weights_url("best") # download URL with auth
self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id)
return return
# Set training args and start heartbeats for HUB to monitor agent # Set training args and start heartbeats for HUB to monitor agent
@ -146,9 +147,8 @@ class HUBTrainingSession:
Parses the given identifier to determine the type of identifier and extract relevant components. Parses the given identifier to determine the type of identifier and extract relevant components.
The method supports different identifier formats: The method supports different identifier formats:
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/' - A HUB model URL https://hub.ultralytics.com/models/MODEL
- An identifier containing an API key and a model ID separated by an underscore - A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY
- An identifier that is solely a model ID of a fixed length
- A local filename that ends with '.pt' or '.yaml' - A local filename that ends with '.pt' or '.yaml'
Args: Args:
@ -160,32 +160,26 @@ class HUBTrainingSession:
Raises: Raises:
HUBModelError: If the identifier format is not recognized. HUBModelError: If the identifier format is not recognized.
""" """
# Initialize variables
api_key, model_id, filename = None, None, None api_key, model_id, filename = None, None, None
# Check if identifier is a HUB URL # path = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # parts = path.split("_")
# Extract the model_id after the HUB_WEB_ROOT URL # if Path(path).suffix in {".pt", ".yaml"}:
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1] # filename = path
# elif len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
# api_key, model_id = parts
# elif len(path) == 20:
# model_id = path
if Path(identifier).suffix in {".pt", ".yaml"}:
filename = identifier
elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
parsed_url = urlparse(identifier)
model_id = Path(parsed_url.path).stem # handle possible final backslash robustly
query_params = parse_qs(parsed_url.query) # dictionary, i.e. {"api_key": ["API_KEY_HERE"]}
api_key = query_params.get("api_key", [None])[0]
else: else:
# Split the identifier based on underscores only if it's not a HUB URL raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID")
parts = identifier.split("_")
# Check if identifier is in the format of API key and model ID
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
api_key, model_id = parts
# Check if identifier is a single model ID
elif len(parts) == 1 and len(parts[0]) == 20:
model_id = parts[0]
# Check if identifier is a local filename
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
filename = identifier
else:
raise HUBModelError(
f"model='{identifier}' could not be parsed. Check format is correct. "
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
)
return api_key, model_id, filename return api_key, model_id, filename
def _set_train_args(self): def _set_train_args(self):