ultralytics 8.1.41 DDP resume untrained-checkpoint fix (#9453)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
Glenn Jocher 2024-04-01 19:46:04 +02:00 committed by GitHub
parent 2cee8893d9
commit 959acf67db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 12 additions and 15 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.40" __version__ = "8.1.41"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

View file

@ -212,7 +212,7 @@ class BaseTrainer:
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
dist.init_process_group( dist.init_process_group(
"nccl" if dist.is_nccl_available() else "gloo", backend="nccl" if dist.is_nccl_available() else "gloo",
timeout=timedelta(seconds=10800), # 3 hours timeout=timedelta(seconds=10800), # 3 hours
rank=RANK, rank=RANK,
world_size=world_size, world_size=world_size,
@ -648,8 +648,8 @@ class BaseTrainer:
resume = True resume = True
self.args = get_cfg(ckpt_args) self.args = get_cfg(ckpt_args)
self.args.model = str(last) # reinstate model self.args.model = self.args.resume = str(last) # reinstate model
for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM for k in "imgsz", "batch", "device": # allow arg updates to reduce memory or update device on resume
if k in overrides: if k in overrides:
setattr(self.args, k, overrides[k]) setattr(self.args, k, overrides[k])
@ -662,7 +662,7 @@ class BaseTrainer:
def resume_training(self, ckpt): def resume_training(self, ckpt):
"""Resume YOLO training from given epoch and best fitness.""" """Resume YOLO training from given epoch and best fitness."""
if ckpt is None: if ckpt is None or not self.resume:
return return
best_fitness = 0.0 best_fitness = 0.0
start_epoch = ckpt.get("epoch", -1) + 1 start_epoch = ckpt.get("epoch", -1) + 1
@ -672,14 +672,11 @@ class BaseTrainer:
if self.ema and ckpt.get("ema"): if self.ema and ckpt.get("ema"):
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
self.ema.updates = ckpt["updates"] self.ema.updates = ckpt["updates"]
if self.resume: assert start_epoch > 0, (
assert start_epoch > 0, ( f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n" f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" )
) LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
LOGGER.info(
f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
)
if self.epochs < start_epoch: if self.epochs < start_epoch:
LOGGER.info( LOGGER.info(
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."

View file

@ -391,7 +391,7 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
try: try:
t = time.time() t = time.time()
assert is_online(), "AutoUpdate skipped (offline)" assert is_online(), "AutoUpdate skipped (offline)"
with Retry(times=1, delay=1): # retry once on failure after 1 second with Retry(times=2, delay=1): # run up to 2 times with 1-second retry delay
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode()) LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
dt = time.time() - t dt = time.time() - t
LOGGER.info( LOGGER.info(

View file

@ -513,7 +513,7 @@ def convert_optimizer_state_dict_to_fp16(state_dict):
""" """
for state in state_dict["state"].values(): for state in state_dict["state"].values():
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor) and v.dtype is torch.float32: if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
state[k] = v.half() state[k] = v.half()
return state_dict return state_dict