standalone val (#56)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3a241e4cea
commit
5a52e7663a
16 changed files with 161 additions and 31 deletions
|
|
@ -5,11 +5,14 @@ import torch
|
|||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import TQDM_BAR_FORMAT
|
||||
from ultralytics.yolo.utils import LOGGER, TQDM_BAR_FORMAT
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.modeling import get_model
|
||||
from ultralytics.yolo.utils.modeling.autobackend import AutoBackend
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device
|
||||
from ultralytics.yolo.utils.torch_utils import check_img_size, de_parallel, select_device
|
||||
|
||||
|
||||
class BaseValidator:
|
||||
|
|
@ -17,17 +20,18 @@ class BaseValidator:
|
|||
Base validator class.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.logger = logger or logging.getLogger()
|
||||
self.logger = logger or LOGGER
|
||||
self.args = args or OmegaConf.load(DEFAULT_CONFIG)
|
||||
self.device = select_device(self.args.device, dataloader.batch_size)
|
||||
self.save_dir = save_dir if save_dir is not None else \
|
||||
increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
||||
self.cuda = self.device.type != 'cpu'
|
||||
self.model = None
|
||||
self.data = None
|
||||
self.device = None
|
||||
self.batch_i = None
|
||||
self.training = True
|
||||
self.save_dir = save_dir if save_dir is not None else \
|
||||
increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
||||
|
||||
def __call__(self, trainer=None, model=None):
|
||||
"""
|
||||
|
|
@ -36,14 +40,35 @@ class BaseValidator:
|
|||
"""
|
||||
self.training = trainer is not None
|
||||
if self.training:
|
||||
self.device = trainer.device
|
||||
self.data = trainer.data
|
||||
model = trainer.ema.ema or trainer.model
|
||||
self.args.half &= self.device.type != 'cpu'
|
||||
model = model.half() if self.args.half else model.float()
|
||||
self.model = model
|
||||
loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||
else: # TODO: handle this when detectMultiBackend is supported
|
||||
assert model is not None, "Either trainer or model is needed for validation"
|
||||
# model = DetectMultiBacked(model)
|
||||
# TODO: implement init_model_attributes()
|
||||
self.device = select_device(self.args.device, self.args.batch_size)
|
||||
self.args.half &= self.device.type != 'cpu'
|
||||
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, fp16=self.args.half)
|
||||
self.model = model
|
||||
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
||||
imgsz = check_img_size(self.args.img_size, s=stride)
|
||||
if engine:
|
||||
self.args.batch_size = model.batch_size
|
||||
else:
|
||||
self.device = model.device
|
||||
if not (pt or jit):
|
||||
self.args.batch_size = 1 # export.py models default to batch-size 1
|
||||
self.logger.info(
|
||||
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
|
||||
if self.args.data.endswith(".yaml"):
|
||||
data = check_dataset_yaml(self.args.data)
|
||||
else:
|
||||
data = check_dataset(self.args.data)
|
||||
self.dataloader = self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)
|
||||
|
||||
model.eval()
|
||||
|
||||
|
|
@ -101,6 +126,9 @@ class BaseValidator:
|
|||
return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \
|
||||
if self.training else stats
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
raise Exception("get_dataloder function not implemented for this validator")
|
||||
|
||||
def preprocess(self, batch):
|
||||
return batch
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue