YOLOv5 updates (#90)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2022-12-25 14:33:18 +01:00 committed by GitHub
parent ebd3cfb2fd
commit 98815d560f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 281 additions and 161 deletions

View file

@ -1,4 +1,3 @@
import logging
from pathlib import Path
import torch
@ -9,10 +8,9 @@ from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils import LOGGER, TQDM_BAR_FORMAT
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.modeling.autobackend import AutoBackend
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import check_imgsz, de_parallel, select_device
from ultralytics.yolo.utils.torch_utils import check_imgsz, de_parallel, select_device, smart_inference_mode
class BaseValidator:
@ -32,8 +30,9 @@ class BaseValidator:
self.training = True
self.speed = None
self.save_dir = save_dir if save_dir is not None else \
increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
@smart_inference_mode()
def __call__(self, trainer=None, model=None):
"""
Supports validation of a pre-trained model if passed or a model being trained
@ -76,35 +75,34 @@ class BaseValidator:
dt = Profile(), Profile(), Profile(), Profile()
n_batches = len(self.dataloader)
desc = self.get_desc()
# NOTE: keeping this `not self.training` in tqdm will eliminate pbar after finishing segmantation evaluation during training,
# so I removed it, not sure if this will affect classification task cause I saw we use this arg in yolov5/classify/val.py.
# NOTE: keeping `not self.training` in tqdm will eliminate pbar after segmentation evaluation during training,
# which may affect classification task since this arg is in yolov5/classify/val.py.
# bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
self.init_metrics(de_parallel(model))
with torch.no_grad():
for batch_i, batch in enumerate(bar):
self.batch_i = batch_i
# pre-process
with dt[0]:
batch = self.preprocess(batch)
for batch_i, batch in enumerate(bar):
self.batch_i = batch_i
# pre-process
with dt[0]:
batch = self.preprocess(batch)
# inference
with dt[1]:
preds = model(batch["img"])
# inference
with dt[1]:
preds = model(batch["img"])
# loss
with dt[2]:
if self.training:
self.loss += trainer.criterion(preds, batch)[1]
# loss
with dt[2]:
if self.training:
self.loss += trainer.criterion(preds, batch)[1]
# pre-process predictions
with dt[3]:
preds = self.postprocess(preds)
# pre-process predictions
with dt[3]:
preds = self.postprocess(preds)
self.update_metrics(preds, batch)
if self.args.plots and batch_i < 3:
self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i)
self.update_metrics(preds, batch)
if self.args.plots and batch_i < 3:
self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i)
stats = self.get_stats()
self.check_stats(stats)
@ -113,22 +111,21 @@ class BaseValidator:
# calculate speed only once when training
if not self.training or trainer.epoch == 0:
t = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
self.speed = t
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
if not self.training: # print only at inference
self.logger.info(
'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' % t)
if not self.training: # print only at inference
self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
self.speed)
if self.training:
model.float()
# TODO: implement save json
return stats | trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val") \
if self.training else stats
return {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} \
if self.training else stats
def get_dataloader(self, dataset_path, batch_size):
raise Exception("get_dataloder function not implemented for this validator")
raise NotImplementedError("get_dataloader function not implemented for this validator")
def preprocess(self, batch):
return batch