Update Validator to use model argument (#4480)

This commit is contained in:
Glenn Jocher 2023-08-21 19:21:55 +02:00 committed by GitHub
parent 615ddc9d97
commit b2f279ffdd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 15 additions and 14 deletions

View file

@ -29,7 +29,7 @@ from tqdm import tqdm
from ultralytics.cfg import get_cfg
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
from ultralytics.utils import LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.files import increment_path
from ultralytics.utils.ops import Profile
@ -43,9 +43,9 @@ class BaseValidator:
A base class for creating validators.
Attributes:
args (SimpleNamespace): Configuration for the validator.
dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation.
args (SimpleNamespace): Configuration for the validator.
model (nn.Module): Model to validate.
data (dict): Data dictionary.
device (torch.device): Device to use for validation.
@ -76,9 +76,9 @@ class BaseValidator:
args (SimpleNamespace): Configuration for the validator.
_callbacks (dict): Dictionary to store various callback functions.
"""
self.args = get_cfg(overrides=args)
self.dataloader = dataloader
self.pbar = pbar
self.args = args or get_cfg(DEFAULT_CFG)
self.model = None
self.data = None
self.device = None
@ -126,8 +126,7 @@ class BaseValidator:
else:
callbacks.add_integration_callbacks(self)
self.run_callbacks('on_val_start')
assert model is not None, 'Either trainer or model is needed for validation'
model = AutoBackend(model,
model = AutoBackend(model or self.args.model,
device=select_device(self.args.device, self.args.batch),
dnn=self.args.dnn,
data=self.args.data,