ultralytics 8.2.66 HUB model autodownload (#14702)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-07-26 00:31:41 +02:00 committed by GitHub
parent 1d5d105c62
commit 9130399974
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 14 additions and 11 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.65" __version__ = "8.2.66"
import os import os

View file

@ -17,6 +17,7 @@ from ultralytics.utils import (
DEFAULT_CFG_DICT, DEFAULT_CFG_DICT,
LOGGER, LOGGER,
RANK, RANK,
SETTINGS,
callbacks, callbacks,
checks, checks,
emojis, emojis,
@ -286,7 +287,7 @@ class Model(nn.Module):
>>> model._load('path/to/weights.pth', task='detect') >>> model._load('path/to/weights.pth', task='detect')
""" """
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
weights = checks.check_file(weights) # automatically download and return local filename weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt
if Path(weights).suffix == ".pt": if Path(weights).suffix == ".pt":

View file

@ -507,7 +507,7 @@ class BaseTrainer:
self.last.write_bytes(serialized_ckpt) # save last.pt self.last.write_bytes(serialized_ckpt) # save last.pt
if self.best_fitness == self.fitness: if self.best_fitness == self.fitness:
self.best.write_bytes(serialized_ckpt) # save best.pt self.best.write_bytes(serialized_ckpt) # save best.pt
if (self.save_period > 0) and (self.epoch >= 0) and (self.epoch % self.save_period == 0): if (self.save_period > 0) and (self.epoch % self.save_period == 0):
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt' (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
def get_dataset(self): def get_dataset(self):

View file

@ -48,6 +48,7 @@ class HUBTrainingSession:
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
self.model = None self.model = None
self.model_url = None self.model_url = None
self.model_file = None
# Parse input # Parse input
api_key, model_id, self.filename = self._parse_identifier(identifier) api_key, model_id, self.filename = self._parse_identifier(identifier)
@ -91,10 +92,13 @@ class HUBTrainingSession:
raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
if self.model.is_trained():
print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
self.model_file = self.model.get_weights_url("best")
return
# Set training args and start heartbeats for HUB to monitor agent
self._set_train_args() self._set_train_args()
# Start heartbeats for HUB to monitor agent
self.model.start_heartbeat(self.rate_limits["heartbeat"]) self.model.start_heartbeat(self.rate_limits["heartbeat"])
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
@ -195,8 +199,6 @@ class HUBTrainingSession:
ValueError: If the model is already trained, if required dataset information is missing, or if there are ValueError: If the model is already trained, if required dataset information is missing, or if there are
issues with the provided training arguments. issues with the provided training arguments.
""" """
if self.model.is_trained():
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
if self.model.is_resumable(): if self.model.is_resumable():
# Model has saved weights # Model has saved weights

View file

@ -484,7 +484,7 @@ def check_model_file_from_stem(model="yolov8n"):
return model return model
def check_file(file, suffix="", download=True, hard=True): def check_file(file, suffix="", download=True, download_dir=".", hard=True):
"""Search/download file (if necessary) and return path.""" """Search/download file (if necessary) and return path."""
check_suffix(file, suffix) # optional check_suffix(file, suffix) # optional
file = str(file).strip() # convert to string and strip spaces file = str(file).strip() # convert to string and strip spaces
@ -497,12 +497,12 @@ def check_file(file, suffix="", download=True, hard=True):
return file return file
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
url = file # warning: Pathlib turns :// -> :/ url = file # warning: Pathlib turns :// -> :/
file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
if Path(file).exists(): if file.exists():
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else: else:
downloads.safe_download(url=url, file=file, unzip=False) downloads.safe_download(url=url, file=file, unzip=False)
return file return str(file)
else: # search else: # search
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
if not files and hard: if not files and hard: