Add max_dim==2 argument to check_imgsz() (#789)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: andreaswimmer <53872150+andreaswimmer@users.noreply.github.com>
Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
This commit is contained in:
Glenn Jocher 2023-02-04 01:48:44 +04:00 committed by GitHub
parent 5a80ad98db
commit 0d182e80f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 96 additions and 52 deletions

View file

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
from pathlib import Path
from typing import List
import sys
from ultralytics import yolo # noqa
@ -9,7 +10,7 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, callbacks, yaml_load
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.checks import check_yaml, check_imgsz
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
# Map head to model, trainer, validator, and predictor classes
@ -131,7 +132,7 @@ class YOLO:
Check the 'configuration' section in the documentation for all available options.
Returns:
(dict): The prediction results.
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
"""
overrides = self.overrides.copy()
overrides["conf"] = 0.25
@ -161,6 +162,7 @@ class YOLO:
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.data = data or args.data
args.task = self.task
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = self.ValidatorClass(args=args)
validator(model=self.model)