ultralytics 8.2.97 robust HUB model downloads (#16347)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
7834d19776
commit
6dcc4a0610
4 changed files with 29 additions and 46 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.2.96"
|
__version__ = "8.2.97"
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -712,6 +712,7 @@ def entrypoint(debug=""):
|
||||||
"cfg": lambda: yaml_print(DEFAULT_CFG_PATH),
|
"cfg": lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||||
"hub": lambda: handle_yolo_hub(args[1:]),
|
"hub": lambda: handle_yolo_hub(args[1:]),
|
||||||
"login": lambda: handle_yolo_hub(args),
|
"login": lambda: handle_yolo_hub(args),
|
||||||
|
"logout": lambda: handle_yolo_hub(args),
|
||||||
"copy-cfg": copy_default_cfg,
|
"copy-cfg": copy_default_cfg,
|
||||||
"explorer": lambda: handle_explorer(args[1:]),
|
"explorer": lambda: handle_explorer(args[1:]),
|
||||||
"streamlit-predict": lambda: handle_streamlit_inference(),
|
"streamlit-predict": lambda: handle_streamlit_inference(),
|
||||||
|
|
|
||||||
|
|
@ -206,33 +206,21 @@ class Model(nn.Module):
|
||||||
Check if the provided model is an Ultralytics HUB model.
|
Check if the provided model is an Ultralytics HUB model.
|
||||||
|
|
||||||
This static method determines whether the given model string represents a valid Ultralytics HUB model
|
This static method determines whether the given model string represents a valid Ultralytics HUB model
|
||||||
identifier. It checks for three possible formats: a full HUB URL, an API key and model ID combination,
|
identifier.
|
||||||
or a standalone model ID.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (str): The model identifier to check. This can be a URL, an API key and model ID
|
model (str): The model string to check.
|
||||||
combination, or a standalone model ID.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(bool): True if the model is a valid Ultralytics HUB model, False otherwise.
|
(bool): True if the model is a valid Ultralytics HUB model, False otherwise.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> Model.is_hub_model("https://hub.ultralytics.com/models/example_model")
|
>>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL")
|
||||||
True
|
True
|
||||||
>>> Model.is_hub_model("api_key_example_model_id")
|
>>> Model.is_hub_model("yolov8n.pt")
|
||||||
True
|
|
||||||
>>> Model.is_hub_model("example_model_id")
|
|
||||||
True
|
|
||||||
>>> Model.is_hub_model("not_a_hub_model.pt")
|
|
||||||
False
|
False
|
||||||
"""
|
"""
|
||||||
return any(
|
return model.startswith(f"{HUB_WEB_ROOT}/models/")
|
||||||
(
|
|
||||||
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
|
||||||
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL
|
|
||||||
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
|
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import threading
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
@ -77,7 +78,6 @@ class HUBTrainingSession:
|
||||||
if not session.client.authenticated:
|
if not session.client.authenticated:
|
||||||
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
||||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Login to Ultralytics HUB with 'yolo hub login API_KEY'.")
|
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Login to Ultralytics HUB with 'yolo hub login API_KEY'.")
|
||||||
exit()
|
|
||||||
return None
|
return None
|
||||||
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
|
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
|
||||||
session.create_model(args)
|
session.create_model(args)
|
||||||
|
|
@ -96,7 +96,8 @@ class HUBTrainingSession:
|
||||||
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||||
if self.model.is_trained():
|
if self.model.is_trained():
|
||||||
print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
|
print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
|
||||||
self.model_file = self.model.get_weights_url("best")
|
url = self.model.get_weights_url("best") # download URL with auth
|
||||||
|
self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Set training args and start heartbeats for HUB to monitor agent
|
# Set training args and start heartbeats for HUB to monitor agent
|
||||||
|
|
@ -146,9 +147,8 @@ class HUBTrainingSession:
|
||||||
Parses the given identifier to determine the type of identifier and extract relevant components.
|
Parses the given identifier to determine the type of identifier and extract relevant components.
|
||||||
|
|
||||||
The method supports different identifier formats:
|
The method supports different identifier formats:
|
||||||
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
|
- A HUB model URL https://hub.ultralytics.com/models/MODEL
|
||||||
- An identifier containing an API key and a model ID separated by an underscore
|
- A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY
|
||||||
- An identifier that is solely a model ID of a fixed length
|
|
||||||
- A local filename that ends with '.pt' or '.yaml'
|
- A local filename that ends with '.pt' or '.yaml'
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -160,32 +160,26 @@ class HUBTrainingSession:
|
||||||
Raises:
|
Raises:
|
||||||
HUBModelError: If the identifier format is not recognized.
|
HUBModelError: If the identifier format is not recognized.
|
||||||
"""
|
"""
|
||||||
# Initialize variables
|
|
||||||
api_key, model_id, filename = None, None, None
|
api_key, model_id, filename = None, None, None
|
||||||
|
|
||||||
# Check if identifier is a HUB URL
|
# path = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
|
||||||
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
# parts = path.split("_")
|
||||||
# Extract the model_id after the HUB_WEB_ROOT URL
|
# if Path(path).suffix in {".pt", ".yaml"}:
|
||||||
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
|
# filename = path
|
||||||
|
# elif len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
||||||
|
# api_key, model_id = parts
|
||||||
|
# elif len(path) == 20:
|
||||||
|
# model_id = path
|
||||||
|
|
||||||
|
if Path(identifier).suffix in {".pt", ".yaml"}:
|
||||||
|
filename = identifier
|
||||||
|
elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
||||||
|
parsed_url = urlparse(identifier)
|
||||||
|
model_id = Path(parsed_url.path).stem # handle possible final backslash robustly
|
||||||
|
query_params = parse_qs(parsed_url.query) # dictionary, i.e. {"api_key": ["API_KEY_HERE"]}
|
||||||
|
api_key = query_params.get("api_key", [None])[0]
|
||||||
else:
|
else:
|
||||||
# Split the identifier based on underscores only if it's not a HUB URL
|
raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID")
|
||||||
parts = identifier.split("_")
|
|
||||||
|
|
||||||
# Check if identifier is in the format of API key and model ID
|
|
||||||
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
|
||||||
api_key, model_id = parts
|
|
||||||
# Check if identifier is a single model ID
|
|
||||||
elif len(parts) == 1 and len(parts[0]) == 20:
|
|
||||||
model_id = parts[0]
|
|
||||||
# Check if identifier is a local filename
|
|
||||||
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
|
|
||||||
filename = identifier
|
|
||||||
else:
|
|
||||||
raise HUBModelError(
|
|
||||||
f"model='{identifier}' could not be parsed. Check format is correct. "
|
|
||||||
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
|
|
||||||
)
|
|
||||||
|
|
||||||
return api_key, model_id, filename
|
return api_key, model_id, filename
|
||||||
|
|
||||||
def _set_train_args(self):
|
def _set_train_args(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue