standalone val (#56)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3a241e4cea
commit
5a52e7663a
16 changed files with 161 additions and 31 deletions
|
|
@ -25,7 +25,7 @@ import ultralytics.yolo.utils as utils
|
|||
import ultralytics.yolo.utils.callbacks as callbacks
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
||||
from ultralytics.yolo.utils.checks import print_args
|
||||
from ultralytics.yolo.utils.checks import check_file, print_args
|
||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||
from ultralytics.yolo.utils.modeling import get_model
|
||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
||||
|
|
@ -299,13 +299,16 @@ class BaseTrainer:
|
|||
"""
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized
|
||||
"""
|
||||
return data["train"], data["val"]
|
||||
return data["train"], data.get("val") or data.get("test")
|
||||
|
||||
def get_model(self, model: Union[str, Path]):
|
||||
"""
|
||||
load/create/download model for any task
|
||||
"""
|
||||
pretrained = not str(model).endswith(".yaml")
|
||||
pretrained = True
|
||||
if str(model).endswith(".yaml"):
|
||||
model = check_file(model)
|
||||
pretrained = False
|
||||
return self.load_model(model_cfg=None if pretrained else model,
|
||||
weights=get_model(model) if pretrained else None,
|
||||
data=self.data) # model
|
||||
|
|
@ -376,7 +379,7 @@ class BaseTrainer:
|
|||
"""
|
||||
To set or update model parameters before training.
|
||||
"""
|
||||
pass
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def build_targets(self, preds, targets):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue