Threadpool fixes and CLI improvements (#550)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher 2023-01-22 17:08:08 +01:00 committed by GitHub
parent d9a0fba251
commit 21b701c4ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 338 additions and 251 deletions

View file

@ -31,7 +31,8 @@ from ultralytics.yolo.utils.autobatch import check_train_batch_size
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.yolo.utils.files import get_latest_run, increment_path
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle,
strip_optimizer)
class BaseTrainer:
@ -71,15 +72,15 @@ class BaseTrainer:
csv (Path): Path to results CSV file.
"""
def __init__(self, config=DEFAULT_CFG_PATH, overrides=None):
def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None):
"""
Initializes the BaseTrainer class.
Args:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(config, overrides)
self.args = get_cfg(cfg, overrides)
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch)
self.check_resume()
self.console = LOGGER
@ -225,6 +226,7 @@ class BaseTrainer:
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
# dataloaders
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
@ -333,10 +335,12 @@ class BaseTrainer:
# Validation
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
final_epoch = (epoch + 1 == self.epochs)
final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
if self.args.val or final_epoch:
self.metrics, self.fitness = self.validate()
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
self.stop = self.stopper(epoch + 1, self.fitness)
# Save model
if self.args.save or (epoch + 1 == self.epochs):
@ -347,7 +351,15 @@ class BaseTrainer:
self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow
self.run_callbacks("on_fit_epoch_end")
# TODO: termination condition
# Early Stopping
if RANK != -1: # if DDP training
broadcast_list = [self.stop if RANK == 0 else None]
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
if RANK != 0:
self.stop = broadcast_list[0]
if self.stop:
break # must break all DDP ranks
if rank in {-1, 0}:
# Do final val with best.pt