Threadpool fixes and CLI improvements (#550)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
d9a0fba251
commit
21b701c4ea
22 changed files with 338 additions and 251 deletions
|
|
@ -31,7 +31,8 @@ from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
|||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
||||
from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle,
|
||||
strip_optimizer)
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
|
|
@ -71,15 +72,15 @@ class BaseTrainer:
|
|||
csv (Path): Path to results CSV file.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CFG_PATH, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None):
|
||||
"""
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(config, overrides)
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch)
|
||||
self.check_resume()
|
||||
self.console = LOGGER
|
||||
|
|
@ -225,6 +226,7 @@ class BaseTrainer:
|
|||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||
|
||||
# dataloaders
|
||||
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
|
||||
|
|
@ -333,10 +335,12 @@ class BaseTrainer:
|
|||
|
||||
# Validation
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
final_epoch = (epoch + 1 == self.epochs)
|
||||
final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
|
||||
|
||||
if self.args.val or final_epoch:
|
||||
self.metrics, self.fitness = self.validate()
|
||||
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
||||
self.stop = self.stopper(epoch + 1, self.fitness)
|
||||
|
||||
# Save model
|
||||
if self.args.save or (epoch + 1 == self.epochs):
|
||||
|
|
@ -347,7 +351,15 @@ class BaseTrainer:
|
|||
self.epoch_time = tnow - self.epoch_time_start
|
||||
self.epoch_time_start = tnow
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
# TODO: termination condition
|
||||
|
||||
# Early Stopping
|
||||
if RANK != -1: # if DDP training
|
||||
broadcast_list = [self.stop if RANK == 0 else None]
|
||||
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||
if RANK != 0:
|
||||
self.stop = broadcast_list[0]
|
||||
if self.stop:
|
||||
break # must break all DDP ranks
|
||||
|
||||
if rank in {-1, 0}:
|
||||
# Do final val with best.pt
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue