Fix learning rate gap on resume (#9468)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: EunChan Kim <eunchan@hanyang.ac.kr> Co-authored-by: Lakshantha Dissanayake <lakshanthad@yahoo.com> Co-authored-by: RizwanMunawar <chr043416@gmail.com> Co-authored-by: gs80140 <gs80140@users.noreply.github.com>
This commit is contained in:
parent
e5f4f5c8b9
commit
1e547e60a0
2 changed files with 13 additions and 11 deletions
|
|
@ -331,6 +331,10 @@ class BaseTrainer:
|
|||
while True:
|
||||
self.epoch = epoch
|
||||
self.run_callbacks("on_train_epoch_start")
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
||||
self.scheduler.step()
|
||||
|
||||
self.model.train()
|
||||
if RANK != -1:
|
||||
self.train_loader.sampler.set_epoch(epoch)
|
||||
|
|
@ -426,15 +430,12 @@ class BaseTrainer:
|
|||
t = time.time()
|
||||
self.epoch_time = t - self.epoch_time_start
|
||||
self.epoch_time_start = t
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
||||
if self.args.time:
|
||||
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
||||
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
||||
self._setup_scheduler()
|
||||
self.scheduler.last_epoch = self.epoch # do not move
|
||||
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
||||
self.scheduler.step()
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, __version__
|
||||
from ultralytics.utils.checks import PYTHON_VERSION, check_version
|
||||
|
||||
try:
|
||||
|
|
@ -614,8 +614,9 @@ class EarlyStopping:
|
|||
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
||||
stop = delta >= self.patience # stop training if patience exceeded
|
||||
if stop:
|
||||
prefix = colorstr("EarlyStopping: ")
|
||||
LOGGER.info(
|
||||
f"Stopping training early as no improvement observed in last {self.patience} epochs. "
|
||||
f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
|
||||
f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
|
||||
f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
|
||||
f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue