Fix learning rate gap on resume (#9468)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: EunChan Kim <eunchan@hanyang.ac.kr>
Co-authored-by: Lakshantha Dissanayake <lakshanthad@yahoo.com>
Co-authored-by: RizwanMunawar <chr043416@gmail.com>
Co-authored-by: gs80140 <gs80140@users.noreply.github.com>
This commit is contained in:
Laughing 2024-04-02 17:55:11 +08:00 committed by GitHub
parent e5f4f5c8b9
commit 1e547e60a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 13 additions and 11 deletions

View file

@ -331,6 +331,10 @@ class BaseTrainer:
while True:
self.epoch = epoch
self.run_callbacks("on_train_epoch_start")
with warnings.catch_warnings():
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
self.scheduler.step()
self.model.train()
if RANK != -1:
self.train_loader.sampler.set_epoch(epoch)
@ -426,15 +430,12 @@ class BaseTrainer:
t = time.time()
self.epoch_time = t - self.epoch_time_start
self.epoch_time_start = t
with warnings.catch_warnings():
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
if self.args.time:
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
self._setup_scheduler()
self.scheduler.last_epoch = self.epoch # do not move
self.stop |= epoch >= self.epochs # stop if exceeded epochs
self.scheduler.step()
if self.args.time:
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
self._setup_scheduler()
self.scheduler.last_epoch = self.epoch # do not move
self.stop |= epoch >= self.epochs # stop if exceeded epochs
self.run_callbacks("on_fit_epoch_end")
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors