Improvements (#142)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
dcd8ef68e6
commit
55bdca6768
6 changed files with 149 additions and 45 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import glob
|
||||
import inspect
|
||||
import math
|
||||
import platform
|
||||
import urllib
|
||||
from pathlib import Path
|
||||
|
|
@ -13,71 +14,141 @@ import torch
|
|||
|
||||
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis,
|
||||
is_docker, is_jupyter_notebook)
|
||||
from ultralytics.yolo.utils.ops import make_divisible
|
||||
|
||||
|
||||
def is_ascii(s=''):
|
||||
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
|
||||
s = str(s) # convert list, tuple, None, etc. to str
|
||||
return len(s.encode().decode('ascii', 'ignore')) == len(s)
|
||||
def is_ascii(s) -> bool:
|
||||
"""
|
||||
Check if a string is composed of only ASCII characters.
|
||||
|
||||
Args:
|
||||
s (str): String to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the string is composed only of ASCII characters, False otherwise.
|
||||
"""
|
||||
# Convert list, tuple, None, etc. to string
|
||||
s = str(s)
|
||||
|
||||
# Check if the string is composed of only ASCII characters
|
||||
return all(ord(c) < 128 for c in s)
|
||||
|
||||
|
||||
def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
|
||||
# Verify image size is a multiple of stride s in each dimension
|
||||
"""
|
||||
Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
|
||||
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
|
||||
|
||||
Args:
|
||||
imgsz (int or List[int]): Image size.
|
||||
stride (int): Stride value.
|
||||
min_dim (int): Minimum number of dimensions.
|
||||
floor (int): Minimum allowed value for image size.
|
||||
|
||||
Returns:
|
||||
List[int]: Updated image size.
|
||||
"""
|
||||
# Convert stride to integer if it is a tensor
|
||||
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
|
||||
if isinstance(imgsz, int): # integer i.e. imgsz=640
|
||||
sz = max(make_divisible(imgsz, stride), floor)
|
||||
else: # list i.e. imgsz=[640, 480]
|
||||
imgsz = list(imgsz) # convert to list if tuple
|
||||
sz = [max(make_divisible(x, stride), floor) for x in imgsz]
|
||||
|
||||
# Convert image size to list if it is an integer
|
||||
if isinstance(imgsz, int):
|
||||
imgsz = [imgsz]
|
||||
|
||||
# Make image size a multiple of the stride
|
||||
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
|
||||
|
||||
# Print warning message if image size was updated
|
||||
if sz != imgsz:
|
||||
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}')
|
||||
|
||||
# Check dims
|
||||
if min_dim == 2:
|
||||
if isinstance(imgsz, int):
|
||||
sz = [sz, sz]
|
||||
elif len(sz) == 1:
|
||||
sz = [sz[0], sz[0]]
|
||||
# Add missing dimensions if necessary
|
||||
if min_dim == 2 and len(sz) == 1:
|
||||
sz = [sz[0], sz[0]]
|
||||
|
||||
return sz
|
||||
|
||||
|
||||
def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False):
|
||||
# Check version vs. required version
|
||||
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
||||
def check_version(current: str = "0.0.0",
|
||||
minimum: str = "0.0.0",
|
||||
name: str = "version ",
|
||||
pinned: bool = False,
|
||||
hard: bool = False,
|
||||
verbose: bool = False) -> bool:
|
||||
"""
|
||||
Check current version against the required minimum version.
|
||||
|
||||
Args:
|
||||
current (str): Current version.
|
||||
minimum (str): Required minimum version.
|
||||
name (str): Name to be used in warning message.
|
||||
pinned (bool): If True, versions must match exactly. If False, minimum version must be satisfied.
|
||||
hard (bool): If True, raise an AssertionError if the minimum version is not met.
|
||||
verbose (bool): If True, print warning message if minimum version is not met.
|
||||
|
||||
Returns:
|
||||
bool: True if minimum version is met, False otherwise.
|
||||
"""
|
||||
from pkg_resources import parse_version
|
||||
current, minimum = (parse_version(x) for x in (current, minimum))
|
||||
result = (current == minimum) if pinned else (current >= minimum) # bool
|
||||
s = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed" # string
|
||||
warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed"
|
||||
if hard:
|
||||
assert result, emojis(s) # assert min requirements met
|
||||
assert result, emojis(warning_message) # assert min requirements met
|
||||
if verbose and not result:
|
||||
LOGGER.warning(s)
|
||||
LOGGER.warning(warning_message)
|
||||
return result
|
||||
|
||||
|
||||
def check_font(font=FONT, progress=False):
|
||||
# Download font to CONFIG_DIR if necessary
|
||||
def check_font(font: str = FONT, progress: bool = False) -> None:
|
||||
"""
|
||||
Download font file to the user's configuration directory if it does not already exist.
|
||||
|
||||
Args:
|
||||
font (str): Path to font file.
|
||||
progress (bool): If True, display a progress bar during the download.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
font = Path(font)
|
||||
|
||||
# Destination path for the font file
|
||||
file = USER_CONFIG_DIR / font.name
|
||||
|
||||
# Check if font file exists at the source or destination path
|
||||
if not font.exists() and not file.exists():
|
||||
# Download font file
|
||||
url = f'https://ultralytics.com/assets/{font.name}'
|
||||
LOGGER.info(f'Downloading {url} to {file}...')
|
||||
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
||||
|
||||
|
||||
def check_online():
|
||||
# Check internet connectivity
|
||||
def check_online() -> bool:
|
||||
"""
|
||||
Check internet connectivity by attempting to connect to a known online host.
|
||||
|
||||
Returns:
|
||||
bool: True if connection is successful, False otherwise.
|
||||
"""
|
||||
import socket
|
||||
try:
|
||||
socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
|
||||
# Check host accessibility by attempting to establish a connection
|
||||
socket.create_connection(("1.1.1.1", 443), timeout=5)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def check_python(minimum='3.7.0'):
|
||||
# Check current python version vs. required python version
|
||||
def check_python(minimum: str = '3.7.0') -> bool:
|
||||
"""
|
||||
Check current python version against the required minimum version.
|
||||
|
||||
Args:
|
||||
minimum (str): Required minimum version of python.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
check_version(platform.python_version(), minimum, name='Python ', hard=True)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue