diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 8c16411d..3155de0d 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.1.40" +__version__ = "8.1.41" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 96931fab..f92e815e 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -212,7 +212,7 @@ class BaseTrainer: # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout dist.init_process_group( - "nccl" if dist.is_nccl_available() else "gloo", + backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=10800), # 3 hours rank=RANK, world_size=world_size, @@ -648,8 +648,8 @@ class BaseTrainer: resume = True self.args = get_cfg(ckpt_args) - self.args.model = str(last) # reinstate model - for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM + self.args.model = self.args.resume = str(last) # reinstate model + for k in "imgsz", "batch", "device": # allow arg updates to reduce memory or update device on resume if k in overrides: setattr(self.args, k, overrides[k]) @@ -662,7 +662,7 @@ class BaseTrainer: def resume_training(self, ckpt): """Resume YOLO training from given epoch and best fitness.""" - if ckpt is None: + if ckpt is None or not self.resume: return best_fitness = 0.0 start_epoch = ckpt.get("epoch", -1) + 1 @@ -672,14 +672,11 @@ class BaseTrainer: if self.ema and ckpt.get("ema"): self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA self.ema.updates = ckpt["updates"] - if self.resume: - assert start_epoch > 0, ( - f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n" - f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" - ) - LOGGER.info( - f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs" - ) + assert start_epoch > 0, ( + f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n" + f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" + ) + LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs") if self.epochs < start_epoch: LOGGER.info( f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index fbd00f76..97ab5b49 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -391,7 +391,7 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=() try: t = time.time() assert is_online(), "AutoUpdate skipped (offline)" - with Retry(times=1, delay=1): # retry once on failure after 1 second + with Retry(times=2, delay=1): # run up to 2 times with 1-second retry delay LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode()) dt = time.time() - t LOGGER.info( diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index d5d91e13..96154d60 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -513,7 +513,7 @@ def convert_optimizer_state_dict_to_fp16(state_dict): """ for state in state_dict["state"].values(): for k, v in state.items(): - if isinstance(v, torch.Tensor) and v.dtype is torch.float32: + if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32: state[k] = v.half() return state_dict