standalone val (#56)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia 2022-11-30 15:04:44 +05:30 committed by GitHub
parent 3a241e4cea
commit 5a52e7663a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 161 additions and 31 deletions

View file

@ -25,7 +25,7 @@ import ultralytics.yolo.utils as utils
import ultralytics.yolo.utils.callbacks as callbacks
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
from ultralytics.yolo.utils.checks import print_args
from ultralytics.yolo.utils.checks import check_file, print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
@ -299,13 +299,16 @@ class BaseTrainer:
"""
Get train, val path from data dict if it exists. Returns None if data format is not recognized
"""
return data["train"], data["val"]
return data["train"], data.get("val") or data.get("test")
def get_model(self, model: Union[str, Path]):
"""
load/create/download model for any task
"""
pretrained = not str(model).endswith(".yaml")
pretrained = True
if str(model).endswith(".yaml"):
model = check_file(model)
pretrained = False
return self.load_model(model_cfg=None if pretrained else model,
weights=get_model(model) if pretrained else None,
data=self.data) # model
@ -376,7 +379,7 @@ class BaseTrainer:
"""
To set or update model parameters before training.
"""
pass
self.model.names = self.data["names"]
def build_targets(self, preds, targets):
pass