ultralytics 8.2.66 HUB model autodownload (#14702)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
1d5d105c62
commit
9130399974
5 changed files with 14 additions and 11 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.2.65"
|
__version__ = "8.2.66"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from ultralytics.utils import (
|
||||||
DEFAULT_CFG_DICT,
|
DEFAULT_CFG_DICT,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
RANK,
|
RANK,
|
||||||
|
SETTINGS,
|
||||||
callbacks,
|
callbacks,
|
||||||
checks,
|
checks,
|
||||||
emojis,
|
emojis,
|
||||||
|
|
@ -286,7 +287,7 @@ class Model(nn.Module):
|
||||||
>>> model._load('path/to/weights.pth', task='detect')
|
>>> model._load('path/to/weights.pth', task='detect')
|
||||||
"""
|
"""
|
||||||
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
|
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
|
||||||
weights = checks.check_file(weights) # automatically download and return local filename
|
weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file
|
||||||
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt
|
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||||
|
|
||||||
if Path(weights).suffix == ".pt":
|
if Path(weights).suffix == ".pt":
|
||||||
|
|
|
||||||
|
|
@ -507,7 +507,7 @@ class BaseTrainer:
|
||||||
self.last.write_bytes(serialized_ckpt) # save last.pt
|
self.last.write_bytes(serialized_ckpt) # save last.pt
|
||||||
if self.best_fitness == self.fitness:
|
if self.best_fitness == self.fitness:
|
||||||
self.best.write_bytes(serialized_ckpt) # save best.pt
|
self.best.write_bytes(serialized_ckpt) # save best.pt
|
||||||
if (self.save_period > 0) and (self.epoch >= 0) and (self.epoch % self.save_period == 0):
|
if (self.save_period > 0) and (self.epoch % self.save_period == 0):
|
||||||
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
||||||
|
|
||||||
def get_dataset(self):
|
def get_dataset(self):
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,7 @@ class HUBTrainingSession:
|
||||||
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
||||||
self.model = None
|
self.model = None
|
||||||
self.model_url = None
|
self.model_url = None
|
||||||
|
self.model_file = None
|
||||||
|
|
||||||
# Parse input
|
# Parse input
|
||||||
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
||||||
|
|
@ -91,10 +92,13 @@ class HUBTrainingSession:
|
||||||
raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling
|
raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling
|
||||||
|
|
||||||
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||||
|
if self.model.is_trained():
|
||||||
|
print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
|
||||||
|
self.model_file = self.model.get_weights_url("best")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Set training args and start heartbeats for HUB to monitor agent
|
||||||
self._set_train_args()
|
self._set_train_args()
|
||||||
|
|
||||||
# Start heartbeats for HUB to monitor agent
|
|
||||||
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
||||||
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||||
|
|
||||||
|
|
@ -195,8 +199,6 @@ class HUBTrainingSession:
|
||||||
ValueError: If the model is already trained, if required dataset information is missing, or if there are
|
ValueError: If the model is already trained, if required dataset information is missing, or if there are
|
||||||
issues with the provided training arguments.
|
issues with the provided training arguments.
|
||||||
"""
|
"""
|
||||||
if self.model.is_trained():
|
|
||||||
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
|
|
||||||
|
|
||||||
if self.model.is_resumable():
|
if self.model.is_resumable():
|
||||||
# Model has saved weights
|
# Model has saved weights
|
||||||
|
|
|
||||||
|
|
@ -484,7 +484,7 @@ def check_model_file_from_stem(model="yolov8n"):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def check_file(file, suffix="", download=True, hard=True):
|
def check_file(file, suffix="", download=True, download_dir=".", hard=True):
|
||||||
"""Search/download file (if necessary) and return path."""
|
"""Search/download file (if necessary) and return path."""
|
||||||
check_suffix(file, suffix) # optional
|
check_suffix(file, suffix) # optional
|
||||||
file = str(file).strip() # convert to string and strip spaces
|
file = str(file).strip() # convert to string and strip spaces
|
||||||
|
|
@ -497,12 +497,12 @@ def check_file(file, suffix="", download=True, hard=True):
|
||||||
return file
|
return file
|
||||||
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
|
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
|
||||||
url = file # warning: Pathlib turns :// -> :/
|
url = file # warning: Pathlib turns :// -> :/
|
||||||
file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
|
file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
|
||||||
if Path(file).exists():
|
if file.exists():
|
||||||
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
|
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
|
||||||
else:
|
else:
|
||||||
downloads.safe_download(url=url, file=file, unzip=False)
|
downloads.safe_download(url=url, file=file, unzip=False)
|
||||||
return file
|
return str(file)
|
||||||
else: # search
|
else: # search
|
||||||
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
|
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
|
||||||
if not files and hard:
|
if not files and hard:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue