imgsz warning fix, download function consolidation (#681)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: HaeJin Lee <seareale@gmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher 2023-01-29 02:31:37 +01:00 committed by GitHub
parent 0609561549
commit 899abe9f82
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 171 additions and 147 deletions

View file

@ -17,9 +17,10 @@ import pkg_resources as pkg
import psutil
import torch
from IPython import display
from matplotlib import font_manager
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads,
emojis, is_colab, is_docker, is_jupyter)
from ultralytics.yolo.utils import (AUTOINSTALL, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads, emojis,
is_colab, is_docker, is_jupyter)
def is_ascii(s) -> bool:
@ -57,15 +58,14 @@ def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
# Convert image size to list if it is an integer
if isinstance(imgsz, int):
imgsz = [imgsz]
imgsz = [imgsz] if isinstance(imgsz, int) else list(imgsz)
# Make image size a multiple of the stride
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
# Print warning message if image size was updated
if sz != imgsz:
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}')
LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}')
# Add missing dimensions if necessary
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
@ -104,26 +104,33 @@ def check_version(current: str = "0.0.0",
return result
def check_font(font: str = FONT, progress: bool = False) -> None:
def check_font(font='Arial.ttf'):
"""
Download font file to the user's configuration directory if it does not already exist.
Find font locally or download to user's configuration directory if it does not already exist.
Args:
font (str): Path to font file.
progress (bool): If True, display a progress bar during the download.
font (str): Path or name of font.
Returns:
None
file (Path): Resolved font file path.
"""
font = Path(font)
name = Path(font).name
# Destination path for the font file
file = USER_CONFIG_DIR / font.name
# Check USER_CONFIG_DIR
file = USER_CONFIG_DIR / name
if file.exists():
return file
# Check if font file exists at the source or destination path
if not font.exists() and not file.exists():
# Download font file
downloads.safe_download(file=file, url=f'https://ultralytics.com/assets/{font.name}', progress=progress)
# Check system fonts
matches = [s for s in font_manager.findSystemFonts() if font in s]
if any(matches):
return matches[0]
# Download to USER_CONFIG_DIR if missing
url = f'https://ultralytics.com/assets/{name}'
if downloads.is_url(url):
downloads.safe_download(url=url, file=file)
return file
def check_online() -> bool:
@ -213,7 +220,7 @@ def check_file(file, suffix=''):
if Path(file).is_file():
LOGGER.info(f'Found {url} locally at {file}') # file already exists
else:
downloads.safe_download(file=file, url=url)
downloads.safe_download(url=url, file=file)
return file
else: # search
files = []