Add best.pt val and COCO pycocotools val (#98)

Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2022-12-27 04:56:24 +01:00 committed by GitHub
parent a1808eeda4
commit 6f0ba81427
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 159 additions and 115 deletions

View file

@ -22,6 +22,7 @@ from tqdm import tqdm
import ultralytics.yolo.utils as utils
import ultralytics.yolo.utils.callbacks as callbacks
from ultralytics import __version__
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
from ultralytics.yolo.utils.checks import check_file, print_args
@ -52,7 +53,8 @@ class BaseTrainer:
self.batch_size = self.args.batch_size
self.epochs = self.args.epochs
self.start_epoch = 0
print_args(dict(self.args))
if RANK == -1:
print_args(dict(self.args))
# Save run settings
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
@ -109,7 +111,6 @@ class BaseTrainer:
world_size = torch.cuda.device_count()
if world_size > 1 and "LOCAL_RANK" not in os.environ:
command = generate_ddp_command(world_size, self)
print('DDP command: ', command)
try:
subprocess.run(command)
except Exception as e:
@ -124,7 +125,7 @@ class BaseTrainer:
# os.environ['MASTER_PORT'] = '9020'
torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank)
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
self.console.info(f"DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}")
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
def _setup_train(self, rank, world_size):
@ -259,8 +260,7 @@ class BaseTrainer:
if not self.args.noval or final_epoch:
self.metrics, self.fitness = self.validate()
self.trigger_callbacks('on_val_end')
log_vals = {**self.label_loss_items(self.tloss), **self.metrics, **lr}
self.save_metrics(metrics=log_vals)
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **lr})
# save model
if (not self.args.nosave) or (epoch + 1 == self.epochs):
@ -282,7 +282,6 @@ class BaseTrainer:
self.plot_metrics()
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
self.trigger_callbacks('on_train_end')
dist.destroy_process_group() if world_size > 1 else None
torch.cuda.empty_cache()
self.trigger_callbacks('teardown')
@ -295,7 +294,8 @@ class BaseTrainer:
'updates': self.ema.updates,
'optimizer': self.optimizer.state_dict(),
'train_args': self.args,
'date': datetime.now().isoformat()}
'date': datetime.now().isoformat(),
'version': __version__}
# Save last, best and delete
torch.save(ckpt, self.last)
@ -365,7 +365,7 @@ class BaseTrainer:
if rank in {-1, 0}:
self.console.info(text)
def load_model(self, model_cfg, weights):
def load_model(self, model_cfg=None, weights=None, verbose=True):
raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self):
@ -417,12 +417,14 @@ class BaseTrainer:
pass
def final_eval(self):
# TODO: need standalone evaluator to do this
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
self.console.info(f'\nValidating {f}...')
self.metrics = self.validator(model=f)
self.metrics.pop('fitness', None)
self.trigger_callbacks('on_val_end')
def check_resume(self):
resume = self.args.resume