ultralytics 8.0.155 allow imgsz and batch resume changes (#4366)

Co-authored-by: Mostafa Nemati <58460889+monemati@users.noreply.github.com>
Co-authored-by: Eduard Voiculescu <eduardvoiculescu95@gmail.com>
This commit is contained in:
Glenn Jocher 2023-08-15 22:02:23 +02:00 committed by GitHub
parent 60cad0c592
commit 9a0555eca4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 84 additions and 105 deletions

View file

@ -81,7 +81,7 @@ class BaseTrainer:
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
self.check_resume()
self.check_resume(overrides)
self.device = select_device(self.args.device, self.args.batch)
self.validator = None
self.model = None
@ -576,7 +576,7 @@ class BaseTrainer:
self.metrics.pop('fitness', None)
self.run_callbacks('on_fit_epoch_end')
def check_resume(self):
def check_resume(self, overrides):
"""Check if resume checkpoint exists and update arguments accordingly."""
resume = self.args.resume
if resume:
@ -589,8 +589,13 @@ class BaseTrainer:
if not Path(ckpt_args['data']).exists():
ckpt_args['data'] = self.args.data
resume = True
self.args = get_cfg(ckpt_args)
self.args.model, resume = str(last), True # reinstate
self.args.model = str(last) # reinstate model
for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
if k in overrides:
setattr(self.args, k, overrides[k])
except Exception as e:
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
"i.e. 'yolo train resume model=path/to/last.pt'") from e