ultralytics 8.1.36 improve train stop robustness epoch + 1 >= self.epochs (#9384)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-03-28 17:28:20 +01:00 committed by GitHub
parent 3aeb058e82
commit ed2250cf1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 4 additions and 3 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.35" __version__ = "8.1.36"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

View file

@ -421,7 +421,7 @@ class BaseTrainer:
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.run_callbacks("on_train_epoch_end") self.run_callbacks("on_train_epoch_end")
if RANK in (-1, 0): if RANK in (-1, 0):
final_epoch = epoch + 1 == self.epochs final_epoch = epoch + 1 >= self.epochs
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
# Validation # Validation

View file

@ -140,7 +140,8 @@ class AutoBackend(nn.Module):
# In-memory PyTorch model # In-memory PyTorch model
if nn_module: if nn_module:
model = weights.to(device) model = weights.to(device)
model = model.fuse(verbose=verbose) if fuse else model if fuse:
model = model.fuse(verbose=verbose)
if hasattr(model, "kpt_shape"): if hasattr(model, "kpt_shape"):
kpt_shape = model.kpt_shape # pose-only kpt_shape = model.kpt_shape # pose-only
stride = max(int(model.stride.max()), 32) # model stride stride = max(int(model.stride.max()), 32) # model stride