ultralytics 8.2.97 robust HUB model downloads (#16347)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
7834d19776
commit
6dcc4a0610
4 changed files with 29 additions and 46 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.2.96"
|
||||
__version__ = "8.2.97"
|
||||
|
||||
|
||||
import os
|
||||
|
|
|
|||
|
|
@ -712,6 +712,7 @@ def entrypoint(debug=""):
|
|||
"cfg": lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||
"hub": lambda: handle_yolo_hub(args[1:]),
|
||||
"login": lambda: handle_yolo_hub(args),
|
||||
"logout": lambda: handle_yolo_hub(args),
|
||||
"copy-cfg": copy_default_cfg,
|
||||
"explorer": lambda: handle_explorer(args[1:]),
|
||||
"streamlit-predict": lambda: handle_streamlit_inference(),
|
||||
|
|
|
|||
|
|
@ -206,33 +206,21 @@ class Model(nn.Module):
|
|||
Check if the provided model is an Ultralytics HUB model.
|
||||
|
||||
This static method determines whether the given model string represents a valid Ultralytics HUB model
|
||||
identifier. It checks for three possible formats: a full HUB URL, an API key and model ID combination,
|
||||
or a standalone model ID.
|
||||
identifier.
|
||||
|
||||
Args:
|
||||
model (str): The model identifier to check. This can be a URL, an API key and model ID
|
||||
combination, or a standalone model ID.
|
||||
model (str): The model string to check.
|
||||
|
||||
Returns:
|
||||
(bool): True if the model is a valid Ultralytics HUB model, False otherwise.
|
||||
|
||||
Examples:
|
||||
>>> Model.is_hub_model("https://hub.ultralytics.com/models/example_model")
|
||||
>>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL")
|
||||
True
|
||||
>>> Model.is_hub_model("api_key_example_model_id")
|
||||
True
|
||||
>>> Model.is_hub_model("example_model_id")
|
||||
True
|
||||
>>> Model.is_hub_model("not_a_hub_model.pt")
|
||||
>>> Model.is_hub_model("yolov8n.pt")
|
||||
False
|
||||
"""
|
||||
return any(
|
||||
(
|
||||
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
||||
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL
|
||||
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL
|
||||
)
|
||||
)
|
||||
return model.startswith(f"{HUB_WEB_ROOT}/models/")
|
||||
|
||||
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import threading
|
|||
import time
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -77,7 +78,6 @@ class HUBTrainingSession:
|
|||
if not session.client.authenticated:
|
||||
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Login to Ultralytics HUB with 'yolo hub login API_KEY'.")
|
||||
exit()
|
||||
return None
|
||||
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
|
||||
session.create_model(args)
|
||||
|
|
@ -96,7 +96,8 @@ class HUBTrainingSession:
|
|||
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||
if self.model.is_trained():
|
||||
print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
|
||||
self.model_file = self.model.get_weights_url("best")
|
||||
url = self.model.get_weights_url("best") # download URL with auth
|
||||
self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id)
|
||||
return
|
||||
|
||||
# Set training args and start heartbeats for HUB to monitor agent
|
||||
|
|
@ -146,9 +147,8 @@ class HUBTrainingSession:
|
|||
Parses the given identifier to determine the type of identifier and extract relevant components.
|
||||
|
||||
The method supports different identifier formats:
|
||||
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
|
||||
- An identifier containing an API key and a model ID separated by an underscore
|
||||
- An identifier that is solely a model ID of a fixed length
|
||||
- A HUB model URL https://hub.ultralytics.com/models/MODEL
|
||||
- A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY
|
||||
- A local filename that ends with '.pt' or '.yaml'
|
||||
|
||||
Args:
|
||||
|
|
@ -160,32 +160,26 @@ class HUBTrainingSession:
|
|||
Raises:
|
||||
HUBModelError: If the identifier format is not recognized.
|
||||
"""
|
||||
# Initialize variables
|
||||
api_key, model_id, filename = None, None, None
|
||||
|
||||
# Check if identifier is a HUB URL
|
||||
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
||||
# Extract the model_id after the HUB_WEB_ROOT URL
|
||||
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
|
||||
# path = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
|
||||
# parts = path.split("_")
|
||||
# if Path(path).suffix in {".pt", ".yaml"}:
|
||||
# filename = path
|
||||
# elif len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
||||
# api_key, model_id = parts
|
||||
# elif len(path) == 20:
|
||||
# model_id = path
|
||||
|
||||
if Path(identifier).suffix in {".pt", ".yaml"}:
|
||||
filename = identifier
|
||||
elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
||||
parsed_url = urlparse(identifier)
|
||||
model_id = Path(parsed_url.path).stem # handle possible final backslash robustly
|
||||
query_params = parse_qs(parsed_url.query) # dictionary, i.e. {"api_key": ["API_KEY_HERE"]}
|
||||
api_key = query_params.get("api_key", [None])[0]
|
||||
else:
|
||||
# Split the identifier based on underscores only if it's not a HUB URL
|
||||
parts = identifier.split("_")
|
||||
|
||||
# Check if identifier is in the format of API key and model ID
|
||||
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
||||
api_key, model_id = parts
|
||||
# Check if identifier is a single model ID
|
||||
elif len(parts) == 1 and len(parts[0]) == 20:
|
||||
model_id = parts[0]
|
||||
# Check if identifier is a local filename
|
||||
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
|
||||
filename = identifier
|
||||
else:
|
||||
raise HUBModelError(
|
||||
f"model='{identifier}' could not be parsed. Check format is correct. "
|
||||
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
|
||||
)
|
||||
|
||||
raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID")
|
||||
return api_key, model_id, filename
|
||||
|
||||
def _set_train_args(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue