ultralytics 8.0.239 Ultralytics Actions and hub-sdk adoption (#7431)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2024-01-10 03:16:08 +01:00 committed by GitHub
parent e795277391
commit fe27db2f6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
139 changed files with 6870 additions and 5125 deletions

View file

@ -25,23 +25,22 @@ from tqdm import tqdm as tqdm_original
from ultralytics import __version__
# PyTorch Multi-GPU DDP Constants
RANK = int(os.getenv('RANK', -1))
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv("RANK", -1))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
# Other Constants
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLO
ASSETS = ROOT / 'assets' # default images
DEFAULT_CFG_PATH = ROOT / 'cfg/default.yaml'
ASSETS = ROOT / "assets" # default images
DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml"
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' if VERBOSE else None # tqdm bar format
LOGGING_NAME = 'ultralytics'
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans
ARM64 = platform.machine() in ('arm64', 'aarch64') # ARM64 booleans
HELP_MSG = \
"""
AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true" # global auto-install mode
VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbose mode
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format
LOGGING_NAME = "ultralytics"
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
ARM64 = platform.machine() in ("arm64", "aarch64") # ARM64 booleans
HELP_MSG = """
Usage examples for running YOLOv8:
1. Install the ultralytics package:
@ -99,12 +98,12 @@ HELP_MSG = \
"""
# Settings
torch.set_printoptions(linewidth=320, precision=4, profile='default')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
torch.set_printoptions(linewidth=320, precision=4, profile="default")
np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format}) # format short g, %precision=5
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab
os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # for deterministic training
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress verbose TF compiler warnings in Colab
class TQDM(tqdm_original):
@ -119,8 +118,8 @@ class TQDM(tqdm_original):
def __init__(self, *args, **kwargs):
"""Initialize custom Ultralytics tqdm class with different default arguments."""
# Set new default values (these can still be overridden when calling TQDM)
kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed
kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed
kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) # logical 'and' with default value if passed
kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed
super().__init__(*args, **kwargs)
@ -134,14 +133,14 @@ class SimpleClass:
attr = []
for a in dir(self):
v = getattr(self, a)
if not callable(v) and not a.startswith('_'):
if not callable(v) and not a.startswith("_"):
if isinstance(v, SimpleClass):
# Display only the module and class name for subclasses
s = f'{a}: {v.__module__}.{v.__class__.__name__} object'
s = f"{a}: {v.__module__}.{v.__class__.__name__} object"
else:
s = f'{a}: {repr(v)}'
s = f"{a}: {repr(v)}"
attr.append(s)
return f'{self.__module__}.{self.__class__.__name__} object with attributes:\n\n' + '\n'.join(attr)
return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr)
def __repr__(self):
"""Return a machine-readable string representation of the object."""
@ -164,24 +163,26 @@ class IterableSimpleNamespace(SimpleNamespace):
def __str__(self):
"""Return a human-readable string representation of the object."""
return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
return "\n".join(f"{k}={v}" for k, v in vars(self).items())
def __getattr__(self, attr):
"""Custom attribute access error message with helpful information."""
name = self.__class__.__name__
raise AttributeError(f"""
raise AttributeError(
f"""
'{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics
'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace
{DEFAULT_CFG_PATH} with the latest version from
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml
""")
"""
)
def get(self, key, default=None):
"""Return the value of the specified key if it exists; otherwise, return the default value."""
return getattr(self, key, default)
def plt_settings(rcparams=None, backend='Agg'):
def plt_settings(rcparams=None, backend="Agg"):
"""
Decorator to temporarily set rc parameters and the backend for a plotting function.
@ -199,7 +200,7 @@ def plt_settings(rcparams=None, backend='Agg'):
"""
if rcparams is None:
rcparams = {'font.size': 11}
rcparams = {"font.size": 11}
def decorator(func):
"""Decorator to apply temporary rc parameters and backend to a function."""
@ -208,14 +209,14 @@ def plt_settings(rcparams=None, backend='Agg'):
"""Sets rc parameters and backend, calls the original function, and restores the settings."""
original_backend = plt.get_backend()
if backend != original_backend:
plt.close('all') # auto-close()ing of figures upon backend switching is deprecated since 3.8
plt.close("all") # auto-close()ing of figures upon backend switching is deprecated since 3.8
plt.switch_backend(backend)
with plt.rc_context(rcparams):
result = func(*args, **kwargs)
if backend != original_backend:
plt.close('all')
plt.close("all")
plt.switch_backend(original_backend)
return result
@ -229,26 +230,26 @@ def set_logging(name=LOGGING_NAME, verbose=True):
level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings
# Configure the console (stdout) encoding to UTF-8
formatter = logging.Formatter('%(message)s') # Default formatter
if WINDOWS and sys.stdout.encoding != 'utf-8':
formatter = logging.Formatter("%(message)s") # Default formatter
if WINDOWS and sys.stdout.encoding != "utf-8":
try:
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8')
elif hasattr(sys.stdout, 'buffer'):
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
elif hasattr(sys.stdout, "buffer"):
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
else:
sys.stdout.encoding = 'utf-8'
sys.stdout.encoding = "utf-8"
except Exception as e:
print(f'Creating custom formatter for non UTF-8 environments due to {e}')
print(f"Creating custom formatter for non UTF-8 environments due to {e}")
class CustomFormatter(logging.Formatter):
def format(self, record):
"""Sets up logging with UTF-8 encoding and configurable verbosity."""
return emojis(super().format(record))
formatter = CustomFormatter('%(message)s') # Use CustomFormatter to eliminate UTF-8 output as last recourse
formatter = CustomFormatter("%(message)s") # Use CustomFormatter to eliminate UTF-8 output as last recourse
# Create and configure the StreamHandler
stream_handler = logging.StreamHandler(sys.stdout)
@ -264,13 +265,13 @@ def set_logging(name=LOGGING_NAME, verbose=True):
# Set logger
LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.)
for logger in 'sentry_sdk', 'urllib3.connectionpool':
for logger in "sentry_sdk", "urllib3.connectionpool":
logging.getLogger(logger).setLevel(logging.CRITICAL + 1)
def emojis(string=''):
def emojis(string=""):
"""Return platform-dependent emoji-safe version of string."""
return string.encode().decode('ascii', 'ignore') if WINDOWS else string
return string.encode().decode("ascii", "ignore") if WINDOWS else string
class ThreadingLocked:
@ -310,7 +311,7 @@ class ThreadingLocked:
return decorated
def yaml_save(file='data.yaml', data=None, header=''):
def yaml_save(file="data.yaml", data=None, header=""):
"""
Save YAML data to a file.
@ -336,13 +337,13 @@ def yaml_save(file='data.yaml', data=None, header=''):
data[k] = str(v)
# Dump data to file in YAML format
with open(file, 'w', errors='ignore', encoding='utf-8') as f:
with open(file, "w", errors="ignore", encoding="utf-8") as f:
if header:
f.write(header)
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)
def yaml_load(file='data.yaml', append_filename=False):
def yaml_load(file="data.yaml", append_filename=False):
"""
Load YAML data from a file.
@ -353,18 +354,18 @@ def yaml_load(file='data.yaml', append_filename=False):
Returns:
(dict): YAML data and file name.
"""
assert Path(file).suffix in ('.yaml', '.yml'), f'Attempting to load non-YAML file {file} with yaml_load()'
with open(file, errors='ignore', encoding='utf-8') as f:
assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()"
with open(file, errors="ignore", encoding="utf-8") as f:
s = f.read() # string
# Remove special characters
if not s.isprintable():
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s)
# Add YAML filename to dict and return
data = yaml.safe_load(s) or {} # always return a dict (yaml.safe_load() may return None for empty files)
if append_filename:
data['yaml_file'] = str(file)
data["yaml_file"] = str(file)
return data
@ -386,7 +387,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
# Default configuration
DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
for k, v in DEFAULT_CFG_DICT.items():
if isinstance(v, str) and v.lower() == 'none':
if isinstance(v, str) and v.lower() == "none":
DEFAULT_CFG_DICT[k] = None
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
@ -400,8 +401,8 @@ def is_ubuntu() -> bool:
(bool): True if OS is Ubuntu, False otherwise.
"""
with contextlib.suppress(FileNotFoundError):
with open('/etc/os-release') as f:
return 'ID=ubuntu' in f.read()
with open("/etc/os-release") as f:
return "ID=ubuntu" in f.read()
return False
@ -412,7 +413,7 @@ def is_colab():
Returns:
(bool): True if running inside a Colab notebook, False otherwise.
"""
return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ
return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ
def is_kaggle():
@ -422,7 +423,7 @@ def is_kaggle():
Returns:
(bool): True if running inside a Kaggle kernel, False otherwise.
"""
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
def is_jupyter():
@ -434,6 +435,7 @@ def is_jupyter():
"""
with contextlib.suppress(Exception):
from IPython import get_ipython
return get_ipython() is not None
return False
@ -445,10 +447,10 @@ def is_docker() -> bool:
Returns:
(bool): True if the script is running inside a Docker container, False otherwise.
"""
file = Path('/proc/self/cgroup')
file = Path("/proc/self/cgroup")
if file.exists():
with open(file) as f:
return 'docker' in f.read()
return "docker" in f.read()
else:
return False
@ -462,7 +464,7 @@ def is_online() -> bool:
"""
import socket
for host in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
try:
test_connection = socket.create_connection(address=(host, 53), timeout=2)
except (socket.timeout, socket.gaierror, OSError):
@ -516,7 +518,7 @@ def is_pytest_running():
Returns:
(bool): True if pytest is running, False otherwise.
"""
return ('PYTEST_CURRENT_TEST' in os.environ) or ('pytest' in sys.modules) or ('pytest' in Path(sys.argv[0]).stem)
return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(sys.argv[0]).stem)
def is_github_action_running() -> bool:
@ -526,7 +528,7 @@ def is_github_action_running() -> bool:
Returns:
(bool): True if the current environment is a GitHub Actions runner, False otherwise.
"""
return 'GITHUB_ACTIONS' in os.environ and 'GITHUB_WORKFLOW' in os.environ and 'RUNNER_OS' in os.environ
return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ
def is_git_dir():
@ -549,7 +551,7 @@ def get_git_dir():
(Path | None): Git root directory if found or None if not found.
"""
for d in Path(__file__).parents:
if (d / '.git').is_dir():
if (d / ".git").is_dir():
return d
@ -562,7 +564,7 @@ def get_git_origin_url():
"""
if is_git_dir():
with contextlib.suppress(subprocess.CalledProcessError):
origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url'])
origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
return origin.decode().strip()
@ -575,7 +577,7 @@ def get_git_branch():
"""
if is_git_dir():
with contextlib.suppress(subprocess.CalledProcessError):
origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
return origin.decode().strip()
@ -602,11 +604,11 @@ def get_ubuntu_version():
"""
if is_ubuntu():
with contextlib.suppress(FileNotFoundError, AttributeError):
with open('/etc/os-release') as f:
with open("/etc/os-release") as f:
return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]
def get_user_config_dir(sub_dir='Ultralytics'):
def get_user_config_dir(sub_dir="Ultralytics"):
"""
Get the user config directory.
@ -618,19 +620,21 @@ def get_user_config_dir(sub_dir='Ultralytics'):
"""
# Return the appropriate config directory for each operating system
if WINDOWS:
path = Path.home() / 'AppData' / 'Roaming' / sub_dir
path = Path.home() / "AppData" / "Roaming" / sub_dir
elif MACOS: # macOS
path = Path.home() / 'Library' / 'Application Support' / sub_dir
path = Path.home() / "Library" / "Application Support" / sub_dir
elif LINUX:
path = Path.home() / '.config' / sub_dir
path = Path.home() / ".config" / sub_dir
else:
raise ValueError(f'Unsupported operating system: {platform.system()}')
raise ValueError(f"Unsupported operating system: {platform.system()}")
# GCP and AWS lambda fix, only /tmp is writeable
if not is_dir_writeable(path.parent):
LOGGER.warning(f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
'Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path.')
path = Path('/tmp') / sub_dir if is_dir_writeable('/tmp') else Path().cwd() / sub_dir
LOGGER.warning(
f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
"Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path."
)
path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir
# Create the subdirectory if it does not exist
path.mkdir(parents=True, exist_ok=True)
@ -638,8 +642,8 @@ def get_user_config_dir(sub_dir='Ultralytics'):
return path
USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR') or get_user_config_dir()) # Ultralytics settings dir
SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml'
USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir()) # Ultralytics settings dir
SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml"
def colorstr(*input):
@ -670,28 +674,29 @@ def colorstr(*input):
>>> colorstr('blue', 'bold', 'hello world')
>>> '\033[34m\033[1mhello world\033[0m'
"""
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
colors = {
'black': '\033[30m', # basic colors
'red': '\033[31m',
'green': '\033[32m',
'yellow': '\033[33m',
'blue': '\033[34m',
'magenta': '\033[35m',
'cyan': '\033[36m',
'white': '\033[37m',
'bright_black': '\033[90m', # bright colors
'bright_red': '\033[91m',
'bright_green': '\033[92m',
'bright_yellow': '\033[93m',
'bright_blue': '\033[94m',
'bright_magenta': '\033[95m',
'bright_cyan': '\033[96m',
'bright_white': '\033[97m',
'end': '\033[0m', # misc
'bold': '\033[1m',
'underline': '\033[4m'}
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
"black": "\033[30m", # basic colors
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
"bright_black": "\033[90m", # bright colors
"bright_red": "\033[91m",
"bright_green": "\033[92m",
"bright_yellow": "\033[93m",
"bright_blue": "\033[94m",
"bright_magenta": "\033[95m",
"bright_cyan": "\033[96m",
"bright_white": "\033[97m",
"end": "\033[0m", # misc
"bold": "\033[1m",
"underline": "\033[4m",
}
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
def remove_colorstr(input_string):
@ -708,8 +713,8 @@ def remove_colorstr(input_string):
>>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
>>> 'hello world'
"""
ansi_escape = re.compile(r'\x1B\[[0-9;]*[A-Za-z]')
return ansi_escape.sub('', input_string)
ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]")
return ansi_escape.sub("", input_string)
class TryExcept(contextlib.ContextDecorator):
@ -719,7 +724,7 @@ class TryExcept(contextlib.ContextDecorator):
Use as @TryExcept() decorator or 'with TryExcept():' context manager.
"""
def __init__(self, msg='', verbose=True):
def __init__(self, msg="", verbose=True):
"""Initialize TryExcept class with optional message and verbosity settings."""
self.msg = msg
self.verbose = verbose
@ -744,7 +749,7 @@ def threaded(func):
def wrapper(*args, **kwargs):
"""Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result."""
if kwargs.pop('threaded', True): # run in thread
if kwargs.pop("threaded", True): # run in thread
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
@ -786,27 +791,28 @@ def set_sentry():
Returns:
dict: The modified event or None if the event should not be sent to Sentry.
"""
if 'exc_info' in hint:
exc_type, exc_value, tb = hint['exc_info']
if exc_type in (KeyboardInterrupt, FileNotFoundError) \
or 'out of memory' in str(exc_value):
if "exc_info" in hint:
exc_type, exc_value, tb = hint["exc_info"]
if exc_type in (KeyboardInterrupt, FileNotFoundError) or "out of memory" in str(exc_value):
return None # do not send event
event['tags'] = {
'sys_argv': sys.argv[0],
'sys_argv_name': Path(sys.argv[0]).name,
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
'os': ENVIRONMENT}
event["tags"] = {
"sys_argv": sys.argv[0],
"sys_argv_name": Path(sys.argv[0]).name,
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
"os": ENVIRONMENT,
}
return event
if SETTINGS['sync'] and \
RANK in (-1, 0) and \
Path(sys.argv[0]).name == 'yolo' and \
not TESTS_RUNNING and \
ONLINE and \
is_pip_package() and \
not is_git_dir():
if (
SETTINGS["sync"]
and RANK in (-1, 0)
and Path(sys.argv[0]).name == "yolo"
and not TESTS_RUNNING
and ONLINE
and is_pip_package()
and not is_git_dir()
):
# If sentry_sdk package is not installed then return and do not use Sentry
try:
import sentry_sdk # noqa
@ -814,14 +820,15 @@ def set_sentry():
return
sentry_sdk.init(
dsn='https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016',
dsn="https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016",
debug=False,
traces_sample_rate=1.0,
release=__version__,
environment='production', # 'dev' or 'production'
environment="production", # 'dev' or 'production'
before_send=before_send,
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
sentry_sdk.set_user({'id': SETTINGS['uuid']}) # SHA-256 anonymized UUID hash
ignore_errors=[KeyboardInterrupt, FileNotFoundError],
)
sentry_sdk.set_user({"id": SETTINGS["uuid"]}) # SHA-256 anonymized UUID hash
class SettingsManager(dict):
@ -833,7 +840,7 @@ class SettingsManager(dict):
version (str): Settings version. In case of local version mismatch, new default settings will be saved.
"""
def __init__(self, file=SETTINGS_YAML, version='0.0.4'):
def __init__(self, file=SETTINGS_YAML, version="0.0.4"):
"""Initialize the SettingsManager with default settings, load and validate current settings from the YAML
file.
"""
@ -850,23 +857,24 @@ class SettingsManager(dict):
self.file = Path(file)
self.version = version
self.defaults = {
'settings_version': version,
'datasets_dir': str(datasets_root / 'datasets'),
'weights_dir': str(root / 'weights'),
'runs_dir': str(root / 'runs'),
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
'sync': True,
'api_key': '',
'openai_api_key': '',
'clearml': True, # integrations
'comet': True,
'dvc': True,
'hub': True,
'mlflow': True,
'neptune': True,
'raytune': True,
'tensorboard': True,
'wandb': True}
"settings_version": version,
"datasets_dir": str(datasets_root / "datasets"),
"weights_dir": str(root / "weights"),
"runs_dir": str(root / "runs"),
"uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
"sync": True,
"api_key": "",
"openai_api_key": "",
"clearml": True, # integrations
"comet": True,
"dvc": True,
"hub": True,
"mlflow": True,
"neptune": True,
"raytune": True,
"tensorboard": True,
"wandb": True,
}
super().__init__(copy.deepcopy(self.defaults))
@ -877,13 +885,14 @@ class SettingsManager(dict):
self.load()
correct_keys = self.keys() == self.defaults.keys()
correct_types = all(type(a) is type(b) for a, b in zip(self.values(), self.defaults.values()))
correct_version = check_version(self['settings_version'], self.version)
correct_version = check_version(self["settings_version"], self.version)
if not (correct_keys and correct_types and correct_version):
LOGGER.warning(
'WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem '
'with your settings or a recent ultralytics package update. '
"WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem "
"with your settings or a recent ultralytics package update. "
f"\nView settings with 'yolo settings' or at '{self.file}'"
"\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'.")
"\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'."
)
self.reset()
def load(self):
@ -910,14 +919,16 @@ def deprecation_warn(arg, new_arg, version=None):
"""Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
if not version:
version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release
LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
f"Please use '{new_arg}' instead.")
LOGGER.warning(
f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
f"Please use '{new_arg}' instead."
)
def clean_url(url):
"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
url = Path(url).as_posix().replace(':/', '://') # Pathlib turns :// -> :/, as_posix() for Windows
return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windows
return urllib.parse.unquote(url).split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth
def url2file(url):
@ -928,13 +939,22 @@ def url2file(url):
# Run below code on utils init ------------------------------------------------------------------------------------
# Check first-install steps
PREFIX = colorstr('Ultralytics: ')
PREFIX = colorstr("Ultralytics: ")
SETTINGS = SettingsManager() # initialize settings
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
WEIGHTS_DIR = Path(SETTINGS['weights_dir']) # global weights directory
RUNS_DIR = Path(SETTINGS['runs_dir']) # global runs directory
ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
'Docker' if is_docker() else platform.system()
DATASETS_DIR = Path(SETTINGS["datasets_dir"]) # global datasets directory
WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory
RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory
ENVIRONMENT = (
"Colab"
if is_colab()
else "Kaggle"
if is_kaggle()
else "Jupyter"
if is_jupyter()
else "Docker"
if is_docker()
else platform.system()
)
TESTS_RUNNING = is_pytest_running() or is_github_action_running()
set_sentry()

View file

@ -42,14 +42,14 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
"""
# Check device
prefix = colorstr('AutoBatch: ')
LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
prefix = colorstr("AutoBatch: ")
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz}")
device = next(model.parameters()).device # get model device
if device.type == 'cpu':
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
if device.type == "cpu":
LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}")
return batch_size
if torch.backends.cudnn.benchmark:
LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
return batch_size
# Inspect CUDA memory
@ -60,7 +60,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
f = t - (r + a) # GiB free
LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
# Profile batch sizes
batch_sizes = [1, 2, 4, 8, 16]
@ -70,7 +70,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
# Fit a solution
y = [x[2] for x in results if x] # memory [2]
p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
p = np.polyfit(batch_sizes[: len(y)], y, deg=1) # first degree polynomial fit
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
if None in results: # some sizes failed
i = results.index(None) # first fail index
@ -78,11 +78,11 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
b = batch_sizes[max(i - 1, 0)] # select prior safe point
if b < 1 or b > 1024: # b outside of safe range
b = batch_size
LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.')
LOGGER.info(f"{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.")
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
return b
except Exception as e:
LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.')
LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.")
return batch_size

View file

@ -42,13 +42,9 @@ from ultralytics.utils.files import file_size
from ultralytics.utils.torch_utils import select_device
def benchmark(model=WEIGHTS_DIR / 'yolov8n.pt',
data=None,
imgsz=160,
half=False,
int8=False,
device='cpu',
verbose=False):
def benchmark(
model=WEIGHTS_DIR / "yolov8n.pt", data=None, imgsz=160, half=False, int8=False, device="cpu", verbose=False
):
"""
Benchmark a YOLO model across different formats for speed and accuracy.
@ -76,6 +72,7 @@ def benchmark(model=WEIGHTS_DIR / 'yolov8n.pt',
"""
import pandas as pd
pd.options.display.max_columns = 10
pd.options.display.width = 120
device = select_device(device, verbose=False)
@ -85,67 +82,62 @@ def benchmark(model=WEIGHTS_DIR / 'yolov8n.pt',
y = []
t0 = time.time()
for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
emoji, filename = '', None # export defaults
emoji, filename = "", None # export defaults
try:
assert i != 9 or LINUX, 'Edge TPU export only supported on Linux'
assert i != 9 or LINUX, "Edge TPU export only supported on Linux"
if i == 10:
assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux'
assert MACOS or LINUX, "TF.js export only supported on macOS and Linux"
elif i == 11:
assert sys.version_info < (3, 11), 'PaddlePaddle export only supported on Python<=3.10'
if 'cpu' in device.type:
assert cpu, 'inference not supported on CPU'
if 'cuda' in device.type:
assert gpu, 'inference not supported on GPU'
assert sys.version_info < (3, 11), "PaddlePaddle export only supported on Python<=3.10"
if "cpu" in device.type:
assert cpu, "inference not supported on CPU"
if "cuda" in device.type:
assert gpu, "inference not supported on GPU"
# Export
if format == '-':
if format == "-":
filename = model.ckpt_path or model.cfg
exported_model = model # PyTorch format
else:
filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False)
exported_model = YOLO(filename, task=model.task)
assert suffix in str(filename), 'export failed'
emoji = '' # indicates export succeeded
assert suffix in str(filename), "export failed"
emoji = "" # indicates export succeeded
# Predict
assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported'
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
exported_model.predict(ASSETS / 'bus.jpg', imgsz=imgsz, device=device, half=half)
assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported
assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
# Validate
data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
results = exported_model.val(data=data,
batch=1,
imgsz=imgsz,
plots=False,
device=device,
half=half,
int8=int8,
verbose=False)
metric, speed = results.results_dict[key], results.speed['inference']
y.append([name, '', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
results = exported_model.val(
data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, int8=int8, verbose=False
)
metric, speed = results.results_dict[key], results.speed["inference"]
y.append([name, "", round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
except Exception as e:
if verbose:
assert type(e) is AssertionError, f'Benchmark failure for {name}: {e}'
LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}')
assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}"
LOGGER.warning(f"ERROR ❌️ Benchmark failure for {name}: {e}")
y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference
# Print results
check_yolo(device=device) # print system info
df = pd.DataFrame(y, columns=['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)'])
df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)"])
name = Path(model.ckpt_path).name
s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n'
s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n"
LOGGER.info(s)
with open('benchmarks.log', 'a', errors='ignore', encoding='utf-8') as f:
with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
f.write(s)
if verbose and isinstance(verbose, float):
metrics = df[key].array # values to compare to floor
floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
assert all(x > floor for x in metrics if pd.notna(x)), f'Benchmark failure: metric(s) < floor {floor}'
assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}"
return df
@ -175,15 +167,17 @@ class ProfileModels:
```
"""
def __init__(self,
paths: list,
num_timed_runs=100,
num_warmup_runs=10,
min_time=60,
imgsz=640,
half=True,
trt=True,
device=None):
def __init__(
self,
paths: list,
num_timed_runs=100,
num_warmup_runs=10,
min_time=60,
imgsz=640,
half=True,
trt=True,
device=None,
):
"""
Initialize the ProfileModels class for profiling models.
@ -204,37 +198,32 @@ class ProfileModels:
self.imgsz = imgsz
self.half = half
self.trt = trt # run TensorRT profiling
self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu')
self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu")
def profile(self):
"""Logs the benchmarking results of a model, checks metrics against floor and returns the results."""
files = self.get_files()
if not files:
print('No matching *.pt or *.onnx files found.')
print("No matching *.pt or *.onnx files found.")
return
table_rows = []
output = []
for file in files:
engine_file = file.with_suffix('.engine')
if file.suffix in ('.pt', '.yaml', '.yml'):
engine_file = file.with_suffix(".engine")
if file.suffix in (".pt", ".yaml", ".yml"):
model = YOLO(str(file))
model.fuse() # to report correct params and GFLOPs in model.info()
model_info = model.info()
if self.trt and self.device.type != 'cpu' and not engine_file.is_file():
engine_file = model.export(format='engine',
half=self.half,
imgsz=self.imgsz,
device=self.device,
verbose=False)
onnx_file = model.export(format='onnx',
half=self.half,
imgsz=self.imgsz,
simplify=True,
device=self.device,
verbose=False)
elif file.suffix == '.onnx':
if self.trt and self.device.type != "cpu" and not engine_file.is_file():
engine_file = model.export(
format="engine", half=self.half, imgsz=self.imgsz, device=self.device, verbose=False
)
onnx_file = model.export(
format="onnx", half=self.half, imgsz=self.imgsz, simplify=True, device=self.device, verbose=False
)
elif file.suffix == ".onnx":
model_info = self.get_onnx_model_info(file)
onnx_file = file
else:
@ -254,14 +243,14 @@ class ProfileModels:
for path in self.paths:
path = Path(path)
if path.is_dir():
extensions = ['*.pt', '*.onnx', '*.yaml']
extensions = ["*.pt", "*.onnx", "*.yaml"]
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
elif path.suffix in {'.pt', '.yaml', '.yml'}: # add non-existing
elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing
files.append(str(path))
else:
files.extend(glob.glob(str(path)))
print(f'Profiling: {sorted(files)}')
print(f"Profiling: {sorted(files)}")
return [Path(file) for file in sorted(files)]
def get_onnx_model_info(self, onnx_file: str):
@ -306,7 +295,7 @@ class ProfileModels:
run_times = []
for _ in TQDM(range(num_runs), desc=engine_file):
results = model(input_data, imgsz=self.imgsz, verbose=False)
run_times.append(results[0].speed['inference']) # Convert to milliseconds
run_times.append(results[0].speed["inference"]) # Convert to milliseconds
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
return np.mean(run_times), np.std(run_times)
@ -315,31 +304,31 @@ class ProfileModels:
"""Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
times.
"""
check_requirements('onnxruntime')
check_requirements("onnxruntime")
import onnxruntime as ort
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 8 # Limit the number of threads
sess = ort.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider'])
sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"])
input_tensor = sess.get_inputs()[0]
input_type = input_tensor.type
# Mapping ONNX datatype to numpy datatype
if 'float16' in input_type:
if "float16" in input_type:
input_dtype = np.float16
elif 'float' in input_type:
elif "float" in input_type:
input_dtype = np.float32
elif 'double' in input_type:
elif "double" in input_type:
input_dtype = np.float64
elif 'int64' in input_type:
elif "int64" in input_type:
input_dtype = np.int64
elif 'int32' in input_type:
elif "int32" in input_type:
input_dtype = np.int32
else:
raise ValueError(f'Unsupported ONNX datatype {input_type}')
raise ValueError(f"Unsupported ONNX datatype {input_type}")
input_data = np.random.rand(*input_tensor.shape).astype(input_dtype)
input_name = input_tensor.name
@ -369,25 +358,26 @@ class ProfileModels:
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
"""Generates a formatted string for a table row that includes model performance and metric details."""
layers, params, gradients, flops = model_info
return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |'
return f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |"
def generate_results_dict(self, model_name, t_onnx, t_engine, model_info):
"""Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
layers, params, gradients, flops = model_info
return {
'model/name': model_name,
'model/parameters': params,
'model/GFLOPs': round(flops, 3),
'model/speed_ONNX(ms)': round(t_onnx[0], 3),
'model/speed_TensorRT(ms)': round(t_engine[0], 3)}
"model/name": model_name,
"model/parameters": params,
"model/GFLOPs": round(flops, 3),
"model/speed_ONNX(ms)": round(t_onnx[0], 3),
"model/speed_TensorRT(ms)": round(t_engine[0], 3),
}
def print_table(self, table_rows):
"""Formats and prints a comparison table for different models with given statistics and performance data."""
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
header = f'| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |'
separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
header = f"| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |"
separator = "|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|"
print(f'\n\n{header}')
print(f"\n\n{header}")
print(separator)
for row in table_rows:
print(row)

View file

@ -2,4 +2,4 @@
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
__all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks'
__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"

View file

@ -143,37 +143,35 @@ def on_export_end(exporter):
default_callbacks = {
# Run in trainer
'on_pretrain_routine_start': [on_pretrain_routine_start],
'on_pretrain_routine_end': [on_pretrain_routine_end],
'on_train_start': [on_train_start],
'on_train_epoch_start': [on_train_epoch_start],
'on_train_batch_start': [on_train_batch_start],
'optimizer_step': [optimizer_step],
'on_before_zero_grad': [on_before_zero_grad],
'on_train_batch_end': [on_train_batch_end],
'on_train_epoch_end': [on_train_epoch_end],
'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
'on_model_save': [on_model_save],
'on_train_end': [on_train_end],
'on_params_update': [on_params_update],
'teardown': [teardown],
"on_pretrain_routine_start": [on_pretrain_routine_start],
"on_pretrain_routine_end": [on_pretrain_routine_end],
"on_train_start": [on_train_start],
"on_train_epoch_start": [on_train_epoch_start],
"on_train_batch_start": [on_train_batch_start],
"optimizer_step": [optimizer_step],
"on_before_zero_grad": [on_before_zero_grad],
"on_train_batch_end": [on_train_batch_end],
"on_train_epoch_end": [on_train_epoch_end],
"on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val
"on_model_save": [on_model_save],
"on_train_end": [on_train_end],
"on_params_update": [on_params_update],
"teardown": [teardown],
# Run in validator
'on_val_start': [on_val_start],
'on_val_batch_start': [on_val_batch_start],
'on_val_batch_end': [on_val_batch_end],
'on_val_end': [on_val_end],
"on_val_start": [on_val_start],
"on_val_batch_start": [on_val_batch_start],
"on_val_batch_end": [on_val_batch_end],
"on_val_end": [on_val_end],
# Run in predictor
'on_predict_start': [on_predict_start],
'on_predict_batch_start': [on_predict_batch_start],
'on_predict_postprocess_end': [on_predict_postprocess_end],
'on_predict_batch_end': [on_predict_batch_end],
'on_predict_end': [on_predict_end],
"on_predict_start": [on_predict_start],
"on_predict_batch_start": [on_predict_batch_start],
"on_predict_postprocess_end": [on_predict_postprocess_end],
"on_predict_batch_end": [on_predict_batch_end],
"on_predict_end": [on_predict_end],
# Run in exporter
'on_export_start': [on_export_start],
'on_export_end': [on_export_end]}
"on_export_start": [on_export_start],
"on_export_end": [on_export_end],
}
def get_default_callbacks():
@ -197,10 +195,11 @@ def add_integration_callbacks(instance):
# Load HUB callbacks
from .hub import callbacks as hub_cb
callbacks_list = [hub_cb]
# Load training callbacks
if 'Trainer' in instance.__class__.__name__:
if "Trainer" in instance.__class__.__name__:
from .clearml import callbacks as clear_cb
from .comet import callbacks as comet_cb
from .dvc import callbacks as dvc_cb
@ -209,6 +208,7 @@ def add_integration_callbacks(instance):
from .raytune import callbacks as tune_cb
from .tensorboard import callbacks as tb_cb
from .wb import callbacks as wb_cb
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
# Add the callbacks to the callbacks dictionary

View file

@ -4,19 +4,19 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['clearml'] is True # verify integration is enabled
assert SETTINGS["clearml"] is True # verify integration is enabled
import clearml
from clearml import Task
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from clearml.binding.matplotlib_bind import PatchedMatplotlib
assert hasattr(clearml, '__version__') # verify package is not directory
assert hasattr(clearml, "__version__") # verify package is not directory
except (ImportError, AssertionError):
clearml = None
def _log_debug_samples(files, title='Debug Samples') -> None:
def _log_debug_samples(files, title="Debug Samples") -> None:
"""
Log files (images) as debug samples in the ClearML task.
@ -29,12 +29,11 @@ def _log_debug_samples(files, title='Debug Samples') -> None:
if task := Task.current_task():
for f in files:
if f.exists():
it = re.search(r'_batch(\d+)', f.name)
it = re.search(r"_batch(\d+)", f.name)
iteration = int(it.groups()[0]) if it else 0
task.get_logger().report_image(title=title,
series=f.name.replace(it.group(), ''),
local_path=str(f),
iteration=iteration)
task.get_logger().report_image(
title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
)
def _log_plot(title, plot_path) -> None:
@ -50,13 +49,12 @@ def _log_plot(title, plot_path) -> None:
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
Task.current_task().get_logger().report_matplotlib_figure(title=title,
series='',
figure=fig,
report_interactive=False)
Task.current_task().get_logger().report_matplotlib_figure(
title=title, series="", figure=fig, report_interactive=False
)
def on_pretrain_routine_start(trainer):
@ -68,19 +66,21 @@ def on_pretrain_routine_start(trainer):
PatchPyTorchModelIO.update_current_task(None)
PatchedMatplotlib.update_current_task(None)
else:
task = Task.init(project_name=trainer.args.project or 'YOLOv8',
task_name=trainer.args.name,
tags=['YOLOv8'],
output_uri=True,
reuse_last_task_id=False,
auto_connect_frameworks={
'pytorch': False,
'matplotlib': False})
LOGGER.warning('ClearML Initialized a new task. If you want to run remotely, '
'please add clearml-init and connect your arguments before initializing YOLO.')
task.connect(vars(trainer.args), name='General')
task = Task.init(
project_name=trainer.args.project or "YOLOv8",
task_name=trainer.args.name,
tags=["YOLOv8"],
output_uri=True,
reuse_last_task_id=False,
auto_connect_frameworks={"pytorch": False, "matplotlib": False},
)
LOGGER.warning(
"ClearML Initialized a new task. If you want to run remotely, "
"please add clearml-init and connect your arguments before initializing YOLO."
)
task.connect(vars(trainer.args), name="General")
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer):
@ -88,26 +88,26 @@ def on_train_epoch_end(trainer):
if task := Task.current_task():
# Log debug samples
if trainer.epoch == 1:
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
_log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
# Report the current training progress
for k, v in trainer.label_loss_items(trainer.tloss, prefix='train').items():
task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
for k, v in trainer.lr.items():
task.get_logger().report_scalar('lr', k, v, iteration=trainer.epoch)
task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
def on_fit_epoch_end(trainer):
"""Reports model information to logger at the end of an epoch."""
if task := Task.current_task():
# You should have access to the validation bboxes under jdict
task.get_logger().report_scalar(title='Epoch Time',
series='Epoch Time',
value=trainer.epoch_time,
iteration=trainer.epoch)
task.get_logger().report_scalar(
title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
)
for k, v in trainer.metrics.items():
task.get_logger().report_scalar('val', k, v, iteration=trainer.epoch)
task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
for k, v in model_info_for_loggers(trainer).items():
task.get_logger().report_single_value(k, v)
@ -116,7 +116,7 @@ def on_val_end(validator):
"""Logs validation results including labels and predictions."""
if Task.current_task():
# Log val_labels and val_pred
_log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation')
_log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
def on_train_end(trainer):
@ -124,8 +124,11 @@ def on_train_end(trainer):
if task := Task.current_task():
# Log final results, CM matrix + PR plots
files = [
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
"results.png",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
@ -136,9 +139,14 @@ def on_train_end(trainer):
task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_val_end': on_val_end,
'on_train_end': on_train_end} if clearml else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if clearml
else {}
)

View file

@ -4,20 +4,20 @@ from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['comet'] is True # verify integration is enabled
assert SETTINGS["comet"] is True # verify integration is enabled
import comet_ml
assert hasattr(comet_ml, '__version__') # verify package is not directory
assert hasattr(comet_ml, "__version__") # verify package is not directory
import os
from pathlib import Path
# Ensures certain logging functions only run for supported tasks
COMET_SUPPORTED_TASKS = ['detect']
COMET_SUPPORTED_TASKS = ["detect"]
# Names of plots created by YOLOv8 that are logged to Comet
EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix'
LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix"
LABEL_PLOT_NAMES = "labels", "labels_correlogram"
_comet_image_prediction_count = 0
@ -27,43 +27,43 @@ except (ImportError, AssertionError):
def _get_comet_mode():
"""Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
return os.getenv('COMET_MODE', 'online')
return os.getenv("COMET_MODE", "online")
def _get_comet_model_name():
"""Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
return os.getenv("COMET_MODEL_NAME", "YOLOv8")
def _get_eval_batch_logging_interval():
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
def _get_max_image_predictions_to_log():
"""Get the maximum number of image predictions to log from the environment variables."""
return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
def _scale_confidence_score(score):
"""Scales the given confidence score by a factor specified in an environment variable."""
scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
return score * scale
def _should_log_confusion_matrix():
"""Determines if the confusion matrix should be logged based on the environment variable settings."""
return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true'
return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
def _should_log_image_predictions():
"""Determines whether to log image predictions based on a specified environment variable."""
return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
def _get_experiment_type(mode, project_name):
"""Return an experiment based on mode and project name."""
if mode == 'offline':
if mode == "offline":
return comet_ml.OfflineExperiment(project_name=project_name)
return comet_ml.Experiment(project_name=project_name)
@ -75,18 +75,21 @@ def _create_experiment(args):
return
try:
comet_mode = _get_comet_mode()
_project_name = os.getenv('COMET_PROJECT_NAME', args.project)
_project_name = os.getenv("COMET_PROJECT_NAME", args.project)
experiment = _get_experiment_type(comet_mode, _project_name)
experiment.log_parameters(vars(args))
experiment.log_others({
'eval_batch_logging_interval': _get_eval_batch_logging_interval(),
'log_confusion_matrix_on_eval': _should_log_confusion_matrix(),
'log_image_predictions': _should_log_image_predictions(),
'max_image_predictions': _get_max_image_predictions_to_log(), })
experiment.log_other('Created from', 'yolov8')
experiment.log_others(
{
"eval_batch_logging_interval": _get_eval_batch_logging_interval(),
"log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
"log_image_predictions": _should_log_image_predictions(),
"max_image_predictions": _get_max_image_predictions_to_log(),
}
)
experiment.log_other("Created from", "yolov8")
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}")
def _fetch_trainer_metadata(trainer):
@ -134,29 +137,32 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
"""Format ground truth annotations for detection."""
indices = batch['batch_idx'] == img_idx
bboxes = batch['bboxes'][indices]
indices = batch["batch_idx"] == img_idx
bboxes = batch["bboxes"][indices]
if len(bboxes) == 0:
LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes labels')
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels")
return None
cls_labels = batch['cls'][indices].squeeze(1).tolist()
cls_labels = batch["cls"][indices].squeeze(1).tolist()
if class_name_map:
cls_labels = [str(class_name_map[label]) for label in cls_labels]
original_image_shape = batch['ori_shape'][img_idx]
resized_image_shape = batch['resized_shape'][img_idx]
ratio_pad = batch['ratio_pad'][img_idx]
original_image_shape = batch["ori_shape"][img_idx]
resized_image_shape = batch["resized_shape"][img_idx]
ratio_pad = batch["ratio_pad"][img_idx]
data = []
for box, label in zip(bboxes, cls_labels):
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
data.append({
'boxes': [box],
'label': f'gt_{label}',
'score': _scale_confidence_score(1.0), })
data.append(
{
"boxes": [box],
"label": f"gt_{label}",
"score": _scale_confidence_score(1.0),
}
)
return {'name': 'ground_truth', 'data': data}
return {"name": "ground_truth", "data": data}
def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
@ -166,31 +172,34 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab
predictions = metadata.get(image_id)
if not predictions:
LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes predictions')
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions")
return None
data = []
for prediction in predictions:
boxes = prediction['bbox']
score = _scale_confidence_score(prediction['score'])
cls_label = prediction['category_id']
boxes = prediction["bbox"]
score = _scale_confidence_score(prediction["score"])
cls_label = prediction["category_id"]
if class_label_map:
cls_label = str(class_label_map[cls_label])
data.append({'boxes': [boxes], 'label': cls_label, 'score': score})
data.append({"boxes": [boxes], "label": cls_label, "score": score})
return {'name': 'prediction', 'data': data}
return {"name": "prediction", "data": data}
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
"""Join the ground truth and prediction annotations if they exist."""
ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch,
class_label_map)
prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map,
class_label_map)
ground_truth_annotations = _format_ground_truth_annotations_for_detection(
img_idx, image_path, batch, class_label_map
)
prediction_annotations = _format_prediction_annotations_for_detection(
image_path, prediction_metadata_map, class_label_map
)
annotations = [
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None]
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
]
return [annotations] if annotations else None
@ -198,8 +207,8 @@ def _create_prediction_metadata_map(model_predictions):
"""Create metadata map for model predictions by groupings them based on image ID."""
pred_metadata_map = {}
for prediction in model_predictions:
pred_metadata_map.setdefault(prediction['image_id'], [])
pred_metadata_map[prediction['image_id']].append(prediction)
pred_metadata_map.setdefault(prediction["image_id"], [])
pred_metadata_map[prediction["image_id"]].append(prediction)
return pred_metadata_map
@ -207,7 +216,7 @@ def _create_prediction_metadata_map(model_predictions):
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
"""Log the confusion matrix to Comet experiment."""
conf_mat = trainer.validator.confusion_matrix.matrix
names = list(trainer.data['names'].values()) + ['background']
names = list(trainer.data["names"].values()) + ["background"]
experiment.log_confusion_matrix(
matrix=conf_mat,
labels=names,
@ -251,7 +260,7 @@ def _log_image_predictions(experiment, validator, curr_step):
if (batch_idx + 1) % batch_logging_interval != 0:
continue
image_paths = batch['im_file']
image_paths = batch["im_file"]
for img_idx, image_path in enumerate(image_paths):
if _comet_image_prediction_count >= max_image_predictions:
return
@ -275,10 +284,10 @@ def _log_image_predictions(experiment, validator, curr_step):
def _log_plots(experiment, trainer):
"""Logs evaluation plots and label plots for the experiment."""
plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES]
plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
_log_images(experiment, plot_filenames, None)
label_plot_filenames = [trainer.save_dir / f'{labels}.jpg' for labels in LABEL_PLOT_NAMES]
label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
_log_images(experiment, label_plot_filenames, None)
@ -288,7 +297,7 @@ def _log_model(experiment, trainer):
experiment.log_model(
model_name,
file_or_folder=str(trainer.best),
file_name='best.pt',
file_name="best.pt",
overwrite=True,
)
@ -296,7 +305,7 @@ def _log_model(experiment, trainer):
def on_pretrain_routine_start(trainer):
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
experiment = comet_ml.get_global_experiment()
is_alive = getattr(experiment, 'alive', False)
is_alive = getattr(experiment, "alive", False)
if not experiment or not is_alive:
_create_experiment(trainer.args)
@ -308,17 +317,17 @@ def on_train_epoch_end(trainer):
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata['curr_epoch']
curr_step = metadata['curr_step']
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
experiment.log_metrics(
trainer.label_loss_items(trainer.tloss, prefix='train'),
trainer.label_loss_items(trainer.tloss, prefix="train"),
step=curr_step,
epoch=curr_epoch,
)
if curr_epoch == 1:
_log_images(experiment, trainer.save_dir.glob('train_batch*.jpg'), curr_step)
_log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
def on_fit_epoch_end(trainer):
@ -328,14 +337,15 @@ def on_fit_epoch_end(trainer):
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata['curr_epoch']
curr_step = metadata['curr_step']
save_assets = metadata['save_assets']
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
save_assets = metadata["save_assets"]
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
if curr_epoch == 1:
from ultralytics.utils.torch_utils import model_info_for_loggers
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
if not save_assets:
@ -355,8 +365,8 @@ def on_train_end(trainer):
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata['curr_epoch']
curr_step = metadata['curr_step']
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
plots = trainer.args.plots
_log_model(experiment, trainer)
@ -371,8 +381,13 @@ def on_train_end(trainer):
_comet_image_prediction_count = 0
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if comet_ml else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if comet_ml
else {}
)

View file

@ -4,9 +4,10 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['dvc'] is True # verify integration is enabled
assert SETTINGS["dvc"] is True # verify integration is enabled
import dvclive
assert checks.check_version('dvclive', '2.11.0', verbose=True)
assert checks.check_version("dvclive", "2.11.0", verbose=True)
import os
import re
@ -24,24 +25,24 @@ except (ImportError, AssertionError, TypeError):
dvclive = None
def _log_images(path, prefix=''):
def _log_images(path, prefix=""):
"""Logs images at specified path with an optional prefix using DVCLive."""
if live:
name = path.name
# Group images by batch to enable sliders in UI
if m := re.search(r'_batch(\d+)', name):
if m := re.search(r"_batch(\d+)", name):
ni = m[1]
new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem)
new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
name = (Path(new_stem) / ni).with_suffix(path.suffix)
live.log_image(os.path.join(prefix, name), path)
def _log_plots(plots, prefix=''):
def _log_plots(plots, prefix=""):
"""Logs plot images for training progress if they have not been previously processed."""
for name, params in plots.items():
timestamp = params['timestamp']
timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
_log_images(name, prefix)
_processed_plots[name] = timestamp
@ -53,15 +54,15 @@ def _log_confusion_matrix(validator):
preds = []
matrix = validator.confusion_matrix.matrix
names = list(validator.names.values())
if validator.confusion_matrix.task == 'detect':
names += ['background']
if validator.confusion_matrix.task == "detect":
names += ["background"]
for ti, pred in enumerate(matrix.T.astype(int)):
for pi, num in enumerate(pred):
targets.extend([names[ti]] * num)
preds.extend([names[pi]] * num)
live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True)
live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
def on_pretrain_routine_start(trainer):
@ -71,12 +72,12 @@ def on_pretrain_routine_start(trainer):
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
def on_pretrain_routine_end(trainer):
"""Logs plots related to the training process at the end of the pretraining routine."""
_log_plots(trainer.plots, 'train')
_log_plots(trainer.plots, "train")
def on_train_start(trainer):
@ -95,17 +96,18 @@ def on_fit_epoch_end(trainer):
"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
global _training_epoch
if live and _training_epoch:
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
for metric, value in model_info_for_loggers(trainer).items():
live.log_metric(metric, value, plot=False)
_log_plots(trainer.plots, 'train')
_log_plots(trainer.validator.plots, 'val')
_log_plots(trainer.plots, "train")
_log_plots(trainer.validator.plots, "val")
live.next_step()
_training_epoch = False
@ -115,24 +117,29 @@ def on_train_end(trainer):
"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
if live:
# At the end log the best metrics. It runs validator on the best model internally.
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value, plot=False)
_log_plots(trainer.plots, 'val')
_log_plots(trainer.validator.plots, 'val')
_log_plots(trainer.plots, "val")
_log_plots(trainer.validator.plots, "val")
_log_confusion_matrix(trainer.validator)
if trainer.best.exists():
live.log_artifact(trainer.best, copy=True, type='model')
live.log_artifact(trainer.best, copy=True, type="model")
live.end()
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_train_start': on_train_start,
'on_train_epoch_start': on_train_epoch_start,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if dvclive else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_train_start": on_train_start,
"on_train_epoch_start": on_train_epoch_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if dvclive
else {}
)

View file

@ -11,60 +11,62 @@ from ultralytics.utils import LOGGER, SETTINGS
def on_pretrain_routine_end(trainer):
"""Logs info before starting timer for upload rate limit."""
session = getattr(trainer, 'hub_session', None)
session = getattr(trainer, "hub_session", None)
if session:
# Start timer for upload rate limit
session.timers = {
'metrics': time(),
'ckpt': time(), } # start timer on session.rate_limit
"metrics": time(),
"ckpt": time(),
} # start timer on session.rate_limit
def on_fit_epoch_end(trainer):
"""Uploads training progress metrics at the end of each epoch."""
session = getattr(trainer, 'hub_session', None)
session = getattr(trainer, "hub_session", None)
if session:
# Upload metrics after val end
all_plots = {
**trainer.label_loss_items(trainer.tloss, prefix='train'),
**trainer.metrics, }
**trainer.label_loss_items(trainer.tloss, prefix="train"),
**trainer.metrics,
}
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
session.upload_metrics()
session.timers['metrics'] = time() # reset timer
session.timers["metrics"] = time() # reset timer
session.metrics_queue = {} # reset queue
def on_model_save(trainer):
"""Saves checkpoints to Ultralytics HUB with rate limiting."""
session = getattr(trainer, 'hub_session', None)
session = getattr(trainer, "hub_session", None)
if session:
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}')
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}")
session.upload_model(trainer.epoch, trainer.last, is_best)
session.timers['ckpt'] = time() # reset timer
session.timers["ckpt"] = time() # reset timer
def on_train_end(trainer):
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
session = getattr(trainer, 'hub_session', None)
session = getattr(trainer, "hub_session", None)
if session:
# Upload final model and metrics with exponential standoff
LOGGER.info(f'{PREFIX}Syncing final model...')
LOGGER.info(f"{PREFIX}Syncing final model...")
session.upload_model(
trainer.epoch,
trainer.best,
map=trainer.metrics.get('metrics/mAP50-95(B)', 0),
map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
final=True,
)
session.alive = False # stop heartbeats
LOGGER.info(f'{PREFIX}Done ✅\n'
f'{PREFIX}View model at {session.model_url} 🚀')
LOGGER.info(f"{PREFIX}Done ✅\n" f"{PREFIX}View model at {session.model_url} 🚀")
def on_train_start(trainer):
@ -87,12 +89,17 @@ def on_export_start(exporter):
events(exporter.args)
callbacks = ({
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_model_save': on_model_save,
'on_train_end': on_train_end,
'on_train_start': on_train_start,
'on_val_start': on_val_start,
'on_predict_start': on_predict_start,
'on_export_start': on_export_start, } if SETTINGS['hub'] is True else {}) # verify enabled
callbacks = (
{
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_model_save": on_model_save,
"on_train_end": on_train_end,
"on_train_start": on_train_start,
"on_val_start": on_val_start,
"on_predict_start": on_predict_start,
"on_export_start": on_export_start,
}
if SETTINGS["hub"] is True
else {}
) # verify enabled

View file

@ -26,15 +26,15 @@ from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorst
try:
import os
assert not TESTS_RUNNING or 'test_mlflow' in os.environ.get('PYTEST_CURRENT_TEST', '') # do not log pytest
assert SETTINGS['mlflow'] is True # verify integration is enabled
assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
assert SETTINGS["mlflow"] is True # verify integration is enabled
import mlflow
assert hasattr(mlflow, '__version__') # verify package is not directory
assert hasattr(mlflow, "__version__") # verify package is not directory
from pathlib import Path
PREFIX = colorstr('MLflow: ')
SANITIZE = lambda x: {k.replace('(', '').replace(')', ''): float(v) for k, v in x.items()}
PREFIX = colorstr("MLflow: ")
SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
except (ImportError, AssertionError):
mlflow = None
@ -61,33 +61,33 @@ def on_pretrain_routine_end(trainer):
"""
global mlflow
uri = os.environ.get('MLFLOW_TRACKING_URI') or str(RUNS_DIR / 'mlflow')
LOGGER.debug(f'{PREFIX} tracking uri: {uri}')
uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
mlflow.set_tracking_uri(uri)
# Set experiment and run names
experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8'
run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
mlflow.set_experiment(experiment_name)
mlflow.autolog()
try:
active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
LOGGER.info(f'{PREFIX}logging run_id({active_run.info.run_id}) to {uri}')
LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
if Path(uri).is_dir():
LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
mlflow.log_params(dict(trainer.args))
except Exception as e:
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n'
f'{PREFIX}WARNING ⚠️ Not tracking this run')
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")
def on_train_epoch_end(trainer):
"""Log training metrics at the end of each train epoch to MLflow."""
if mlflow:
mlflow.log_metrics(metrics=SANITIZE(trainer.label_loss_items(trainer.tloss, prefix='train')),
step=trainer.epoch)
mlflow.log_metrics(
metrics=SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")), step=trainer.epoch
)
mlflow.log_metrics(metrics=SANITIZE(trainer.lr), step=trainer.epoch)
@ -101,16 +101,23 @@ def on_train_end(trainer):
"""Log model artifacts at the end of the training."""
if mlflow:
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
for f in trainer.save_dir.glob('*'): # log all other files in save_dir
if f.suffix in {'.png', '.jpg', '.csv', '.pt', '.yaml'}:
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
mlflow.log_artifact(str(f))
mlflow.end_run()
LOGGER.info(f'{PREFIX}results logged to {mlflow.get_tracking_uri()}\n'
f"{PREFIX}disable with 'yolo settings mlflow=False'")
LOGGER.info(
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
f"{PREFIX}disable with 'yolo settings mlflow=False'"
)
callbacks = {
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if mlflow else {}
callbacks = (
{
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if mlflow
else {}
)

View file

@ -4,11 +4,11 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['neptune'] is True # verify integration is enabled
assert SETTINGS["neptune"] is True # verify integration is enabled
import neptune
from neptune.types import File
assert hasattr(neptune, '__version__')
assert hasattr(neptune, "__version__")
run = None # NeptuneAI experiment logger instance
@ -23,11 +23,11 @@ def _log_scalars(scalars, step=0):
run[k].append(value=v, step=step)
def _log_images(imgs_dict, group=''):
def _log_images(imgs_dict, group=""):
"""Log scalars to the NeptuneAI experiment logger."""
if run:
for k, v in imgs_dict.items():
run[f'{group}/{k}'].upload(File(v))
run[f"{group}/{k}"].upload(File(v))
def _log_plot(title, plot_path):
@ -43,34 +43,35 @@ def _log_plot(title, plot_path):
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
run[f'Plots/{title}'].upload(fig)
run[f"Plots/{title}"].upload(fig)
def on_pretrain_routine_start(trainer):
"""Callback function called before the training routine starts."""
try:
global run
run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8'])
run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()}
run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"])
run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer):
"""Callback function called at end of each training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
if trainer.epoch == 1:
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic')
_log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
def on_fit_epoch_end(trainer):
"""Callback function called at end of each fit (train+val) epoch."""
if run and trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
run['Configuration/Model'] = model_info_for_loggers(trainer)
run["Configuration/Model"] = model_info_for_loggers(trainer)
_log_scalars(trainer.metrics, trainer.epoch + 1)
@ -78,7 +79,7 @@ def on_val_end(validator):
"""Callback function called at end of each validation."""
if run:
# Log val_labels and val_pred
_log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation')
_log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
def on_train_end(trainer):
@ -86,19 +87,28 @@ def on_train_end(trainer):
if run:
# Log final results, CM matrix + PR plots
files = [
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
"results.png",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
# Log the final model
run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str(
trainer.best)))
run[f"weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}"].upload(
File(str(trainer.best))
)
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_val_end': on_val_end,
'on_train_end': on_train_end} if neptune else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if neptune
else {}
)

View file

@ -3,7 +3,7 @@
from ultralytics.utils import SETTINGS
try:
assert SETTINGS['raytune'] is True # verify integration is enabled
assert SETTINGS["raytune"] is True # verify integration is enabled
import ray
from ray import tune
from ray.air import session
@ -16,9 +16,14 @@ def on_fit_epoch_end(trainer):
"""Sends training metrics to Ray Tune at end of each epoch."""
if ray.tune.is_session_enabled():
metrics = trainer.metrics
metrics['epoch'] = trainer.epoch
metrics["epoch"] = trainer.epoch
session.report(metrics)
callbacks = {
'on_fit_epoch_end': on_fit_epoch_end, } if tune else {}
callbacks = (
{
"on_fit_epoch_end": on_fit_epoch_end,
}
if tune
else {}
)

View file

@ -7,7 +7,7 @@ try:
from torch.utils.tensorboard import SummaryWriter
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['tensorboard'] is True # verify integration is enabled
assert SETTINGS["tensorboard"] is True # verify integration is enabled
WRITER = None # TensorBoard SummaryWriter instance
except (ImportError, AssertionError, TypeError):
@ -34,10 +34,10 @@ def _log_tensorboard_graph(trainer):
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')
LOGGER.warning(f"WARNING ⚠️ TensorBoard graph visualization failure {e}")
def on_pretrain_routine_start(trainer):
@ -46,10 +46,10 @@ def on_pretrain_routine_start(trainer):
try:
global WRITER
WRITER = SummaryWriter(str(trainer.save_dir))
prefix = colorstr('TensorBoard: ')
prefix = colorstr("TensorBoard: ")
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
def on_train_start(trainer):
@ -60,7 +60,7 @@ def on_train_start(trainer):
def on_train_epoch_end(trainer):
"""Logs scalar statistics at the end of a training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
@ -69,8 +69,13 @@ def on_fit_epoch_end(trainer):
_log_scalars(trainer.metrics, trainer.epoch + 1)
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_start': on_train_start,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_epoch_end': on_train_epoch_end} if SummaryWriter else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_start": on_train_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_epoch_end": on_train_epoch_end,
}
if SummaryWriter
else {}
)

View file

@ -5,10 +5,10 @@ from ultralytics.utils.torch_utils import model_info_for_loggers
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['wandb'] is True # verify integration is enabled
assert SETTINGS["wandb"] is True # verify integration is enabled
import wandb as wb
assert hasattr(wb, '__version__') # verify package is not directory
assert hasattr(wb, "__version__") # verify package is not directory
import numpy as np
import pandas as pd
@ -19,7 +19,7 @@ except (ImportError, AssertionError):
wb = None
def _custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall', y_title='Precision'):
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
"""
Create and log a custom metric visualization to wandb.plot.pr_curve.
@ -37,24 +37,25 @@ def _custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall
Returns:
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
"""
df = pd.DataFrame({'class': classes, 'y': y, 'x': x}).round(3)
fields = {'x': 'x', 'y': 'y', 'class': 'class'}
string_fields = {'title': title, 'x-axis-title': x_title, 'y-axis-title': y_title}
return wb.plot_table('wandb/area-under-curve/v0',
wb.Table(dataframe=df),
fields=fields,
string_fields=string_fields)
df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
fields = {"x": "x", "y": "y", "class": "class"}
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
return wb.plot_table(
"wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
)
def _plot_curve(x,
y,
names=None,
id='precision-recall',
title='Precision Recall Curve',
x_title='Recall',
y_title='Precision',
num_x=100,
only_mean=False):
def _plot_curve(
x,
y,
names=None,
id="precision-recall",
title="Precision Recall Curve",
x_title="Recall",
y_title="Precision",
num_x=100,
only_mean=False,
):
"""
Log a metric curve visualization.
@ -88,7 +89,7 @@ def _plot_curve(x,
table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
else:
classes = ['mean'] * len(x_log)
classes = ["mean"] * len(x_log)
for i, yi in enumerate(y):
x_log.extend(x_new) # add new x
y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
@ -99,7 +100,7 @@ def _plot_curve(x,
def _log_plots(plots, step):
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
for name, params in plots.items():
timestamp = params['timestamp']
timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
_processed_plots[name] = timestamp
@ -107,7 +108,7 @@ def _log_plots(plots, step):
def on_pretrain_routine_start(trainer):
"""Initiate and start project if module is present."""
wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args))
wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))
def on_fit_epoch_end(trainer):
@ -121,7 +122,7 @@ def on_fit_epoch_end(trainer):
def on_train_epoch_end(trainer):
"""Log metrics and save images at the end of each training epoch."""
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
wb.run.log(trainer.lr, step=trainer.epoch + 1)
if trainer.epoch == 1:
_log_plots(trainer.plots, step=trainer.epoch + 1)
@ -131,17 +132,17 @@ def on_train_end(trainer):
"""Save the best model as an artifact at end of training."""
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
_log_plots(trainer.plots, step=trainer.epoch + 1)
art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
if trainer.best.exists():
art.add_file(trainer.best)
wb.run.log_artifact(art, aliases=['best'])
wb.run.log_artifact(art, aliases=["best"])
for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
x, y, x_title, y_title = curve_values
_plot_curve(
x,
y,
names=list(trainer.validator.metrics.names.values()),
id=f'curves/{curve_name}',
id=f"curves/{curve_name}",
title=curve_name,
x_title=x_title,
y_title=y_title,
@ -149,8 +150,13 @@ def on_train_end(trainer):
wb.run.finish() # required or run continues on dashboard
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if wb else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if wb
else {}
)

View file

@ -21,12 +21,33 @@ import requests
import torch
from matplotlib import font_manager
from ultralytics.utils import (ASSETS, AUTOINSTALL, LINUX, LOGGER, ONLINE, ROOT, USER_CONFIG_DIR, SimpleNamespace,
ThreadingLocked, TryExcept, clean_url, colorstr, downloads, emojis, is_colab, is_docker,
is_github_action_running, is_jupyter, is_kaggle, is_online, is_pip_package, url2file)
from ultralytics.utils import (
ASSETS,
AUTOINSTALL,
LINUX,
LOGGER,
ONLINE,
ROOT,
USER_CONFIG_DIR,
SimpleNamespace,
ThreadingLocked,
TryExcept,
clean_url,
colorstr,
downloads,
emojis,
is_colab,
is_docker,
is_github_action_running,
is_jupyter,
is_kaggle,
is_online,
is_pip_package,
url2file,
)
def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
"""
Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
@ -46,23 +67,23 @@ def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
"""
if package:
requires = [x for x in metadata.distribution(package).requires if 'extra == ' not in x]
requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
else:
requires = Path(file_path).read_text().splitlines()
requirements = []
for line in requires:
line = line.strip()
if line and not line.startswith('#'):
line = line.split('#')[0].strip() # ignore inline comments
match = re.match(r'([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?', line)
if line and not line.startswith("#"):
line = line.split("#")[0].strip() # ignore inline comments
match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line)
if match:
requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ''))
requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
return requirements
def parse_version(version='0.0.0') -> tuple:
def parse_version(version="0.0.0") -> tuple:
"""
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
function replaces deprecated 'pkg_resources.parse_version(v)'.
@ -74,9 +95,9 @@ def parse_version(version='0.0.0') -> tuple:
(tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
"""
try:
return tuple(map(int, re.findall(r'\d+', version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}')
LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}")
return 0, 0, 0
@ -121,15 +142,19 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
elif isinstance(imgsz, (list, tuple)):
imgsz = list(imgsz)
else:
raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'")
raise TypeError(
f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
)
# Apply max_dim
if len(imgsz) > max_dim:
msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
msg = (
"'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
)
if max_dim != 1:
raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}')
raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
imgsz = [max(imgsz)]
# Make image size a multiple of the stride
@ -137,7 +162,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
# Print warning message if image size was updated
if sz != imgsz:
LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}')
LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
# Add missing dimensions if necessary
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
@ -145,12 +170,14 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
return sz
def check_version(current: str = '0.0.0',
required: str = '0.0.0',
name: str = 'version',
hard: bool = False,
verbose: bool = False,
msg: str = '') -> bool:
def check_version(
current: str = "0.0.0",
required: str = "0.0.0",
name: str = "version",
hard: bool = False,
verbose: bool = False,
msg: str = "",
) -> bool:
"""
Check current version against the required version or range.
@ -181,7 +208,7 @@ def check_version(current: str = '0.0.0',
```
"""
if not current: # if current is '' or None
LOGGER.warning(f'WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.')
LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.")
return True
elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'
try:
@ -189,34 +216,34 @@ def check_version(current: str = '0.0.0',
current = metadata.version(current) # get version string from package name
except metadata.PackageNotFoundError:
if hard:
raise ModuleNotFoundError(emojis(f'WARNING ⚠️ {current} package is required but not installed'))
raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed"))
else:
return False
if not required: # if required is '' or None
return True
op = ''
version = ''
op = ""
version = ""
result = True
c = parse_version(current) # '1.2.3' -> (1, 2, 3)
for r in required.strip(',').split(','):
op, version = re.match(r'([^0-9]*)([\d.]+)', r).groups() # split '>=22.04' -> ('>=', '22.04')
for r in required.strip(",").split(","):
op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
v = parse_version(version) # '1.2.3' -> (1, 2, 3)
if op == '==' and c != v:
if op == "==" and c != v:
result = False
elif op == '!=' and c == v:
elif op == "!=" and c == v:
result = False
elif op in ('>=', '') and not (c >= v): # if no constraint passed assume '>=required'
elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required'
result = False
elif op == '<=' and not (c <= v):
elif op == "<=" and not (c <= v):
result = False
elif op == '>' and not (c > v):
elif op == ">" and not (c > v):
result = False
elif op == '<' and not (c < v):
elif op == "<" and not (c < v):
result = False
if not result:
warning = f'WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}'
warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}"
if hard:
raise ModuleNotFoundError(emojis(warning)) # assert version requirements met
if verbose:
@ -224,7 +251,7 @@ def check_version(current: str = '0.0.0',
return result
def check_latest_pypi_version(package_name='ultralytics'):
def check_latest_pypi_version(package_name="ultralytics"):
"""
Returns the latest version of a PyPI package without downloading or installing it.
@ -236,9 +263,9 @@ def check_latest_pypi_version(package_name='ultralytics'):
"""
with contextlib.suppress(Exception):
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
response = requests.get(f'https://pypi.org/pypi/{package_name}/json', timeout=3)
response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
if response.status_code == 200:
return response.json()['info']['version']
return response.json()["info"]["version"]
def check_pip_update_available():
@ -251,16 +278,19 @@ def check_pip_update_available():
if ONLINE and is_pip_package():
with contextlib.suppress(Exception):
from ultralytics import __version__
latest = check_latest_pypi_version()
if check_version(__version__, f'<{latest}'): # check if current version is < latest version
LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 '
f"Update with 'pip install -U ultralytics'")
if check_version(__version__, f"<{latest}"): # check if current version is < latest version
LOGGER.info(
f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
f"Update with 'pip install -U ultralytics'"
)
return True
return False
@ThreadingLocked()
def check_font(font='Arial.ttf'):
def check_font(font="Arial.ttf"):
"""
Find font locally or download to user's configuration directory if it does not already exist.
@ -283,13 +313,13 @@ def check_font(font='Arial.ttf'):
return matches[0]
# Download to USER_CONFIG_DIR if missing
url = f'https://ultralytics.com/assets/{name}'
url = f"https://ultralytics.com/assets/{name}"
if downloads.is_url(url):
downloads.safe_download(url=url, file=file)
return file
def check_python(minimum: str = '3.8.0') -> bool:
def check_python(minimum: str = "3.8.0") -> bool:
"""
Check current python version against the required minimum version.
@ -299,11 +329,11 @@ def check_python(minimum: str = '3.8.0') -> bool:
Returns:
None
"""
return check_version(platform.python_version(), minimum, name='Python ', hard=True)
return check_version(platform.python_version(), minimum, name="Python ", hard=True)
@TryExcept()
def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''):
def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
"""
Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed.
@ -329,41 +359,42 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
```
"""
prefix = colorstr('red', 'bold', 'requirements:')
prefix = colorstr("red", "bold", "requirements:")
check_python() # check python version
check_torchvision() # check torch-torchvision compatibility
if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve()
assert file.exists(), f'{prefix} {file} not found, check failed.'
requirements = [f'{x.name}{x.specifier}' for x in parse_requirements(file) if x.name not in exclude]
assert file.exists(), f"{prefix} {file} not found, check failed."
requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
elif isinstance(requirements, str):
requirements = [requirements]
pkgs = []
for r in requirements:
r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
match = re.match(r'([a-zA-Z0-9-_]+)([<>!=~]+.*)?', r_stripped)
name, required = match[1], match[2].strip() if match[2] else ''
r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
name, required = match[1], match[2].strip() if match[2] else ""
try:
assert check_version(metadata.version(name), required) # exception if requirements not met
except (AssertionError, metadata.PackageNotFoundError):
pkgs.append(r)
s = ' '.join(f'"{x}"' for x in pkgs) # console string
s = " ".join(f'"{x}"' for x in pkgs) # console string
if s:
if install and AUTOINSTALL: # check environment variable
n = len(pkgs) # number of packages updates
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
try:
t = time.time()
assert is_online(), 'AutoUpdate skipped (offline)'
LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
assert is_online(), "AutoUpdate skipped (offline)"
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
dt = time.time() - t
LOGGER.info(
f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n")
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
)
except Exception as e:
LOGGER.warning(f'{prefix}{e}')
LOGGER.warning(f"{prefix}{e}")
return False
else:
return False
@ -386,76 +417,82 @@ def check_torchvision():
import torchvision
# Compatibility table
compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
# Extract only the major and minor versions
v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2])
if v_torch in compatibility_table:
compatible_versions = compatibility_table[v_torch]
if all(v_torchvision != v for v in compatible_versions):
print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
"'pip install -U torch torchvision' to update both.\n"
'For a full compatibility table see https://github.com/pytorch/vision#installation')
print(
f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
"'pip install -U torch torchvision' to update both.\n"
"For a full compatibility table see https://github.com/pytorch/vision#installation"
)
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""):
"""Check file(s) for acceptable suffix."""
if file and suffix:
if isinstance(suffix, str):
suffix = (suffix, )
suffix = (suffix,)
for f in file if isinstance(file, (list, tuple)) else [file]:
s = Path(f).suffix.lower().strip() # file suffix
if len(s):
assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}, not {s}'
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"
def check_yolov5u_filename(file: str, verbose: bool = True):
"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
if 'yolov3' in file or 'yolov5' in file:
if 'u.yaml' in file:
file = file.replace('u.yaml', '.yaml') # i.e. yolov5nu.yaml -> yolov5n.yaml
elif '.pt' in file and 'u' not in file:
if "yolov3" in file or "yolov5" in file:
if "u.yaml" in file:
file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
elif ".pt" in file and "u" not in file:
original_file = file
file = re.sub(r'(.*yolov5([nsmlx]))\.pt', '\\1u.pt', file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r'(.*yolov5([nsmlx])6)\.pt', '\\1u.pt', file) # i.e. yolov5n6.pt -> yolov5n6u.pt
file = re.sub(r'(.*yolov3(|-tiny|-spp))\.pt', '\\1u.pt', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file and verbose:
LOGGER.info(
f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
)
return file
def check_model_file_from_stem(model='yolov8n'):
def check_model_file_from_stem(model="yolov8n"):
"""Return a model filename from a valid model stem."""
if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
return Path(model).with_suffix('.pt') # add suffix, i.e. yolov8n -> yolov8n.pt
return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt
else:
return model
def check_file(file, suffix='', download=True, hard=True):
def check_file(file, suffix="", download=True, hard=True):
"""Search/download file (if necessary) and return path."""
check_suffix(file, suffix) # optional
file = str(file).strip() # convert to string and strip spaces
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
if (not file or ('://' not in file and Path(file).exists()) or # '://' check required in Windows Python<3.10
file.lower().startswith('grpc://')): # file exists or gRPC Triton images
if (
not file
or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
or file.lower().startswith("grpc://")
): # file exists or gRPC Triton images
return file
elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://', 'tcp://')): # download
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
url = file # warning: Pathlib turns :// -> :/
file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
if Path(file).exists():
LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else:
downloads.safe_download(url=url, file=file, unzip=False)
return file
else: # search
files = glob.glob(str(ROOT / 'cfg' / '**' / file), recursive=True) # find file
files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file
if not files and hard:
raise FileNotFoundError(f"'{file}' does not exist")
elif len(files) > 1 and hard:
@ -463,7 +500,7 @@ def check_file(file, suffix='', download=True, hard=True):
return files[0] if len(files) else [] # return file
def check_yaml(file, suffix=('.yaml', '.yml'), hard=True):
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
"""Search/download YAML file (if necessary) and return path, checking suffix."""
return check_file(file, suffix, hard=hard)
@ -482,51 +519,52 @@ def check_is_path_safe(basedir, path):
base_dir_resolved = Path(basedir).resolve()
path_resolved = Path(path).resolve()
return path_resolved.is_file() and path_resolved.parts[:len(base_dir_resolved.parts)] == base_dir_resolved.parts
return path_resolved.is_file() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
def check_imshow(warn=False):
"""Check if environment supports image displays."""
try:
if LINUX:
assert 'DISPLAY' in os.environ and not is_docker() and not is_colab() and not is_kaggle()
cv2.imshow('test', np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
assert "DISPLAY" in os.environ and not is_docker() and not is_colab() and not is_kaggle()
cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
return True
except Exception as e:
if warn:
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
return False
def check_yolo(verbose=True, device=''):
def check_yolo(verbose=True, device=""):
"""Return a human-readable YOLO software and hardware summary."""
import psutil
from ultralytics.utils.torch_utils import select_device
if is_jupyter():
if check_requirements('wandb', install=False):
os.system('pip uninstall -y wandb') # uninstall wandb: unwanted account creation prompt with infinite hang
if check_requirements("wandb", install=False):
os.system("pip uninstall -y wandb") # uninstall wandb: unwanted account creation prompt with infinite hang
if is_colab():
shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory
shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
if verbose:
# System info
gib = 1 << 30 # bytes per GiB
ram = psutil.virtual_memory().total
total, used, free = shutil.disk_usage('/')
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
total, used, free = shutil.disk_usage("/")
s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
with contextlib.suppress(Exception): # clear display if ipython is installed
from IPython import display
display.clear_output()
else:
s = ''
s = ""
select_device(device=device, newline=False)
LOGGER.info(f'Setup complete ✅ {s}')
LOGGER.info(f"Setup complete ✅ {s}")
def collect_system_info():
@ -537,32 +575,36 @@ def collect_system_info():
from ultralytics.utils import ENVIRONMENT, is_git_dir
from ultralytics.utils.torch_utils import get_cpu_info
ram_info = psutil.virtual_memory().total / (1024 ** 3) # Convert bytes to GB
ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
check_yolo()
LOGGER.info(f"\n{'OS':<20}{platform.platform()}\n"
f"{'Environment':<20}{ENVIRONMENT}\n"
f"{'Python':<20}{sys.version.split()[0]}\n"
f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n"
f"{'RAM':<20}{ram_info:.2f} GB\n"
f"{'CPU':<20}{get_cpu_info()}\n"
f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n")
LOGGER.info(
f"\n{'OS':<20}{platform.platform()}\n"
f"{'Environment':<20}{ENVIRONMENT}\n"
f"{'Python':<20}{sys.version.split()[0]}\n"
f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n"
f"{'RAM':<20}{ram_info:.2f} GB\n"
f"{'CPU':<20}{get_cpu_info()}\n"
f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n"
)
for r in parse_requirements(package='ultralytics'):
for r in parse_requirements(package="ultralytics"):
try:
current = metadata.version(r.name)
is_met = '' if check_version(current, str(r.specifier), hard=True) else ''
is_met = "" if check_version(current, str(r.specifier), hard=True) else ""
except metadata.PackageNotFoundError:
current = '(not installed)'
is_met = ''
LOGGER.info(f'{r.name:<20}{is_met}{current}{r.specifier}')
current = "(not installed)"
is_met = ""
LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}")
if is_github_action_running():
LOGGER.info(f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n")
LOGGER.info(
f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n"
)
def check_amp(model):
@ -587,7 +629,7 @@ def check_amp(model):
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
"""
device = next(model.parameters()).device # get model device
if device.type in ('cpu', 'mps'):
if device.type in ("cpu", "mps"):
return False # AMP only used on CUDA devices
def amp_allclose(m, im):
@ -598,22 +640,27 @@ def check_amp(model):
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
im = ASSETS / 'bus.jpg' # image to check
prefix = colorstr('AMP: ')
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
im = ASSETS / "bus.jpg" # image to check
prefix = colorstr("AMP: ")
LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...")
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
try:
from ultralytics import YOLO
assert amp_allclose(YOLO('yolov8n.pt'), im)
LOGGER.info(f'{prefix}checks passed ✅')
assert amp_allclose(YOLO("yolov8n.pt"), im)
LOGGER.info(f"{prefix}checks passed ✅")
except ConnectionError:
LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}')
LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}")
except (AttributeError, ModuleNotFoundError):
LOGGER.warning(f'{prefix}checks skipped ⚠️. '
f'Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}')
LOGGER.warning(
f"{prefix}checks skipped ⚠️. "
f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}"
)
except AssertionError:
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
LOGGER.warning(
f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to "
f"NaN losses or zero-mAP results, so AMP will be disabled during training."
)
return False
return True
@ -621,8 +668,8 @@ def check_amp(model):
def git_describe(path=ROOT): # path must be a directory
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
with contextlib.suppress(Exception):
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
return ''
return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
return ""
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
@ -630,7 +677,7 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
def strip_auth(v):
"""Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
return clean_url(v) if (isinstance(v, str) and v.startswith('http') and len(v) > 100) else v
return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
x = inspect.currentframe().f_back # previous frame
file, _, func, _, _ = inspect.getframeinfo(x)
@ -638,11 +685,11 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
try:
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
except ValueError:
file = Path(file).stem
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
LOGGER.info(colorstr(s) + ', '.join(f'{k}={strip_auth(v)}' for k, v in args.items()))
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items()))
def cuda_device_count() -> int:
@ -654,11 +701,12 @@ def cuda_device_count() -> int:
"""
try:
# Run the nvidia-smi command and capture its output
output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'],
encoding='utf-8')
output = subprocess.check_output(
["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
)
# Take the first line and strip any leading/trailing white space
first_line = output.strip().split('\n')[0]
first_line = output.strip().split("\n")[0]
return int(first_line)
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):

View file

@ -18,13 +18,13 @@ def find_free_network_port() -> int:
`MASTER_PORT` environment variable.
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
s.bind(("127.0.0.1", 0))
return s.getsockname()[1] # port
def generate_ddp_file(trainer):
"""Generates a DDP file and returns its file name."""
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
content = f"""
# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
@ -39,13 +39,15 @@ if __name__ == "__main__":
trainer = {name}(cfg=cfg, overrides=overrides)
results = trainer.train()
"""
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(prefix='_temp_',
suffix=f'{id(trainer)}.py',
mode='w+',
encoding='utf-8',
dir=USER_CONFIG_DIR / 'DDP',
delete=False) as file:
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(
prefix="_temp_",
suffix=f"{id(trainer)}.py",
mode="w+",
encoding="utf-8",
dir=USER_CONFIG_DIR / "DDP",
delete=False,
) as file:
file.write(content)
return file.name
@ -53,16 +55,17 @@ if __name__ == "__main__":
def generate_ddp_command(world_size, trainer):
"""Generates and returns command for distributed training."""
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir
file = generate_ddp_file(trainer)
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
port = find_free_network_port()
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
return cmd, file
def ddp_cleanup(trainer, file):
"""Delete temp file if created."""
if f'{id(trainer)}.py' in file: # if temp_file suffix in file
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
os.remove(file)

View file

@ -15,15 +15,17 @@ import torch
from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
GITHUB_ASSETS_REPO = 'ultralytics/assets'
GITHUB_ASSETS_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '-cls', '-seg', '-pose', '-obb')] + \
[f'yolov5{k}{resolution}u.pt' for k in 'nsmlx' for resolution in ('', '6')] + \
[f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \
[f'yolo_nas_{k}.pt' for k in 'sml'] + \
[f'sam_{k}.pt' for k in 'bl'] + \
[f'FastSAM-{k}.pt' for k in 'sx'] + \
[f'rtdetr-{k}.pt' for k in 'lx'] + \
['mobile_sam.pt']
GITHUB_ASSETS_REPO = "ultralytics/assets"
GITHUB_ASSETS_NAMES = (
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
+ [f"yolo_nas_{k}.pt" for k in "sml"]
+ [f"sam_{k}.pt" for k in "bl"]
+ [f"FastSAM-{k}.pt" for k in "sx"]
+ [f"rtdetr-{k}.pt" for k in "lx"]
+ ["mobile_sam.pt"]
)
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
@ -56,7 +58,7 @@ def is_url(url, check=True):
return False
def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
"""
Deletes all ".DS_store" files under a specified directory.
@ -77,12 +79,12 @@ def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
"""
for file in files_to_delete:
matches = list(Path(path).rglob(file))
LOGGER.info(f'Deleting {file} files: {matches}')
LOGGER.info(f"Deleting {file} files: {matches}")
for f in matches:
f.unlink()
def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), progress=True):
def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):
"""
Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
named after the directory and placed alongside it.
@ -111,17 +113,17 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
# Unzip with progress bar
files_to_zip = [f for f in directory.rglob('*') if f.is_file() and all(x not in f.name for x in exclude)]
zip_file = directory.with_suffix('.zip')
files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)]
zip_file = directory.with_suffix(".zip")
compression = ZIP_DEFLATED if compress else ZIP_STORED
with ZipFile(zip_file, 'w', compression) as f:
for file in TQDM(files_to_zip, desc=f'Zipping {directory} to {zip_file}...', unit='file', disable=not progress):
with ZipFile(zip_file, "w", compression) as f:
for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress):
f.write(file, file.relative_to(directory))
return zip_file # return path to zip file
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=False, progress=True):
def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
"""
Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
@ -161,7 +163,7 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
top_level_dirs = {Path(f).parts[0] for f in files}
if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith('/')):
if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith("/")):
# Zip has multiple files at top level
path = extract_path = Path(path) / Path(file).stem # i.e. ../datasets/coco8
else:
@ -172,20 +174,20 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
# Check if destination directory already exists and contains files
if path.exists() and any(path.iterdir()) and not exist_ok:
# If it exists and is not empty, return the path without unzipping
LOGGER.warning(f'WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.')
LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.")
return path
for f in TQDM(files, desc=f'Unzipping {file} to {Path(path).resolve()}...', unit='file', disable=not progress):
for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress):
# Ensure the file is within the extract_path to avoid path traversal security vulnerability
if '..' in Path(f).parts:
LOGGER.warning(f'Potentially insecure file path: {f}, skipping extraction.')
if ".." in Path(f).parts:
LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
continue
zipObj.extract(f, extract_path)
return path # return unzip dir
def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, hard=True):
def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", sf=1.5, hard=True):
"""
Check if there is sufficient disk space to download and store a file.
@ -199,20 +201,23 @@ def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, h
"""
try:
r = requests.head(url) # response
assert r.status_code < 400, f'URL error for {url}: {r.status_code} {r.reason}' # check response
assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response
except Exception:
return True # requests issue, default to True
# Check file size
gib = 1 << 30 # bytes per GiB
data = int(r.headers.get('Content-Length', 0)) / gib # file size (GB)
data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB)
total, used, free = (x / gib for x in shutil.disk_usage(Path.cwd())) # bytes
if data * sf < free:
return True # sufficient space
# Insufficient space
text = (f'WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, '
f'Please free {data * sf - free:.1f} GB additional disk space and try again.')
text = (
f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, "
f"Please free {data * sf - free:.1f} GB additional disk space and try again."
)
if hard:
raise MemoryError(text)
LOGGER.warning(text)
@ -238,36 +243,41 @@ def get_google_drive_file_info(link):
url, filename = get_google_drive_file_info(link)
```
"""
file_id = link.split('/d/')[1].split('/view')[0]
drive_url = f'https://drive.google.com/uc?export=download&id={file_id}'
file_id = link.split("/d/")[1].split("/view")[0]
drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
filename = None
# Start session
with requests.Session() as session:
response = session.get(drive_url, stream=True)
if 'quota exceeded' in str(response.content.lower()):
if "quota exceeded" in str(response.content.lower()):
raise ConnectionError(
emojis(f'❌ Google Drive file download quota exceeded. '
f'Please try again later or download this file manually at {link}.'))
emojis(
f"❌ Google Drive file download quota exceeded. "
f"Please try again later or download this file manually at {link}."
)
)
for k, v in response.cookies.items():
if k.startswith('download_warning'):
drive_url += f'&confirm={v}' # v is token
cd = response.headers.get('content-disposition')
if k.startswith("download_warning"):
drive_url += f"&confirm={v}" # v is token
cd = response.headers.get("content-disposition")
if cd:
filename = re.findall('filename="(.+)"', cd)[0]
return drive_url, filename
def safe_download(url,
file=None,
dir=None,
unzip=True,
delete=False,
curl=False,
retry=3,
min_bytes=1E0,
exist_ok=False,
progress=True):
def safe_download(
url,
file=None,
dir=None,
unzip=True,
delete=False,
curl=False,
retry=3,
min_bytes=1e0,
exist_ok=False,
progress=True,
):
"""
Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
@ -294,36 +304,38 @@ def safe_download(url,
path = safe_download(link)
```
"""
gdrive = url.startswith('https://drive.google.com/') # check if the URL is a Google Drive link
gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link
if gdrive:
url, file = get_google_drive_file_info(url)
f = Path(dir or '.') / (file or url2file(url)) # URL converted to filename
if '://' not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename
if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename
elif not f.is_file(): # URL and file do not exist
desc = f"Downloading {url if gdrive else clean_url(url)} to '{f}'"
LOGGER.info(f'{desc}...')
LOGGER.info(f"{desc}...")
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
check_disk_space(url)
for i in range(retry + 1):
try:
if curl or i > 0: # curl download with retry, continue
s = 'sS' * (not progress) # silent
r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode
assert r == 0, f'Curl return value {r}'
s = "sS" * (not progress) # silent
r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
assert r == 0, f"Curl return value {r}"
else: # urllib download
method = 'torch'
if method == 'torch':
method = "torch"
if method == "torch":
torch.hub.download_url_to_file(url, f, progress=progress)
else:
with request.urlopen(url) as response, TQDM(total=int(response.getheader('Content-Length', 0)),
desc=desc,
disable=not progress,
unit='B',
unit_scale=True,
unit_divisor=1024) as pbar:
with open(f, 'wb') as f_opened:
with request.urlopen(url) as response, TQDM(
total=int(response.getheader("Content-Length", 0)),
desc=desc,
disable=not progress,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as pbar:
with open(f, "wb") as f_opened:
for data in response:
f_opened.write(data)
pbar.update(len(data))
@ -334,26 +346,26 @@ def safe_download(url,
f.unlink() # remove partial downloads
except Exception as e:
if i == 0 and not is_online():
raise ConnectionError(emojis(f'❌ Download failure for {url}. Environment is not online.')) from e
raise ConnectionError(emojis(f"❌ Download failure for {url}. Environment is not online.")) from e
elif i >= retry:
raise ConnectionError(emojis(f'❌ Download failure for {url}. Retry limit reached.')) from e
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e
LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
if unzip and f.exists() and f.suffix in ('', '.zip', '.tar', '.gz'):
if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"):
from zipfile import is_zipfile
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
if is_zipfile(f):
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
elif f.suffix in ('.tar', '.gz'):
LOGGER.info(f'Unzipping {f} to {unzip_dir}...')
subprocess.run(['tar', 'xf' if f.suffix == '.tar' else 'xfz', f, '--directory', unzip_dir], check=True)
elif f.suffix in (".tar", ".gz"):
LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
if delete:
f.unlink() # remove zip
return unzip_dir
def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
"""
Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the
function fetches the latest release assets.
@ -372,20 +384,20 @@ def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
```
"""
if version != 'latest':
version = f'tags/{version}' # i.e. tags/v6.2
url = f'https://api.github.com/repos/{repo}/releases/{version}'
if version != "latest":
version = f"tags/{version}" # i.e. tags/v6.2
url = f"https://api.github.com/repos/{repo}/releases/{version}"
r = requests.get(url) # github api
if r.status_code != 200 and r.reason != 'rate limit exceeded' and retry: # failed and not 403 rate limit exceeded
if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded
r = requests.get(url) # try again
if r.status_code != 200:
LOGGER.warning(f'⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}')
return '', []
LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}")
return "", []
data = r.json()
return data['tag_name'], [x['name'] for x in data['assets']] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0', **kwargs):
def attempt_download_asset(file, repo="ultralytics/assets", release="v0.0.0", **kwargs):
"""
Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file
locally first, then tries to download it from the specified GitHub repository release.
@ -409,32 +421,32 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0', **
# YOLOv3/5u updates
file = str(file)
file = checks.check_yolov5u_filename(file)
file = Path(file.strip().replace("'", ''))
file = Path(file.strip().replace("'", ""))
if file.exists():
return str(file)
elif (SETTINGS['weights_dir'] / file).exists():
return str(SETTINGS['weights_dir'] / file)
elif (SETTINGS["weights_dir"] / file).exists():
return str(SETTINGS["weights_dir"] / file)
else:
# URL specified
name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
download_url = f'https://github.com/{repo}/releases/download'
if str(file).startswith(('http:/', 'https:/')): # download
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
download_url = f"https://github.com/{repo}/releases/download"
if str(file).startswith(("http:/", "https:/")): # download
url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
file = url2file(name) # parse authentication https://url.com/file.txt?auth...
if Path(file).is_file():
LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else:
safe_download(url=url, file=file, min_bytes=1E5, **kwargs)
safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
safe_download(url=f'{download_url}/{release}/{name}', file=file, min_bytes=1E5, **kwargs)
safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
else:
tag, assets = get_github_assets(repo, release)
if not assets:
tag, assets = get_github_assets(repo) # latest release
if name in assets:
safe_download(url=f'{download_url}/{tag}/{name}', file=file, min_bytes=1E5, **kwargs)
safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
return str(file)
@ -464,14 +476,18 @@ def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=
if threads > 1:
with ThreadPool(threads) as pool:
pool.map(
lambda x: safe_download(url=x[0],
dir=x[1],
unzip=unzip,
delete=delete,
curl=curl,
retry=retry,
exist_ok=exist_ok,
progress=threads <= 1), zip(url, repeat(dir)))
lambda x: safe_download(
url=x[0],
dir=x[1],
unzip=unzip,
delete=delete,
curl=curl,
retry=retry,
exist_ok=exist_ok,
progress=threads <= 1,
),
zip(url, repeat(dir)),
)
pool.close()
pool.join()
else:

View file

@ -17,6 +17,6 @@ class HUBModelError(Exception):
The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package.
"""
def __init__(self, message='Model not found. Please check model URL and try again.'):
def __init__(self, message="Model not found. Please check model URL and try again."):
"""Create an exception for when a model is not found."""
super().__init__(emojis(message))

View file

@ -50,13 +50,13 @@ def spaces_in_path(path):
"""
# If path has spaces, replace them with underscores
if ' ' in str(path):
if " " in str(path):
string = isinstance(path, str) # input type
path = Path(path)
# Create a temporary directory and construct the new path
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir) / path.name.replace(' ', '_')
tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")
# Copy file/directory
if path.is_dir():
@ -82,7 +82,7 @@ def spaces_in_path(path):
yield path
def increment_path(path, exist_ok=False, sep='', mkdir=False):
def increment_path(path, exist_ok=False, sep="", mkdir=False):
"""
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
@ -102,11 +102,11 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
"""
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
# Method 1
for n in range(2, 9999):
p = f'{path}{sep}{n}{suffix}' # increment path
p = f"{path}{sep}{n}{suffix}" # increment path
if not os.path.exists(p):
break
path = Path(p)
@ -119,14 +119,14 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
def file_age(path=__file__):
"""Return days since last file update."""
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
return dt.days # + dt.seconds / 86400 # fractional days
def file_date(path=__file__):
"""Return human-readable file modification date, i.e. '2021-3-26'."""
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
return f'{t.year}-{t.month}-{t.day}'
return f"{t.year}-{t.month}-{t.day}"
def file_size(path):
@ -137,11 +137,11 @@ def file_size(path):
if path.is_file():
return path.stat().st_size / mb
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
return 0.0
def get_latest_run(search_dir='.'):
def get_latest_run(search_dir="."):
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ''
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ""

View file

@ -26,9 +26,9 @@ to_4tuple = _ntuple(4)
# `xyxy` means left top and right bottom
# `xywh` means center x, center y and width, height(YOLO format)
# `ltwh` means left top and width, height(COCO format)
_formats = ['xyxy', 'xywh', 'ltwh']
_formats = ["xyxy", "xywh", "ltwh"]
__all__ = 'Bboxes', # tuple or list
__all__ = ("Bboxes",) # tuple or list
class Bboxes:
@ -46,9 +46,9 @@ class Bboxes:
This class does not handle normalization or denormalization of bounding boxes.
"""
def __init__(self, bboxes, format='xyxy') -> None:
def __init__(self, bboxes, format="xyxy") -> None:
"""Initializes the Bboxes class with bounding box data in a specified format."""
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
assert bboxes.ndim == 2
assert bboxes.shape[1] == 4
@ -58,21 +58,21 @@ class Bboxes:
def convert(self, format):
"""Converts bounding box format from one type to another."""
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
if self.format == format:
return
elif self.format == 'xyxy':
func = xyxy2xywh if format == 'xywh' else xyxy2ltwh
elif self.format == 'xywh':
func = xywh2xyxy if format == 'xyxy' else xywh2ltwh
elif self.format == "xyxy":
func = xyxy2xywh if format == "xywh" else xyxy2ltwh
elif self.format == "xywh":
func = xywh2xyxy if format == "xyxy" else xywh2ltwh
else:
func = ltwh2xyxy if format == 'xyxy' else ltwh2xywh
func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
self.bboxes = func(self.bboxes)
self.format = format
def areas(self):
"""Return box areas."""
self.convert('xyxy')
self.convert("xyxy")
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
# def denormalize(self, w, h):
@ -124,7 +124,7 @@ class Bboxes:
return len(self.bboxes)
@classmethod
def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
"""
Concatenate a list of Bboxes objects into a single Bboxes object.
@ -148,7 +148,7 @@ class Bboxes:
return boxes_list[0]
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
def __getitem__(self, index) -> 'Bboxes':
def __getitem__(self, index) -> "Bboxes":
"""
Retrieve a specific bounding box or a set of bounding boxes using indexing.
@ -169,7 +169,7 @@ class Bboxes:
if isinstance(index, int):
return Bboxes(self.bboxes[index].view(1, -1))
b = self.bboxes[index]
assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!'
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
return Bboxes(b)
@ -205,7 +205,7 @@ class Instances:
This class does not perform input validation, and it assumes the inputs are well-formed.
"""
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None:
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
"""
Args:
bboxes (ndarray): bboxes with shape [N, 4].
@ -263,7 +263,7 @@ class Instances:
def add_padding(self, padw, padh):
"""Handle rect and mosaic situation."""
assert not self.normalized, 'you should add padding with absolute coordinates.'
assert not self.normalized, "you should add padding with absolute coordinates."
self._bboxes.add(offset=(padw, padh, padw, padh))
self.segments[..., 0] += padw
self.segments[..., 1] += padh
@ -271,7 +271,7 @@ class Instances:
self.keypoints[..., 0] += padw
self.keypoints[..., 1] += padh
def __getitem__(self, index) -> 'Instances':
def __getitem__(self, index) -> "Instances":
"""
Retrieve a specific instance or a set of instances using indexing.
@ -301,7 +301,7 @@ class Instances:
def flipud(self, h):
"""Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
if self._bboxes.format == 'xyxy':
if self._bboxes.format == "xyxy":
y1 = self.bboxes[:, 1].copy()
y2 = self.bboxes[:, 3].copy()
self.bboxes[:, 1] = h - y2
@ -314,7 +314,7 @@ class Instances:
def fliplr(self, w):
"""Reverses the order of the bounding boxes and segments horizontally."""
if self._bboxes.format == 'xyxy':
if self._bboxes.format == "xyxy":
x1 = self.bboxes[:, 0].copy()
x2 = self.bboxes[:, 2].copy()
self.bboxes[:, 0] = w - x2
@ -328,10 +328,10 @@ class Instances:
def clip(self, w, h):
"""Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
ori_format = self._bboxes.format
self.convert_bbox(format='xyxy')
self.convert_bbox(format="xyxy")
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
if ori_format != 'xyxy':
if ori_format != "xyxy":
self.convert_bbox(format=ori_format)
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
@ -367,7 +367,7 @@ class Instances:
return len(self.bboxes)
@classmethod
def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
"""
Concatenates a list of Instances objects into a single Instances object.

View file

@ -28,22 +28,27 @@ class VarifocalLoss(nn.Module):
"""Computes varfocal loss."""
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False):
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
weight).mean(1).sum()
loss = (
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
.mean(1)
.sum()
)
return loss
class FocalLoss(nn.Module):
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
def __init__(self, ):
def __init__(
self,
):
"""Initializer for FocalLoss class with no parameters."""
super().__init__()
@staticmethod
def forward(pred, label, gamma=1.5, alpha=0.25):
"""Calculates and updates confusion matrix for object detection/classification tasks."""
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
@ -91,8 +96,10 @@ class BboxLoss(nn.Module):
tr = tl + 1 # target right
wl = tr - target # weight left
wr = 1 - wl # weight right
return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
return (
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
).mean(-1, keepdim=True)
class RotatedBboxLoss(BboxLoss):
@ -145,7 +152,7 @@ class v8DetectionLoss:
h = model.args # hyperparameters
m = model.model[-1] # Detect() module
self.bce = nn.BCEWithLogitsLoss(reduction='none')
self.bce = nn.BCEWithLogitsLoss(reduction="none")
self.hyp = h
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
@ -190,7 +197,8 @@ class v8DetectionLoss:
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats = preds[1] if isinstance(preds, tuple) else preds
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
(self.reg_max * 4, self.nc), 1
)
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
@ -201,7 +209,7 @@ class v8DetectionLoss:
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Targets
targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
@ -210,8 +218,13 @@ class v8DetectionLoss:
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
@ -222,8 +235,9 @@ class v8DetectionLoss:
# Bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
target_scores_sum, fg_mask)
loss[0], loss[2] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.cls # cls gain
@ -246,7 +260,8 @@ class v8SegmentationLoss(v8DetectionLoss):
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
(self.reg_max * 4, self.nc), 1
)
# B, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
@ -259,24 +274,31 @@ class v8SegmentationLoss(v8DetectionLoss):
# Targets
try:
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
except RuntimeError as e:
raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
'as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help.') from e
raise TypeError(
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
) from e
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
@ -286,15 +308,23 @@ class v8SegmentationLoss(v8DetectionLoss):
if fg_mask.sum():
# Bbox loss
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
target_scores, target_scores_sum, fg_mask)
loss[0], loss[3] = self.bbox_loss(
pred_distri,
pred_bboxes,
anchor_points,
target_bboxes / stride_tensor,
target_scores,
target_scores_sum,
fg_mask,
)
# Masks loss
masks = batch['masks'].to(self.device).float()
masks = batch["masks"].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
loss[1] = self.calculate_segmentation_loss(fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto,
pred_masks, imgsz, self.overlap)
loss[1] = self.calculate_segmentation_loss(
fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
)
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
@ -308,8 +338,9 @@ class v8SegmentationLoss(v8DetectionLoss):
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
@staticmethod
def single_mask_loss(gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor,
area: torch.Tensor) -> torch.Tensor:
def single_mask_loss(
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
) -> torch.Tensor:
"""
Compute the instance segmentation loss for a single image.
@ -327,8 +358,8 @@ class v8SegmentationLoss(v8DetectionLoss):
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
predicted masks from the prototype masks and predicted mask coefficients.
"""
pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
def calculate_segmentation_loss(
@ -387,8 +418,9 @@ class v8SegmentationLoss(v8DetectionLoss):
else:
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
loss += self.single_mask_loss(gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i],
marea_i[fg_mask_i])
loss += self.single_mask_loss(
gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
)
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
@ -415,7 +447,8 @@ class v8PoseLoss(v8DetectionLoss):
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
(self.reg_max * 4, self.nc), 1
)
# B, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
@ -428,8 +461,8 @@ class v8PoseLoss(v8DetectionLoss):
# Targets
batch_size = pred_scores.shape[0]
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
@ -439,8 +472,13 @@ class v8PoseLoss(v8DetectionLoss):
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
@ -451,14 +489,16 @@ class v8PoseLoss(v8DetectionLoss):
# Bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
target_scores_sum, fg_mask)
keypoints = batch['keypoints'].to(self.device).float().clone()
loss[0], loss[4] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
keypoints = batch["keypoints"].to(self.device).float().clone()
keypoints[..., 0] *= imgsz[1]
keypoints[..., 1] *= imgsz[0]
loss[1], loss[2] = self.calculate_keypoints_loss(fg_mask, target_gt_idx, keypoints, batch_idx,
stride_tensor, target_bboxes, pred_kpts)
loss[1], loss[2] = self.calculate_keypoints_loss(
fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
)
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.pose # pose gain
@ -477,8 +517,9 @@ class v8PoseLoss(v8DetectionLoss):
y[..., 1] += anchor_points[:, [1]] - 0.5
return y
def calculate_keypoints_loss(self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes,
pred_kpts):
def calculate_keypoints_loss(
self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
):
"""
Calculate the keypoints loss for the model.
@ -507,21 +548,23 @@ class v8PoseLoss(v8DetectionLoss):
max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
# Create a tensor to hold batched keypoints
batched_keypoints = torch.zeros((batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]),
device=keypoints.device)
batched_keypoints = torch.zeros(
(batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
)
# TODO: any idea how to vectorize this?
# Fill batched_keypoints with keypoints based on batch_idx
for i in range(batch_size):
keypoints_i = keypoints[batch_idx == i]
batched_keypoints[i, :keypoints_i.shape[0]] = keypoints_i
batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
# Expand dimensions of target_gt_idx to match the shape of batched_keypoints
target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
# Use target_gt_idx_expanded to select keypoints from batched_keypoints
selected_keypoints = batched_keypoints.gather(
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2]))
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
)
# Divide coordinates by stride
selected_keypoints /= stride_tensor.view(1, -1, 1, 1)
@ -547,13 +590,12 @@ class v8ClassificationLoss:
def __call__(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='mean')
loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")
loss_items = loss.detach()
return loss, loss_items
class v8OBBLoss(v8DetectionLoss):
def __init__(self, model): # model must be de-paralleled
super().__init__(model)
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
@ -583,7 +625,8 @@ class v8OBBLoss(v8DetectionLoss):
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
(self.reg_max * 4, self.nc), 1
)
# b, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
@ -596,19 +639,21 @@ class v8OBBLoss(v8DetectionLoss):
# targets
try:
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes'].view(-1, 5)), 1)
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
except RuntimeError as e:
raise TypeError('ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n'
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
"correctly formatted 'OBB' dataset using 'data=coco8-obb.yaml' "
'as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help.') from e
raise TypeError(
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
"correctly formatted 'OBB' dataset using 'data=coco8-obb.yaml' "
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
) from e
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
@ -616,10 +661,14 @@ class v8OBBLoss(v8DetectionLoss):
bboxes_for_assigner = pred_bboxes.clone().detach()
# Only the first four elements need to be scaled
bboxes_for_assigner[..., :4] *= stride_tensor
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(pred_scores.detach().sigmoid(),
bboxes_for_assigner.type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes,
mask_gt)
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
pred_scores.detach().sigmoid(),
bboxes_for_assigner.type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
@ -630,8 +679,9 @@ class v8OBBLoss(v8DetectionLoss):
# Bbox loss
if fg_mask.sum():
target_bboxes[..., :4] /= stride_tensor
loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
target_scores_sum, fg_mask)
loss[0], loss[2] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
else:
loss[0] += (pred_angle * 0).sum()

View file

@ -11,7 +11,10 @@ import torch
from ultralytics.utils import LOGGER, SimpleClass, TryExcept, plt_settings
OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
OKS_SIGMA = (
np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
/ 10.0
)
def bbox_ioa(box1, box2, iou=False, eps=1e-7):
@ -33,8 +36,9 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
# Intersection area
inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \
(np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)
inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * (
np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)
).clip(0)
# Box2 area
area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
@ -99,8 +103,9 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
# Intersection area
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * \
(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0)
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (
b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
).clamp_(0)
# Union Area
union = w1 * h1 + w2 * h2 - inter + eps
@ -111,10 +116,10 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
c2 = cw**2 + ch**2 + eps # convex diagonal squared
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - (rho2 / c2 + v * alpha) # CIoU
@ -202,12 +207,19 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
a1, b1, c1 = _get_covariance_matrix(obb1)
a2, b2, c2 = _get_covariance_matrix(obb2)
t1 = (((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) /
((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.25
t1 = (
((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2)))
/ ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)
) * 0.25
t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5
t3 = torch.log(((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) /
(4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) *
(a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + eps) * 0.5
t3 = (
torch.log(
((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)))
/ (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) * (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps)
+ eps
)
* 0.5
)
bd = t1 + t2 + t3
bd = torch.clamp(bd, eps, 100.0)
hd = torch.sqrt(1.0 - torch.exp(-bd) + eps)
@ -215,7 +227,7 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
if CIoU: # only include the wh aspect ratio part
w1, h1 = obb1[..., 2:4].split(1, dim=-1)
w2, h2 = obb2[..., 2:4].split(1, dim=-1)
v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - v * alpha # CIoU
@ -239,12 +251,19 @@ def batch_probiou(obb1, obb2, eps=1e-7):
a1, b1, c1 = _get_covariance_matrix(obb1)
a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))
t1 = (((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) /
((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.25
t1 = (
((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2)))
/ ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)
) * 0.25
t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5
t3 = torch.log(((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) /
(4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) *
(a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + eps) * 0.5
t3 = (
torch.log(
((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)))
/ (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) * (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps)
+ eps
)
* 0.5
)
bd = t1 + t2 + t3
bd = torch.clamp(bd, eps, 100.0)
hd = torch.sqrt(1.0 - torch.exp(-bd) + eps)
@ -279,10 +298,10 @@ class ConfusionMatrix:
iou_thres (float): The Intersection over Union threshold.
"""
def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'):
def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"):
"""Initialize attributes for the YOLO model."""
self.task = task
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc))
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
self.nc = nc # number of classes
self.conf = 0.25 if conf in (None, 0.001) else conf # apply 0.25 if default val conf is passed
self.iou_thres = iou_thres
@ -361,11 +380,11 @@ class ConfusionMatrix:
tp = self.matrix.diagonal() # true positives
fp = self.matrix.sum(1) - tp # false positives
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return (tp[:-1], fp[:-1]) if self.task == 'detect' else (tp, fp) # remove background class if task=detect
return (tp[:-1], fp[:-1]) if self.task == "detect" else (tp, fp) # remove background class if task=detect
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
@TryExcept("WARNING ⚠️ ConfusionMatrix plot failure")
@plt_settings()
def plot(self, normalize=True, save_dir='', names=(), on_plot=None):
def plot(self, normalize=True, save_dir="", names=(), on_plot=None):
"""
Plot the confusion matrix using seaborn and save it to a file.
@ -377,30 +396,31 @@ class ConfusionMatrix:
"""
import seaborn as sn
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
ticklabels = (list(names) + ['background']) if labels else 'auto'
ticklabels = (list(names) + ["background"]) if labels else "auto"
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array,
ax=ax,
annot=nc < 30,
annot_kws={
'size': 8},
cmap='Blues',
fmt='.2f' if normalize else '.0f',
square=True,
vmin=0.0,
xticklabels=ticklabels,
yticklabels=ticklabels).set_facecolor((1, 1, 1))
title = 'Confusion Matrix' + ' Normalized' * normalize
ax.set_xlabel('True')
ax.set_ylabel('Predicted')
warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(
array,
ax=ax,
annot=nc < 30,
annot_kws={"size": 8},
cmap="Blues",
fmt=".2f" if normalize else ".0f",
square=True,
vmin=0.0,
xticklabels=ticklabels,
yticklabels=ticklabels,
).set_facecolor((1, 1, 1))
title = "Confusion Matrix" + " Normalized" * normalize
ax.set_xlabel("True")
ax.set_ylabel("Predicted")
ax.set_title(title)
plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png'
fig.savefig(plot_fname, dpi=250)
@ -411,7 +431,7 @@ class ConfusionMatrix:
def print(self):
"""Print the confusion matrix to the console."""
for i in range(self.nc + 1):
LOGGER.info(' '.join(map(str, self.matrix[i])))
LOGGER.info(" ".join(map(str, self.matrix[i])))
def smooth(y, f=0.05):
@ -419,28 +439,28 @@ def smooth(y, f=0.05):
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
p = np.ones(nf // 2) # ones padding
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
return np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed
@plt_settings()
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=None):
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=None):
"""Plots a precision-recall curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py.T):
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
else:
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean())
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
ax.set_title('Precision-Recall Curve')
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title("Precision-Recall Curve")
fig.savefig(save_dir, dpi=250)
plt.close(fig)
if on_plot:
@ -448,24 +468,24 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=N
@plt_settings()
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric', on_plot=None):
def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=(), xlabel="Confidence", ylabel="Metric", on_plot=None):
"""Plots a metric-confidence curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py):
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
else:
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
y = smooth(py.mean(0), 0.05)
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
ax.set_title(f'{ylabel}-Confidence Curve')
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title(f"{ylabel}-Confidence Curve")
fig.savefig(save_dir, dpi=250)
plt.close(fig)
if on_plot:
@ -494,8 +514,8 @@ def compute_ap(recall, precision):
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
# Integrate area under curve
method = 'interp' # methods: 'continuous', 'interp'
if method == 'interp':
method = "interp" # methods: 'continuous', 'interp'
if method == "interp":
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
else: # 'continuous'
@ -505,16 +525,9 @@ def compute_ap(recall, precision):
return ap, mpre, mrec
def ap_per_class(tp,
conf,
pred_cls,
target_cls,
plot=False,
on_plot=None,
save_dir=Path(),
names=(),
eps=1e-16,
prefix=''):
def ap_per_class(
tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names=(), eps=1e-16, prefix=""
):
"""
Computes the average precision per class for object detection evaluation.
@ -591,10 +604,10 @@ def ap_per_class(tp,
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
names = dict(enumerate(names)) # to dict
if plot:
plot_pr_curve(x, prec_values, ap, save_dir / f'{prefix}PR_curve.png', names, on_plot=on_plot)
plot_mc_curve(x, f1_curve, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1', on_plot=on_plot)
plot_mc_curve(x, p_curve, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision', on_plot=on_plot)
plot_mc_curve(x, r_curve, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall', on_plot=on_plot)
plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)
plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)
plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot)
plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot)
i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values
@ -746,8 +759,18 @@ class Metric(SimpleClass):
Updates the class attributes `self.p`, `self.r`, `self.f1`, `self.all_ap`, and `self.ap_class_index` based
on the values provided in the `results` tuple.
"""
(self.p, self.r, self.f1, self.all_ap, self.ap_class_index, self.p_curve, self.r_curve, self.f1_curve, self.px,
self.prec_values) = results
(
self.p,
self.r,
self.f1,
self.all_ap,
self.ap_class_index,
self.p_curve,
self.r_curve,
self.f1_curve,
self.px,
self.prec_values,
) = results
@property
def curves(self):
@ -757,8 +780,12 @@ class Metric(SimpleClass):
@property
def curves_results(self):
"""Returns a list of curves for accessing specific metrics curves."""
return [[self.px, self.prec_values, 'Recall', 'Precision'], [self.px, self.f1_curve, 'Confidence', 'F1'],
[self.px, self.p_curve, 'Confidence', 'Precision'], [self.px, self.r_curve, 'Confidence', 'Recall']]
return [
[self.px, self.prec_values, "Recall", "Precision"],
[self.px, self.f1_curve, "Confidence", "F1"],
[self.px, self.p_curve, "Confidence", "Precision"],
[self.px, self.r_curve, "Confidence", "Recall"],
]
class DetMetrics(SimpleClass):
@ -793,33 +820,35 @@ class DetMetrics(SimpleClass):
curves_results: TODO
"""
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
"""Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.task = 'detect'
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
self.task = "detect"
def process(self, tp, conf, pred_cls, target_cls):
"""Process predicted results for object detection and update metrics."""
results = ap_per_class(tp,
conf,
pred_cls,
target_cls,
plot=self.plot,
save_dir=self.save_dir,
names=self.names,
on_plot=self.on_plot)[2:]
results = ap_per_class(
tp,
conf,
pred_cls,
target_cls,
plot=self.plot,
save_dir=self.save_dir,
names=self.names,
on_plot=self.on_plot,
)[2:]
self.box.nc = len(self.names)
self.box.update(results)
@property
def keys(self):
"""Returns a list of keys for accessing specific metrics."""
return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
def mean_results(self):
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
@ -847,12 +876,12 @@ class DetMetrics(SimpleClass):
@property
def results_dict(self):
"""Returns dictionary of computed performance metrics and statistics."""
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
@property
def curves(self):
"""Returns a list of curves for accessing specific metrics curves."""
return ['Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)']
return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
@property
def curves_results(self):
@ -889,7 +918,7 @@ class SegmentMetrics(SimpleClass):
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
"""
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
"""Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names."""
self.save_dir = save_dir
self.plot = plot
@ -897,8 +926,8 @@ class SegmentMetrics(SimpleClass):
self.names = names
self.box = Metric()
self.seg = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.task = 'segment'
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
self.task = "segment"
def process(self, tp, tp_m, conf, pred_cls, target_cls):
"""
@ -912,26 +941,30 @@ class SegmentMetrics(SimpleClass):
target_cls (list): List of target classes.
"""
results_mask = ap_per_class(tp_m,
conf,
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix='Mask')[2:]
results_mask = ap_per_class(
tp_m,
conf,
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix="Mask",
)[2:]
self.seg.nc = len(self.names)
self.seg.update(results_mask)
results_box = ap_per_class(tp,
conf,
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix='Box')[2:]
results_box = ap_per_class(
tp,
conf,
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix="Box",
)[2:]
self.box.nc = len(self.names)
self.box.update(results_box)
@ -939,8 +972,15 @@ class SegmentMetrics(SimpleClass):
def keys(self):
"""Returns a list of keys for accessing metrics."""
return [
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)']
"metrics/precision(B)",
"metrics/recall(B)",
"metrics/mAP50(B)",
"metrics/mAP50-95(B)",
"metrics/precision(M)",
"metrics/recall(M)",
"metrics/mAP50(M)",
"metrics/mAP50-95(M)",
]
def mean_results(self):
"""Return the mean metrics for bounding box and segmentation results."""
@ -968,14 +1008,21 @@ class SegmentMetrics(SimpleClass):
@property
def results_dict(self):
"""Returns results of object detection model for evaluation."""
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
@property
def curves(self):
"""Returns a list of curves for accessing specific metrics curves."""
return [
'Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)',
'Precision-Recall(M)', 'F1-Confidence(M)', 'Precision-Confidence(M)', 'Recall-Confidence(M)']
"Precision-Recall(B)",
"F1-Confidence(B)",
"Precision-Confidence(B)",
"Recall-Confidence(B)",
"Precision-Recall(M)",
"F1-Confidence(M)",
"Precision-Confidence(M)",
"Recall-Confidence(M)",
]
@property
def curves_results(self):
@ -1012,7 +1059,7 @@ class PoseMetrics(SegmentMetrics):
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
"""
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
"""Initialize the PoseMetrics class with directory path, class names, and plotting options."""
super().__init__(save_dir, plot, names)
self.save_dir = save_dir
@ -1021,8 +1068,8 @@ class PoseMetrics(SegmentMetrics):
self.names = names
self.box = Metric()
self.pose = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.task = 'pose'
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
self.task = "pose"
def process(self, tp, tp_p, conf, pred_cls, target_cls):
"""
@ -1036,26 +1083,30 @@ class PoseMetrics(SegmentMetrics):
target_cls (list): List of target classes.
"""
results_pose = ap_per_class(tp_p,
conf,
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix='Pose')[2:]
results_pose = ap_per_class(
tp_p,
conf,
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix="Pose",
)[2:]
self.pose.nc = len(self.names)
self.pose.update(results_pose)
results_box = ap_per_class(tp,
conf,
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix='Box')[2:]
results_box = ap_per_class(
tp,
conf,
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix="Box",
)[2:]
self.box.nc = len(self.names)
self.box.update(results_box)
@ -1063,8 +1114,15 @@ class PoseMetrics(SegmentMetrics):
def keys(self):
"""Returns list of evaluation metric keys."""
return [
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
'metrics/precision(P)', 'metrics/recall(P)', 'metrics/mAP50(P)', 'metrics/mAP50-95(P)']
"metrics/precision(B)",
"metrics/recall(B)",
"metrics/mAP50(B)",
"metrics/mAP50-95(B)",
"metrics/precision(P)",
"metrics/recall(P)",
"metrics/mAP50(P)",
"metrics/mAP50-95(P)",
]
def mean_results(self):
"""Return the mean results of box and pose."""
@ -1088,8 +1146,15 @@ class PoseMetrics(SegmentMetrics):
def curves(self):
"""Returns a list of curves for accessing specific metrics curves."""
return [
'Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)',
'Precision-Recall(P)', 'F1-Confidence(P)', 'Precision-Confidence(P)', 'Recall-Confidence(P)']
"Precision-Recall(B)",
"F1-Confidence(B)",
"Precision-Confidence(B)",
"Recall-Confidence(B)",
"Precision-Recall(P)",
"F1-Confidence(P)",
"Precision-Confidence(P)",
"Recall-Confidence(P)",
]
@property
def curves_results(self):
@ -1119,8 +1184,8 @@ class ClassifyMetrics(SimpleClass):
"""Initialize a ClassifyMetrics instance."""
self.top1 = 0
self.top5 = 0
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.task = 'classify'
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
self.task = "classify"
def process(self, targets, pred):
"""Target classes and predicted classes."""
@ -1137,12 +1202,12 @@ class ClassifyMetrics(SimpleClass):
@property
def results_dict(self):
"""Returns a dictionary with model's performance metrics and fitness score."""
return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness]))
return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
@property
def keys(self):
"""Returns a list of keys for the results_dict property."""
return ['metrics/accuracy_top1', 'metrics/accuracy_top5']
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
@property
def curves(self):
@ -1156,32 +1221,33 @@ class ClassifyMetrics(SimpleClass):
class OBBMetrics(SimpleClass):
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
def process(self, tp, conf, pred_cls, target_cls):
"""Process predicted results for object detection and update metrics."""
results = ap_per_class(tp,
conf,
pred_cls,
target_cls,
plot=self.plot,
save_dir=self.save_dir,
names=self.names,
on_plot=self.on_plot)[2:]
results = ap_per_class(
tp,
conf,
pred_cls,
target_cls,
plot=self.plot,
save_dir=self.save_dir,
names=self.names,
on_plot=self.on_plot,
)[2:]
self.box.nc = len(self.names)
self.box.update(results)
@property
def keys(self):
"""Returns a list of keys for accessing specific metrics."""
return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
def mean_results(self):
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
@ -1209,7 +1275,7 @@ class OBBMetrics(SimpleClass):
@property
def results_dict(self):
"""Returns dictionary of computed performance metrics and statistics."""
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
@property
def curves(self):

View file

@ -52,7 +52,7 @@ class Profile(contextlib.ContextDecorator):
def __str__(self):
"""Returns a human-readable string representing the accumulated elapsed time in the profiler."""
return f'Elapsed time is {self.t} s'
return f"Elapsed time is {self.t} s"
def time(self):
"""Get current time."""
@ -76,9 +76,13 @@ def segment2box(segment, width=640, height=640):
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
x, y = segment.T # segment xy
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
x, y, = x[inside], y[inside]
return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros(
4, dtype=segment.dtype) # xyxy
x = x[inside]
y = y[inside]
return (
np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
if any(x)
else np.zeros(4, dtype=segment.dtype)
) # xyxy
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
@ -101,8 +105,10 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xyw
"""
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round(
(img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding
pad = (
round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
) # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
@ -145,7 +151,7 @@ def nms_rotated(boxes, scores, threshold=0.45):
Returns:
"""
if len(boxes) == 0:
return np.empty((0, ), dtype=np.int8)
return np.empty((0,), dtype=np.int8)
sorted_idx = torch.argsort(scores, descending=True)
boxes = boxes[sorted_idx]
ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
@ -199,8 +205,8 @@ def non_max_suppression(
"""
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
@ -284,7 +290,7 @@ def non_max_suppression(
output[xi] = x[i]
if (time.time() - t) > time_limit:
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
break # time limit exceeded
return output
@ -378,7 +384,7 @@ def xyxy2xywh(x):
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
"""
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
@ -398,7 +404,7 @@ def xywh2xyxy(x):
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
"""
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
dw = x[..., 2] / 2 # half-width
dh = x[..., 3] / 2 # half-height
@ -423,7 +429,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
"""
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
@ -449,7 +455,7 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
"""
if clip:
x = clip_boxes(x, (h - eps, w - eps))
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
@ -526,8 +532,11 @@ def xyxyxyxy2xywhr(corners):
# especially some objects are cut off by augmentations in dataloader.
(x, y), (w, h), angle = cv2.minAreaRect(pts)
rboxes.append([x, y, w, h, angle / 180 * np.pi])
rboxes = torch.tensor(rboxes, device=corners.device, dtype=corners.dtype) if is_torch else np.asarray(
rboxes, dtype=points.dtype)
rboxes = (
torch.tensor(rboxes, device=corners.device, dtype=corners.dtype)
if is_torch
else np.asarray(rboxes, dtype=points.dtype)
)
return rboxes
@ -546,7 +555,7 @@ def xywhr2xyxyxyxy(center):
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
ctr = center[..., :2]
w, h, angle = (center[..., i:i + 1] for i in range(2, 5))
w, h, angle = (center[..., i : i + 1] for i in range(2, 5))
cos_value, sin_value = cos(angle), sin(angle)
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
@ -607,8 +616,9 @@ def resample_segments(segments, n=1000):
s = np.concatenate((s, s[0:1, :]), axis=0)
x = np.linspace(0, len(s) - 1, n)
xp = np.arange(len(s))
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)],
dtype=np.float32).reshape(2, -1).T # segment xy
segments[i] = (
np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
) # segment xy
return segments
@ -647,7 +657,7 @@ def process_mask_upsample(protos, masks_in, bboxes, shape):
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.5)
@ -680,7 +690,7 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
masks = crop_mask(masks, downsampled_bboxes) # CHW
if upsample:
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
return masks.gt_(0.5)
@ -724,7 +734,7 @@ def scale_masks(masks, shape, padding=True):
bottom, right = (int(round(mh - pad[1] + 0.1)), int(round(mw - pad[0] + 0.1)))
masks = masks[..., top:bottom, left:right]
masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW
masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW
return masks
@ -763,7 +773,7 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
return coords
def masks2segments(masks, strategy='largest'):
def masks2segments(masks, strategy="largest"):
"""
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
@ -775,16 +785,16 @@ def masks2segments(masks, strategy='largest'):
segments (List): list of segment masks
"""
segments = []
for x in masks.int().cpu().numpy().astype('uint8'):
for x in masks.int().cpu().numpy().astype("uint8"):
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
if c:
if strategy == 'concat': # concatenate all segments
if strategy == "concat": # concatenate all segments
c = np.concatenate([x.reshape(-1, 2) for x in c])
elif strategy == 'largest': # select largest segment
elif strategy == "largest": # select largest segment
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
else:
c = np.zeros((0, 2)) # no segments found
segments.append(c.astype('float32'))
segments.append(c.astype("float32"))
return segments
@ -811,4 +821,4 @@ def clean_str(s):
Returns:
(str): a string with special characters replaced by an underscore _
"""
return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)

View file

@ -52,7 +52,7 @@ def imshow(winname: str, mat: np.ndarray):
winname (str): Name of the window.
mat (np.ndarray): Image to be shown.
"""
_imshow(winname.encode('unicode_escape').decode(), mat)
_imshow(winname.encode("unicode_escape").decode(), mat)
# PyTorch functions ----------------------------------------------------------------------------------------------------
@ -72,6 +72,6 @@ def torch_save(*args, **kwargs):
except ImportError:
import pickle
if 'pickle_module' not in kwargs:
kwargs['pickle_module'] = pickle # noqa
if "pickle_module" not in kwargs:
kwargs["pickle_module"] = pickle # noqa
return _torch_save(*args, **kwargs)

View file

@ -33,15 +33,55 @@ class Colors:
def __init__(self):
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
hexs = (
"FF3838",
"FF9D97",
"FF701F",
"FFB21D",
"CFD231",
"48F90A",
"92CC17",
"3DDB86",
"1A9334",
"00D4BB",
"2C99A8",
"00C2FF",
"344593",
"6473FF",
"0018EC",
"8438FF",
"520085",
"CB38FF",
"FF95C8",
"FF37C7",
)
self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
self.n = len(self.palette)
self.pose_palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0], [255, 153, 255],
[153, 204, 255], [255, 102, 255], [255, 51, 255], [102, 178, 255], [51, 153, 255],
[255, 153, 153], [255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102],
[51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255]],
dtype=np.uint8)
self.pose_palette = np.array(
[
[255, 128, 0],
[255, 153, 51],
[255, 178, 102],
[230, 230, 0],
[255, 153, 255],
[153, 204, 255],
[255, 102, 255],
[255, 51, 255],
[102, 178, 255],
[51, 153, 255],
[255, 153, 153],
[255, 102, 102],
[255, 51, 51],
[153, 255, 153],
[102, 255, 102],
[51, 255, 51],
[0, 255, 0],
[0, 0, 255],
[255, 0, 0],
[255, 255, 255],
],
dtype=np.uint8,
)
def __call__(self, i, bgr=False):
"""Converts hex color codes to RGB values."""
@ -51,7 +91,7 @@ class Colors:
@staticmethod
def hex2rgb(h):
"""Converts hex color codes to RGB values (i.e. default PIL order)."""
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
colors = Colors() # create instance for 'from utils.plots import colors'
@ -71,9 +111,9 @@ class Annotator:
kpt_color (List[int]): Color palette for keypoints.
"""
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images."
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
self.pil = pil or non_ascii
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
@ -81,26 +121,45 @@ class Annotator:
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
self.draw = ImageDraw.Draw(self.im)
try:
font = check_font('Arial.Unicode.ttf' if non_ascii else font)
font = check_font("Arial.Unicode.ttf" if non_ascii else font)
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
self.font = ImageFont.truetype(str(font), size)
except Exception:
self.font = ImageFont.load_default()
# Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
if check_version(pil_version, '9.2.0'):
if check_version(pil_version, "9.2.0"):
self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
else: # use cv2
self.im = im if im.flags.writeable else im.copy()
self.tf = max(self.lw - 1, 1) # font thickness
self.sf = self.lw / 3 # font scale
# Pose
self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9],
[8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
self.skeleton = [
[16, 14],
[14, 12],
[17, 15],
[15, 13],
[12, 13],
[6, 12],
[7, 13],
[6, 7],
[6, 8],
[7, 9],
[8, 10],
[9, 11],
[2, 3],
[1, 2],
[1, 3],
[2, 4],
[3, 5],
[4, 6],
[5, 7],
]
self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
"""Add one xyxy box to image with label."""
if isinstance(box, torch.Tensor):
box = box.tolist()
@ -134,13 +193,16 @@ class Annotator:
outside = p1[1] - h >= 3
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
cv2.putText(self.im,
label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
0,
self.sf,
txt_color,
thickness=self.tf,
lineType=cv2.LINE_AA)
cv2.putText(
self.im,
label,
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
0,
self.sf,
txt_color,
thickness=self.tf,
lineType=cv2.LINE_AA,
)
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
"""
@ -171,7 +233,7 @@ class Annotator:
im_gpu = im_gpu.flip(dims=[0]) # flip channel
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
im_mask = (im_gpu * 255)
im_mask = im_gpu * 255
im_mask_np = im_mask.byte().cpu().numpy()
self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
if self.pil:
@ -230,9 +292,9 @@ class Annotator:
"""Add rectangle to image (PIL-only)."""
self.draw.rectangle(xy, fill, outline, width)
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top', box_style=False):
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):
"""Adds text to an image using PIL or cv2."""
if anchor == 'bottom': # start y from font bottom
if anchor == "bottom": # start y from font bottom
w, h = self.font.getsize(text) # text width, height
xy[1] += 1 - h
if self.pil:
@ -241,8 +303,8 @@ class Annotator:
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
# Using `txt_color` for background and draw fg with white color
txt_color = (255, 255, 255)
if '\n' in text:
lines = text.split('\n')
if "\n" in text:
lines = text.split("\n")
_, h = self.font.getsize(text)
for line in lines:
self.draw.text(xy, line, fill=txt_color, font=self.font)
@ -314,15 +376,12 @@ class Annotator:
text_y = t_size_in[1]
# Create a rounded rectangle for in_count
cv2.rectangle(self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color,
-1)
cv2.putText(self.im,
str(counts), (text_x, text_y + t_size_in[1]),
0,
tl / 2,
txt_color,
self.tf,
lineType=cv2.LINE_AA)
cv2.rectangle(
self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color, -1
)
cv2.putText(
self.im, str(counts), (text_x, text_y + t_size_in[1]), 0, tl / 2, txt_color, self.tf, lineType=cv2.LINE_AA
)
@staticmethod
def estimate_pose_angle(a, b, c):
@ -375,7 +434,7 @@ class Annotator:
center_kpt (int): centroid pose index for workout monitoring
line_thickness (int): thickness for text display
"""
angle_text, count_text, stage_text = (f' {angle_text:.2f}', 'Steps : ' + f'{count_text}', f' {stage_text}')
angle_text, count_text, stage_text = (f" {angle_text:.2f}", "Steps : " + f"{count_text}", f" {stage_text}")
font_scale = 0.6 + (line_thickness / 10.0)
# Draw angle
@ -383,21 +442,37 @@ class Annotator:
angle_text_position = (int(center_kpt[0]), int(center_kpt[1]))
angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5)
angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (line_thickness * 2))
cv2.rectangle(self.im, angle_background_position, (angle_background_position[0] + angle_background_size[0],
angle_background_position[1] + angle_background_size[1]),
(255, 255, 255), -1)
cv2.rectangle(
self.im,
angle_background_position,
(
angle_background_position[0] + angle_background_size[0],
angle_background_position[1] + angle_background_size[1],
),
(255, 255, 255),
-1,
)
cv2.putText(self.im, angle_text, angle_text_position, 0, font_scale, (0, 0, 0), line_thickness)
# Draw Counts
(count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, font_scale, line_thickness)
count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20)
count_background_position = (angle_background_position[0],
angle_background_position[1] + angle_background_size[1] + 5)
count_background_position = (
angle_background_position[0],
angle_background_position[1] + angle_background_size[1] + 5,
)
count_background_size = (count_text_width + 10, count_text_height + 10 + (line_thickness * 2))
cv2.rectangle(self.im, count_background_position, (count_background_position[0] + count_background_size[0],
count_background_position[1] + count_background_size[1]),
(255, 255, 255), -1)
cv2.rectangle(
self.im,
count_background_position,
(
count_background_position[0] + count_background_size[0],
count_background_position[1] + count_background_size[1],
),
(255, 255, 255),
-1,
)
cv2.putText(self.im, count_text, count_text_position, 0, font_scale, (0, 0, 0), line_thickness)
# Draw Stage
@ -406,9 +481,16 @@ class Annotator:
stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5)
stage_background_size = (stage_text_width + 10, stage_text_height + 10)
cv2.rectangle(self.im, stage_background_position, (stage_background_position[0] + stage_background_size[0],
stage_background_position[1] + stage_background_size[1]),
(255, 255, 255), -1)
cv2.rectangle(
self.im,
stage_background_position,
(
stage_background_position[0] + stage_background_size[0],
stage_background_position[1] + stage_background_size[1],
),
(255, 255, 255),
-1,
)
cv2.putText(self.im, stage_text, stage_text_position, 0, font_scale, (0, 0, 0), line_thickness)
def seg_bbox(self, mask, mask_color=(255, 0, 255), det_label=None, track_label=None):
@ -423,14 +505,20 @@ class Annotator:
"""
cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
label = f'Track ID: {track_label}' if track_label else det_label
label = f"Track ID: {track_label}" if track_label else det_label
text_size, _ = cv2.getTextSize(label, 0, 0.7, 1)
cv2.rectangle(self.im, (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
(int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)), mask_color, -1)
cv2.rectangle(
self.im,
(int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
(int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)),
mask_color,
-1,
)
cv2.putText(self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255),
2)
cv2.putText(
self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255), 2
)
def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255), thickness=2, pins_radius=10):
"""
@ -452,24 +540,24 @@ class Annotator:
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()
def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
"""Plot training labels including class histograms and box statistics."""
import pandas as pd
import seaborn as sn
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
warnings.filterwarnings('ignore', category=UserWarning, message='The figure layout has changed to tight')
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
warnings.filterwarnings("ignore", category=FutureWarning)
# Plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
nc = int(cls.max() + 1) # number of classes
boxes = boxes[:1000000] # limit to 1M boxes
x = pd.DataFrame(boxes, columns=['x', 'y', 'width', 'height'])
x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"])
# Seaborn correlogram
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
plt.close()
# Matplotlib labels
@ -477,14 +565,14 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
for i in range(nc):
y[2].patches[i].set_color([x / 255 for x in colors(i)])
ax[0].set_ylabel('instances')
ax[0].set_ylabel("instances")
if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names)))
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
else:
ax[0].set_xlabel('classes')
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
ax[0].set_xlabel("classes")
sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
# Rectangles
boxes[:, 0:2] = 0.5 # center
@ -493,20 +581,20 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
for cls, box in zip(cls[:500], boxes[:500]):
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
ax[1].imshow(img)
ax[1].axis('off')
ax[1].axis("off")
for a in [0, 1, 2, 3]:
for s in ['top', 'right', 'left', 'bottom']:
for s in ["top", "right", "left", "bottom"]:
ax[a].spines[s].set_visible(False)
fname = save_dir / 'labels.jpg'
fname = save_dir / "labels.jpg"
plt.savefig(fname, dpi=200)
plt.close()
if on_plot:
on_plot(fname)
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
"""
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
@ -545,29 +633,31 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
xyxy = ops.xywh2xyxy(b).long()
xyxy = ops.clip_boxes(xyxy, im.shape)
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
if save:
file.parent.mkdir(parents=True, exist_ok=True) # make directory
f = str(increment_path(file).with_suffix('.jpg'))
f = str(increment_path(file).with_suffix(".jpg"))
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
return crop
@threaded
def plot_images(images,
batch_idx,
cls,
bboxes=np.zeros(0, dtype=np.float32),
confs=None,
masks=np.zeros(0, dtype=np.uint8),
kpts=np.zeros((0, 51), dtype=np.float32),
paths=None,
fname='images.jpg',
names=None,
on_plot=None,
max_subplots=16,
save=True):
def plot_images(
images,
batch_idx,
cls,
bboxes=np.zeros(0, dtype=np.float32),
confs=None,
masks=np.zeros(0, dtype=np.uint8),
kpts=np.zeros((0, 51), dtype=np.float32),
paths=None,
fname="images.jpg",
names=None,
on_plot=None,
max_subplots=16,
save=True,
):
"""Plot image grid with labels."""
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
@ -585,7 +675,7 @@ def plot_images(images,
max_size = 1920 # max image size
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs ** 0.5) # number of subplots (square)
ns = np.ceil(bs**0.5) # number of subplots (square)
if np.max(images[0]) <= 1:
images *= 255 # de-normalise (optional)
@ -593,7 +683,7 @@ def plot_images(images,
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i in range(bs):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
mosaic[y:y + h, x:x + w, :] = images[i].transpose(1, 2, 0)
mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
# Resize (optional)
scale = max_size / ns / max(h, w)
@ -612,7 +702,7 @@ def plot_images(images,
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
if len(cls) > 0:
idx = batch_idx == i
classes = cls[idx].astype('int')
classes = cls[idx].astype("int")
labels = confs is None
if len(bboxes):
@ -633,14 +723,14 @@ def plot_images(images,
color = colors(c)
c = names.get(c, c) if names else c
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
annotator.box_label(box, label, color=color, rotated=is_obb)
elif len(classes):
for c in classes:
color = colors(c)
c = names.get(c, c) if names else c
annotator.text((x, y), f'{c}', txt_color=color, box_style=True)
annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
# Plot keypoints
if len(kpts):
@ -680,7 +770,9 @@ def plot_images(images,
else:
mask = image_masks[j].astype(bool)
with contextlib.suppress(Exception):
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
im[y : y + h, x : x + w, :][mask] = (
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
)
annotator.fromarray(im)
if save:
annotator.im.save(fname) # save
@ -691,7 +783,7 @@ def plot_images(images,
@plt_settings()
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
"""
Plot training results from a results CSV file. The function supports various types of data including segmentation,
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
@ -714,6 +806,7 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
"""
import pandas as pd
from scipy.ndimage import gaussian_filter1d
save_dir = Path(file).parent if file else Path(dir)
if classify:
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
@ -728,32 +821,32 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
ax = ax.ravel()
files = list(save_dir.glob('results*.csv'))
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
files = list(save_dir.glob("results*.csv"))
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
for f in files:
try:
data = pd.read_csv(f)
s = [x.strip() for x in data.columns]
x = data.values[:, 0]
for i, j in enumerate(index):
y = data.values[:, j].astype('float')
y = data.values[:, j].astype("float")
# y[y == 0] = np.nan # don't show zero values
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
ax[i].set_title(s[j], fontsize=12)
# if j in [8, 9, 10]: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
except Exception as e:
LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
ax[1].legend()
fname = save_dir / 'results.png'
fname = save_dir / "results.png"
fig.savefig(fname, dpi=200)
plt.close()
if on_plot:
on_plot(fname)
def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none'):
def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
"""
Plots a scatter plot with points colored based on a 2D histogram.
@ -774,14 +867,18 @@ def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none
# Calculate 2D histogram and corresponding colors
hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
colors = [
hist[min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1)] for i in range(len(v))]
hist[
min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
]
for i in range(len(v))
]
# Scatter plot
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
def plot_tune_results(csv_file='tune_results.csv'):
def plot_tune_results(csv_file="tune_results.csv"):
"""
Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
@ -810,33 +907,33 @@ def plot_tune_results(csv_file='tune_results.csv'):
v = x[:, i + num_metrics_columns]
mu = v[j] # best single result
plt.subplot(n, n, i + 1)
plt_color_scatter(v, fitness, cmap='viridis', alpha=.8, edgecolors='none')
plt.plot(mu, fitness.max(), 'k+', markersize=15)
plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
plt.tick_params(axis='both', labelsize=8) # Set axis label size to 8
plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
plt.plot(mu, fitness.max(), "k+", markersize=15)
plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters
plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
if i % n != 0:
plt.yticks([])
file = csv_file.with_name('tune_scatter_plots.png') # filename
file = csv_file.with_name("tune_scatter_plots.png") # filename
plt.savefig(file, dpi=200)
plt.close()
LOGGER.info(f'Saved {file}')
LOGGER.info(f"Saved {file}")
# Fitness vs iteration
x = range(1, len(fitness) + 1)
plt.figure(figsize=(10, 6), tight_layout=True)
plt.plot(x, fitness, marker='o', linestyle='none', label='fitness')
plt.plot(x, gaussian_filter1d(fitness, sigma=3), ':', label='smoothed', linewidth=2) # smoothing line
plt.title('Fitness vs Iteration')
plt.xlabel('Iteration')
plt.ylabel('Fitness')
plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line
plt.title("Fitness vs Iteration")
plt.xlabel("Iteration")
plt.ylabel("Fitness")
plt.grid(True)
plt.legend()
file = csv_file.with_name('tune_fitness.png') # filename
file = csv_file.with_name("tune_fitness.png") # filename
plt.savefig(file, dpi=200)
plt.close()
LOGGER.info(f'Saved {file}')
LOGGER.info(f"Saved {file}")
def output_to_target(output, max_det=300):
@ -861,7 +958,7 @@ def output_to_rotated_target(output, max_det=300):
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
"""
Visualize feature maps of a given model module during inference.
@ -872,7 +969,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
"""
for m in ['Detect', 'Pose', 'Segment']:
for m in ["Detect", "Pose", "Segment"]:
if m in module_type:
return
batch, channels, height, width = x.shape # batch, channels, height, width
@ -886,9 +983,9 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
plt.subplots_adjust(wspace=0.05, hspace=0.05)
for i in range(n):
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
ax[i].axis('off')
ax[i].axis("off")
LOGGER.info(f'Saving {f}... ({n}/{channels})')
plt.savefig(f, dpi=300, bbox_inches='tight')
LOGGER.info(f"Saving {f}... ({n}/{channels})")
plt.savefig(f, dpi=300, bbox_inches="tight")
plt.close()
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save

View file

@ -7,7 +7,7 @@ from .checks import check_version
from .metrics import bbox_iou, probiou
from .ops import xywhr2xyxyxyxy
TORCH_1_10 = check_version(torch.__version__, '1.10.0')
TORCH_1_10 = check_version(torch.__version__, "1.10.0")
class TaskAlignedAssigner(nn.Module):
@ -61,12 +61,17 @@ class TaskAlignedAssigner(nn.Module):
if self.n_max_boxes == 0:
device = gt_bboxes.device
return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device),
torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device))
return (
torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
torch.zeros_like(pd_bboxes).to(device),
torch.zeros_like(pd_scores).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device),
)
mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
mask_gt)
mask_pos, align_metric, overlaps = self.get_pos_mask(
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
)
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
@ -148,7 +153,7 @@ class TaskAlignedAssigner(nn.Module):
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
for k in range(self.topk):
# Expand topk_idxs for each value of k and add 1 at the specified positions
count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
# Filter invalid bboxes
count_tensor.masked_fill_(count_tensor > 1, 0)
@ -192,9 +197,11 @@ class TaskAlignedAssigner(nn.Module):
target_labels.clamp_(0)
# 10x faster than F.one_hot()
target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
dtype=torch.int64,
device=target_labels.device) # (b, h*w, 80)
target_scores = torch.zeros(
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
dtype=torch.int64,
device=target_labels.device,
) # (b, h*w, 80)
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
@ -252,7 +259,6 @@ class TaskAlignedAssigner(nn.Module):
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
def iou_calculation(self, gt_bboxes, pd_bboxes):
"""Iou calculation for rotated bounding boxes."""
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
@ -295,7 +301,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
_, _, h, w = feats[i].shape
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_points), torch.cat(stride_tensor)

View file

@ -25,11 +25,11 @@ try:
except ImportError:
thop = None
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
TORCH_2_0 = check_version(torch.__version__, '2.0.0')
TORCHVISION_0_10 = check_version(torchvision.__version__, '0.10.0')
TORCHVISION_0_11 = check_version(torchvision.__version__, '0.11.0')
TORCHVISION_0_13 = check_version(torchvision.__version__, '0.13.0')
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")
TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
@contextmanager
@ -60,13 +60,13 @@ def get_cpu_info():
"""Return a string with system CPU information, i.e. 'Apple M2'."""
import cpuinfo # pip install py-cpuinfo
k = 'brand_raw', 'hardware_raw', 'arch_string_raw' # info keys sorted by preference (not all keys always available)
k = "brand_raw", "hardware_raw", "arch_string_raw" # info keys sorted by preference (not all keys always available)
info = cpuinfo.get_cpu_info() # info dict
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], 'unknown')
return string.replace('(R)', '').replace('CPU ', '').replace('@ ', '')
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
def select_device(device='', batch=0, newline=False, verbose=True):
def select_device(device="", batch=0, newline=False, verbose=True):
"""
Selects the appropriate PyTorch device based on the provided arguments.
@ -103,49 +103,57 @@ def select_device(device='', batch=0, newline=False, verbose=True):
if isinstance(device, torch.device):
return device
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
s = f"Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} "
device = str(device).lower()
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
cpu = device == 'cpu'
mps = device in ('mps', 'mps:0') # Apple Metal Performance Shaders (MPS)
for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
cpu = device == "cpu"
mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS)
if cpu or mps:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
if device == 'cuda':
device = '0'
visible = os.environ.get('CUDA_VISIBLE_DEVICES', None)
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
if device == "cuda":
device = "0"
visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(",", ""))):
LOGGER.info(s)
install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \
'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else ''
raise ValueError(f"Invalid CUDA 'device={device}' requested."
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}'
f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}'
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
f'{install}')
install = (
"See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
"CUDA devices are seen by torch.\n"
if torch.cuda.device_count() == 0
else ""
)
raise ValueError(
f"Invalid CUDA 'device={device}' requested."
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
f"{install}"
)
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")
space = ' ' * (len(s) + 1)
raise ValueError(
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
)
space = " " * (len(s) + 1)
for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
arg = 'cuda:0'
arg = "cuda:0"
elif mps and TORCH_2_0 and torch.backends.mps.is_available():
# Prefer MPS if available
s += f'MPS ({get_cpu_info()})\n'
arg = 'mps'
s += f"MPS ({get_cpu_info()})\n"
arg = "mps"
else: # revert to CPU
s += f'CPU ({get_cpu_info()})\n'
arg = 'cpu'
s += f"CPU ({get_cpu_info()})\n"
arg = "cpu"
if verbose:
LOGGER.info(s if newline else s.rstrip())
@ -161,14 +169,20 @@ def time_sync():
def fuse_conv_and_bn(conv, bn):
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
bias=True).requires_grad_(False).to(conv.weight.device)
fusedconv = (
nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
bias=True,
)
.requires_grad_(False)
.to(conv.weight.device)
)
# Prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
@ -185,15 +199,21 @@ def fuse_conv_and_bn(conv, bn):
def fuse_deconv_and_bn(deconv, bn):
"""Fuse ConvTranspose2d() and BatchNorm2d() layers."""
fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
deconv.out_channels,
kernel_size=deconv.kernel_size,
stride=deconv.stride,
padding=deconv.padding,
output_padding=deconv.output_padding,
dilation=deconv.dilation,
groups=deconv.groups,
bias=True).requires_grad_(False).to(deconv.weight.device)
fuseddconv = (
nn.ConvTranspose2d(
deconv.in_channels,
deconv.out_channels,
kernel_size=deconv.kernel_size,
stride=deconv.stride,
padding=deconv.padding,
output_padding=deconv.output_padding,
dilation=deconv.dilation,
groups=deconv.groups,
bias=True,
)
.requires_grad_(False)
.to(deconv.weight.device)
)
# Prepare filters
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
@ -221,18 +241,21 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
n_l = len(list(model.modules())) # number of layers
if detailed:
LOGGER.info(
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}"
)
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace('module_list.', '')
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' %
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))
name = name.replace("module_list.", "")
LOGGER.info(
"%5g %40s %9s %12g %20s %10.3g %10.3g %10s"
% (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype)
)
flops = get_flops(model, imgsz)
fused = ' (fused)' if getattr(model, 'is_fused', lambda: False)() else ''
fs = f', {flops:.1f} GFLOPs' if flops else ''
yaml_file = getattr(model, 'yaml_file', '') or getattr(model, 'yaml', {}).get('yaml_file', '')
model_name = Path(yaml_file).stem.replace('yolo', 'YOLO') or 'Model'
LOGGER.info(f'{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}')
fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
fs = f", {flops:.1f} GFLOPs" if flops else ""
yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}")
return n_l, n_p, n_g, flops
@ -262,13 +285,15 @@ def model_info_for_loggers(trainer):
"""
if trainer.args.profile: # profile ONNX and TensorRT times
from ultralytics.utils.benchmarks import ProfileModels
results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
results.pop('model/name')
results.pop("model/name")
else: # only return PyTorch times from most recent validation
results = {
'model/parameters': get_num_params(trainer.model),
'model/GFLOPs': round(get_flops(trainer.model), 3)}
results['model/speed_PyTorch(ms)'] = round(trainer.validator.speed['inference'], 3)
"model/parameters": get_num_params(trainer.model),
"model/GFLOPs": round(get_flops(trainer.model), 3),
}
results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
return results
@ -284,14 +309,14 @@ def get_flops(model, imgsz=640):
imgsz = [imgsz, imgsz] # expand if int/float
try:
# Use stride size for input tensor
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 # stride GFLOPs
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
except Exception:
# Use actual image size for input tensor (i.e. required for RTDETR models)
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 # imgsz GFLOPs
return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
except Exception:
return 0.0
@ -301,11 +326,11 @@ def get_flops_with_torch_profiler(model, imgsz=640):
if TORCH_2_0:
model = de_parallel(model)
p = next(model.parameters())
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
with torch.profiler.profile(with_flops=True) as prof:
model(im)
flops = sum(x.flops for x in prof.key_averages()) / 1E9
flops = sum(x.flops for x in prof.key_averages()) / 1e9
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
return flops
@ -333,7 +358,7 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
return img
h, w = img.shape[2:]
s = (int(h * ratio), int(w * ratio)) # new size
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize
if not same_shape: # pad/crop img
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
@ -349,7 +374,7 @@ def make_divisible(x, divisor):
def copy_attr(a, b, include=(), exclude=()):
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
if (len(include) and k not in include) or k.startswith("_") or k in exclude:
continue
else:
setattr(a, k, v)
@ -357,7 +382,7 @@ def copy_attr(a, b, include=(), exclude=()):
def get_latest_opset():
"""Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset
return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 # opset
def intersect_dicts(da, db, exclude=()):
@ -392,10 +417,10 @@ def init_seeds(seed=0, deterministic=False):
if TORCH_2_0:
torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
torch.backends.cudnn.deterministic = True
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTHONHASHSEED"] = str(seed)
else:
LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.')
LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.")
else:
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
@ -430,13 +455,13 @@ class ModelEMA:
v += (1 - d) * msd[k].detach()
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
"""Updates attributes and saves stripped model with optimizer removed."""
if self.enabled:
copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
"""
Strip optimizer from 'f' to finalize training, optionally save as 's'.
@ -456,26 +481,26 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
strip_optimizer(f)
```
"""
x = torch.load(f, map_location=torch.device('cpu'))
if 'model' not in x:
LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.')
x = torch.load(f, map_location=torch.device("cpu"))
if "model" not in x:
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
return
if hasattr(x['model'], 'args'):
x['model'].args = dict(x['model'].args) # convert from IterableSimpleNamespace to dict
args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None # combine args
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
if hasattr(x["model"], "args"):
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
if x.get("ema"):
x["model"] = x["ema"] # replace model with ema
for k in "optimizer", "best_fitness", "ema", "updates": # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
x["epoch"] = -1
x["model"].half() # to FP16
for p in x["model"].parameters():
p.requires_grad = False
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
# x['model'].args = x['train_args']
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # file size
mb = os.path.getsize(s or f) / 1e6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
@ -496,18 +521,20 @@ def profile(input, ops, n=10, device=None):
results = []
if not isinstance(device, torch.device):
device = select_device(device)
LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
f"{'input':>24s}{'output':>24s}")
LOGGER.info(
f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
f"{'input':>24s}{'output':>24s}"
)
for x in input if isinstance(input, list) else [input]:
x = x.to(device)
x.requires_grad = True
for m in ops if isinstance(ops, list) else [ops]:
m = m.to(device) if hasattr(m, 'to') else m # device
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
m = m.to(device) if hasattr(m, "to") else m # device
m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
try:
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1E9 * 2 if thop else 0 # GFLOPs
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
except Exception:
flops = 0
@ -521,13 +548,13 @@ def profile(input, ops, n=10, device=None):
t[2] = time_sync()
except Exception: # no backward method
# print(e) # for debug
t[2] = float('nan')
t[2] = float("nan")
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB)
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
LOGGER.info(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
results.append([p, flops, mem, tf, tb, s_in, s_out])
except Exception as e:
LOGGER.info(e)
@ -548,7 +575,7 @@ class EarlyStopping:
"""
self.best_fitness = 0.0 # i.e. mAP
self.best_epoch = 0
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop
self.possible_stop = False # possible stop may occur next epoch
def __call__(self, epoch, fitness):
@ -572,8 +599,10 @@ class EarlyStopping:
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
stop = delta >= self.patience # stop training if patience exceeded
if stop:
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
f'i.e. `patience=300` or use `patience=0` to disable EarlyStopping.')
LOGGER.info(
f"Stopping training early as no improvement observed in last {self.patience} epochs. "
f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
)
return stop

View file

@ -22,7 +22,7 @@ class TritonRemoteModel:
output_names (List[str]): The names of the model outputs.
"""
def __init__(self, url: str, endpoint: str = '', scheme: str = ''):
def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
"""
Initialize the TritonRemoteModel.
@ -36,7 +36,7 @@ class TritonRemoteModel:
"""
if not endpoint and not scheme: # Parse all args from URL string
splits = urlsplit(url)
endpoint = splits.path.strip('/').split('/')[0]
endpoint = splits.path.strip("/").split("/")[0]
scheme = splits.scheme
url = splits.netloc
@ -44,26 +44,28 @@ class TritonRemoteModel:
self.url = url
# Choose the Triton client based on the communication scheme
if scheme == 'http':
if scheme == "http":
import tritonclient.http as client # noqa
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
config = self.triton_client.get_model_config(endpoint)
else:
import tritonclient.grpc as client # noqa
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
config = self.triton_client.get_model_config(endpoint, as_json=True)['config']
config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
# Sort output names alphabetically, i.e. 'output0', 'output1', etc.
config['output'] = sorted(config['output'], key=lambda x: x.get('name'))
config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
# Define model attributes
type_map = {'TYPE_FP32': np.float32, 'TYPE_FP16': np.float16, 'TYPE_UINT8': np.uint8}
type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
self.InferRequestedOutput = client.InferRequestedOutput
self.InferInput = client.InferInput
self.input_formats = [x['data_type'] for x in config['input']]
self.input_formats = [x["data_type"] for x in config["input"]]
self.np_input_formats = [type_map[x] for x in self.input_formats]
self.input_names = [x['name'] for x in config['input']]
self.output_names = [x['name'] for x in config['output']]
self.input_names = [x["name"] for x in config["input"]]
self.output_names = [x["name"] for x in config["output"]]
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
"""
@ -80,7 +82,7 @@ class TritonRemoteModel:
for i, x in enumerate(inputs):
if x.dtype != self.np_input_formats[i]:
x = x.astype(self.np_input_formats[i])
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace('TYPE_', ''))
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
infer_input.set_data_from_numpy(x)
infer_inputs.append(infer_input)

View file

@ -6,12 +6,9 @@ from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS
def run_ray_tune(model,
space: dict = None,
grace_period: int = 10,
gpu_per_trial: int = None,
max_samples: int = 10,
**train_args):
def run_ray_tune(
model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args
):
"""
Runs hyperparameter tuning using Ray Tune.
@ -38,12 +35,12 @@ def run_ray_tune(model,
```
"""
LOGGER.info('💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune')
LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
if train_args is None:
train_args = {}
try:
subprocess.run('pip install ray[tune]'.split(), check=True)
subprocess.run("pip install ray[tune]".split(), check=True)
import ray
from ray import tune
@ -56,33 +53,34 @@ def run_ray_tune(model,
try:
import wandb
assert hasattr(wandb, '__version__')
assert hasattr(wandb, "__version__")
except (ImportError, AssertionError):
wandb = False
default_space = {
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
'lr0': tune.uniform(1e-5, 1e-1),
'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
'box': tune.uniform(0.02, 0.2), # box loss gain
'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
"lr0": tune.uniform(1e-5, 1e-1),
"lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
"momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
"weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
"warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
"warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum
"box": tune.uniform(0.02, 0.2), # box loss gain
"cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
"hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
"hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
"hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
"degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg)
"translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction)
"scale": tune.uniform(0.0, 0.9), # image scale (+/- gain)
"shear": tune.uniform(0.0, 10.0), # image shear (+/- deg)
"perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
"flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability)
"fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability)
"mosaic": tune.uniform(0.0, 1.0), # image mixup (probability)
"mixup": tune.uniform(0.0, 1.0), # image mixup (probability)
"copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability)
}
# Put the model in ray store
task = model.task
@ -107,35 +105,39 @@ def run_ray_tune(model,
# Get search space
if not space:
space = default_space
LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.')
LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.")
# Get dataset
data = train_args.get('data', TASK2DATA[task])
space['data'] = data
if 'data' not in train_args:
data = train_args.get("data", TASK2DATA[task])
space["data"] = data
if "data" not in train_args:
LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
# Define the trainable function with allocated resources
trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0})
trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
# Define the ASHA scheduler for hyperparameter search
asha_scheduler = ASHAScheduler(time_attr='epoch',
metric=TASK2METRIC[task],
mode='max',
max_t=train_args.get('epochs') or DEFAULT_CFG_DICT['epochs'] or 100,
grace_period=grace_period,
reduction_factor=3)
asha_scheduler = ASHAScheduler(
time_attr="epoch",
metric=TASK2METRIC[task],
mode="max",
max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
grace_period=grace_period,
reduction_factor=3,
)
# Define the callbacks for the hyperparameter search
tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else []
tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
# Create the Ray Tune hyperparameter search tuner
tune_dir = get_save_dir(DEFAULT_CFG, name='tune').resolve() # must be absolute dir
tune_dir = get_save_dir(DEFAULT_CFG, name="tune").resolve() # must be absolute dir
tune_dir.mkdir(parents=True, exist_ok=True)
tuner = tune.Tuner(trainable_with_resources,
param_space=space,
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir))
tuner = tune.Tuner(
trainable_with_resources,
param_space=space,
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir),
)
# Run the hyperparameter search
tuner.fit()