diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 7126c40c..57d6ec66 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.65" +__version__ = "8.2.66" import os diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 9dc3ae56..1a855081 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -17,6 +17,7 @@ from ultralytics.utils import ( DEFAULT_CFG_DICT, LOGGER, RANK, + SETTINGS, callbacks, checks, emojis, @@ -286,7 +287,7 @@ class Model(nn.Module): >>> model._load('path/to/weights.pth', task='detect') """ if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): - weights = checks.check_file(weights) # automatically download and return local filename + weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt if Path(weights).suffix == ".pt": diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 9d71810b..48e95679 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -507,7 +507,7 @@ class BaseTrainer: self.last.write_bytes(serialized_ckpt) # save last.pt if self.best_fitness == self.fitness: self.best.write_bytes(serialized_ckpt) # save best.pt - if (self.save_period > 0) and (self.epoch >= 0) and (self.epoch % self.save_period == 0): + if (self.save_period > 0) and (self.epoch % self.save_period == 0): (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt' def get_dataset(self): diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index ddd4d8c1..1423f5f4 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -48,6 +48,7 @@ class HUBTrainingSession: self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py self.model = None self.model_url = None + self.model_file = None # Parse input api_key, model_id, self.filename = self._parse_identifier(identifier) @@ -91,10 +92,13 @@ class HUBTrainingSession: raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" + if self.model.is_trained(): + print(emojis(f"Loading trained HUB model {self.model_url} 🚀")) + self.model_file = self.model.get_weights_url("best") + return + # Set training args and start heartbeats for HUB to monitor agent self._set_train_args() - - # Start heartbeats for HUB to monitor agent self.model.start_heartbeat(self.rate_limits["heartbeat"]) LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") @@ -195,8 +199,6 @@ class HUBTrainingSession: ValueError: If the model is already trained, if required dataset information is missing, or if there are issues with the provided training arguments. """ - if self.model.is_trained(): - raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀")) if self.model.is_resumable(): # Model has saved weights diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index d94e157f..b9bcef3f 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -484,7 +484,7 @@ def check_model_file_from_stem(model="yolov8n"): return model -def check_file(file, suffix="", download=True, hard=True): +def check_file(file, suffix="", download=True, download_dir=".", hard=True): """Search/download file (if necessary) and return path.""" check_suffix(file, suffix) # optional file = str(file).strip() # convert to string and strip spaces @@ -497,12 +497,12 @@ def check_file(file, suffix="", download=True, hard=True): return file elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download url = file # warning: Pathlib turns :// -> :/ - file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth - if Path(file).exists(): + file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth + if file.exists(): LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists else: downloads.safe_download(url=url, file=file, unzip=False) - return file + return str(file) else: # search files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file if not files and hard: