Update Validator to use model argument (#4480)
This commit is contained in:
parent
615ddc9d97
commit
b2f279ffdd
7 changed files with 15 additions and 14 deletions
|
|
@ -29,7 +29,7 @@ from tqdm import tqdm
|
|||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
||||
from ultralytics.utils import LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
||||
from ultralytics.utils.checks import check_imgsz
|
||||
from ultralytics.utils.files import increment_path
|
||||
from ultralytics.utils.ops import Profile
|
||||
|
|
@ -43,9 +43,9 @@ class BaseValidator:
|
|||
A base class for creating validators.
|
||||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
dataloader (DataLoader): Dataloader to use for validation.
|
||||
pbar (tqdm): Progress bar to update during validation.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
model (nn.Module): Model to validate.
|
||||
data (dict): Data dictionary.
|
||||
device (torch.device): Device to use for validation.
|
||||
|
|
@ -76,9 +76,9 @@ class BaseValidator:
|
|||
args (SimpleNamespace): Configuration for the validator.
|
||||
_callbacks (dict): Dictionary to store various callback functions.
|
||||
"""
|
||||
self.args = get_cfg(overrides=args)
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.args = args or get_cfg(DEFAULT_CFG)
|
||||
self.model = None
|
||||
self.data = None
|
||||
self.device = None
|
||||
|
|
@ -126,8 +126,7 @@ class BaseValidator:
|
|||
else:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
self.run_callbacks('on_val_start')
|
||||
assert model is not None, 'Either trainer or model is needed for validation'
|
||||
model = AutoBackend(model,
|
||||
model = AutoBackend(model or self.args.model,
|
||||
device=select_device(self.args.device, self.args.batch),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue