ultralytics 8.0.239 Ultralytics Actions and hub-sdk adoption (#7431)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
parent
e795277391
commit
fe27db2f6e
139 changed files with 6870 additions and 5125 deletions
|
|
@ -25,23 +25,22 @@ from tqdm import tqdm as tqdm_original
|
|||
from ultralytics import __version__
|
||||
|
||||
# PyTorch Multi-GPU DDP Constants
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv("RANK", -1))
|
||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
|
||||
# Other Constants
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[1] # YOLO
|
||||
ASSETS = ROOT / 'assets' # default images
|
||||
DEFAULT_CFG_PATH = ROOT / 'cfg/default.yaml'
|
||||
ASSETS = ROOT / "assets" # default images
|
||||
DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml"
|
||||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||||
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||||
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
|
||||
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' if VERBOSE else None # tqdm bar format
|
||||
LOGGING_NAME = 'ultralytics'
|
||||
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans
|
||||
ARM64 = platform.machine() in ('arm64', 'aarch64') # ARM64 booleans
|
||||
HELP_MSG = \
|
||||
"""
|
||||
AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true" # global auto-install mode
|
||||
VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbose mode
|
||||
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format
|
||||
LOGGING_NAME = "ultralytics"
|
||||
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
|
||||
ARM64 = platform.machine() in ("arm64", "aarch64") # ARM64 booleans
|
||||
HELP_MSG = """
|
||||
Usage examples for running YOLOv8:
|
||||
|
||||
1. Install the ultralytics package:
|
||||
|
|
@ -99,12 +98,12 @@ HELP_MSG = \
|
|||
"""
|
||||
|
||||
# Settings
|
||||
torch.set_printoptions(linewidth=320, precision=4, profile='default')
|
||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||
torch.set_printoptions(linewidth=320, precision=4, profile="default")
|
||||
np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format}) # format short g, %precision=5
|
||||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||||
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab
|
||||
os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # for deterministic training
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress verbose TF compiler warnings in Colab
|
||||
|
||||
|
||||
class TQDM(tqdm_original):
|
||||
|
|
@ -119,8 +118,8 @@ class TQDM(tqdm_original):
|
|||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize custom Ultralytics tqdm class with different default arguments."""
|
||||
# Set new default values (these can still be overridden when calling TQDM)
|
||||
kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed
|
||||
kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed
|
||||
kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) # logical 'and' with default value if passed
|
||||
kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
|
|
@ -134,14 +133,14 @@ class SimpleClass:
|
|||
attr = []
|
||||
for a in dir(self):
|
||||
v = getattr(self, a)
|
||||
if not callable(v) and not a.startswith('_'):
|
||||
if not callable(v) and not a.startswith("_"):
|
||||
if isinstance(v, SimpleClass):
|
||||
# Display only the module and class name for subclasses
|
||||
s = f'{a}: {v.__module__}.{v.__class__.__name__} object'
|
||||
s = f"{a}: {v.__module__}.{v.__class__.__name__} object"
|
||||
else:
|
||||
s = f'{a}: {repr(v)}'
|
||||
s = f"{a}: {repr(v)}"
|
||||
attr.append(s)
|
||||
return f'{self.__module__}.{self.__class__.__name__} object with attributes:\n\n' + '\n'.join(attr)
|
||||
return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr)
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a machine-readable string representation of the object."""
|
||||
|
|
@ -164,24 +163,26 @@ class IterableSimpleNamespace(SimpleNamespace):
|
|||
|
||||
def __str__(self):
|
||||
"""Return a human-readable string representation of the object."""
|
||||
return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
|
||||
return "\n".join(f"{k}={v}" for k, v in vars(self).items())
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Custom attribute access error message with helpful information."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"""
|
||||
raise AttributeError(
|
||||
f"""
|
||||
'{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics
|
||||
'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace
|
||||
{DEFAULT_CFG_PATH} with the latest version from
|
||||
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Return the value of the specified key if it exists; otherwise, return the default value."""
|
||||
return getattr(self, key, default)
|
||||
|
||||
|
||||
def plt_settings(rcparams=None, backend='Agg'):
|
||||
def plt_settings(rcparams=None, backend="Agg"):
|
||||
"""
|
||||
Decorator to temporarily set rc parameters and the backend for a plotting function.
|
||||
|
||||
|
|
@ -199,7 +200,7 @@ def plt_settings(rcparams=None, backend='Agg'):
|
|||
"""
|
||||
|
||||
if rcparams is None:
|
||||
rcparams = {'font.size': 11}
|
||||
rcparams = {"font.size": 11}
|
||||
|
||||
def decorator(func):
|
||||
"""Decorator to apply temporary rc parameters and backend to a function."""
|
||||
|
|
@ -208,14 +209,14 @@ def plt_settings(rcparams=None, backend='Agg'):
|
|||
"""Sets rc parameters and backend, calls the original function, and restores the settings."""
|
||||
original_backend = plt.get_backend()
|
||||
if backend != original_backend:
|
||||
plt.close('all') # auto-close()ing of figures upon backend switching is deprecated since 3.8
|
||||
plt.close("all") # auto-close()ing of figures upon backend switching is deprecated since 3.8
|
||||
plt.switch_backend(backend)
|
||||
|
||||
with plt.rc_context(rcparams):
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
if backend != original_backend:
|
||||
plt.close('all')
|
||||
plt.close("all")
|
||||
plt.switch_backend(original_backend)
|
||||
return result
|
||||
|
||||
|
|
@ -229,26 +230,26 @@ def set_logging(name=LOGGING_NAME, verbose=True):
|
|||
level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings
|
||||
|
||||
# Configure the console (stdout) encoding to UTF-8
|
||||
formatter = logging.Formatter('%(message)s') # Default formatter
|
||||
if WINDOWS and sys.stdout.encoding != 'utf-8':
|
||||
formatter = logging.Formatter("%(message)s") # Default formatter
|
||||
if WINDOWS and sys.stdout.encoding != "utf-8":
|
||||
try:
|
||||
if hasattr(sys.stdout, 'reconfigure'):
|
||||
sys.stdout.reconfigure(encoding='utf-8')
|
||||
elif hasattr(sys.stdout, 'buffer'):
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
elif hasattr(sys.stdout, "buffer"):
|
||||
import io
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
|
||||
else:
|
||||
sys.stdout.encoding = 'utf-8'
|
||||
sys.stdout.encoding = "utf-8"
|
||||
except Exception as e:
|
||||
print(f'Creating custom formatter for non UTF-8 environments due to {e}')
|
||||
print(f"Creating custom formatter for non UTF-8 environments due to {e}")
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
|
||||
def format(self, record):
|
||||
"""Sets up logging with UTF-8 encoding and configurable verbosity."""
|
||||
return emojis(super().format(record))
|
||||
|
||||
formatter = CustomFormatter('%(message)s') # Use CustomFormatter to eliminate UTF-8 output as last recourse
|
||||
formatter = CustomFormatter("%(message)s") # Use CustomFormatter to eliminate UTF-8 output as last recourse
|
||||
|
||||
# Create and configure the StreamHandler
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
|
|
@ -264,13 +265,13 @@ def set_logging(name=LOGGING_NAME, verbose=True):
|
|||
|
||||
# Set logger
|
||||
LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.)
|
||||
for logger in 'sentry_sdk', 'urllib3.connectionpool':
|
||||
for logger in "sentry_sdk", "urllib3.connectionpool":
|
||||
logging.getLogger(logger).setLevel(logging.CRITICAL + 1)
|
||||
|
||||
|
||||
def emojis(string=''):
|
||||
def emojis(string=""):
|
||||
"""Return platform-dependent emoji-safe version of string."""
|
||||
return string.encode().decode('ascii', 'ignore') if WINDOWS else string
|
||||
return string.encode().decode("ascii", "ignore") if WINDOWS else string
|
||||
|
||||
|
||||
class ThreadingLocked:
|
||||
|
|
@ -310,7 +311,7 @@ class ThreadingLocked:
|
|||
return decorated
|
||||
|
||||
|
||||
def yaml_save(file='data.yaml', data=None, header=''):
|
||||
def yaml_save(file="data.yaml", data=None, header=""):
|
||||
"""
|
||||
Save YAML data to a file.
|
||||
|
||||
|
|
@ -336,13 +337,13 @@ def yaml_save(file='data.yaml', data=None, header=''):
|
|||
data[k] = str(v)
|
||||
|
||||
# Dump data to file in YAML format
|
||||
with open(file, 'w', errors='ignore', encoding='utf-8') as f:
|
||||
with open(file, "w", errors="ignore", encoding="utf-8") as f:
|
||||
if header:
|
||||
f.write(header)
|
||||
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)
|
||||
|
||||
|
||||
def yaml_load(file='data.yaml', append_filename=False):
|
||||
def yaml_load(file="data.yaml", append_filename=False):
|
||||
"""
|
||||
Load YAML data from a file.
|
||||
|
||||
|
|
@ -353,18 +354,18 @@ def yaml_load(file='data.yaml', append_filename=False):
|
|||
Returns:
|
||||
(dict): YAML data and file name.
|
||||
"""
|
||||
assert Path(file).suffix in ('.yaml', '.yml'), f'Attempting to load non-YAML file {file} with yaml_load()'
|
||||
with open(file, errors='ignore', encoding='utf-8') as f:
|
||||
assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()"
|
||||
with open(file, errors="ignore", encoding="utf-8") as f:
|
||||
s = f.read() # string
|
||||
|
||||
# Remove special characters
|
||||
if not s.isprintable():
|
||||
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
|
||||
s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s)
|
||||
|
||||
# Add YAML filename to dict and return
|
||||
data = yaml.safe_load(s) or {} # always return a dict (yaml.safe_load() may return None for empty files)
|
||||
if append_filename:
|
||||
data['yaml_file'] = str(file)
|
||||
data["yaml_file"] = str(file)
|
||||
return data
|
||||
|
||||
|
||||
|
|
@ -386,7 +387,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
|||
# Default configuration
|
||||
DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
|
||||
for k, v in DEFAULT_CFG_DICT.items():
|
||||
if isinstance(v, str) and v.lower() == 'none':
|
||||
if isinstance(v, str) and v.lower() == "none":
|
||||
DEFAULT_CFG_DICT[k] = None
|
||||
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
|
||||
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
|
||||
|
|
@ -400,8 +401,8 @@ def is_ubuntu() -> bool:
|
|||
(bool): True if OS is Ubuntu, False otherwise.
|
||||
"""
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
with open('/etc/os-release') as f:
|
||||
return 'ID=ubuntu' in f.read()
|
||||
with open("/etc/os-release") as f:
|
||||
return "ID=ubuntu" in f.read()
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -412,7 +413,7 @@ def is_colab():
|
|||
Returns:
|
||||
(bool): True if running inside a Colab notebook, False otherwise.
|
||||
"""
|
||||
return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ
|
||||
return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ
|
||||
|
||||
|
||||
def is_kaggle():
|
||||
|
|
@ -422,7 +423,7 @@ def is_kaggle():
|
|||
Returns:
|
||||
(bool): True if running inside a Kaggle kernel, False otherwise.
|
||||
"""
|
||||
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
||||
return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
|
||||
|
||||
|
||||
def is_jupyter():
|
||||
|
|
@ -434,6 +435,7 @@ def is_jupyter():
|
|||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
from IPython import get_ipython
|
||||
|
||||
return get_ipython() is not None
|
||||
return False
|
||||
|
||||
|
|
@ -445,10 +447,10 @@ def is_docker() -> bool:
|
|||
Returns:
|
||||
(bool): True if the script is running inside a Docker container, False otherwise.
|
||||
"""
|
||||
file = Path('/proc/self/cgroup')
|
||||
file = Path("/proc/self/cgroup")
|
||||
if file.exists():
|
||||
with open(file) as f:
|
||||
return 'docker' in f.read()
|
||||
return "docker" in f.read()
|
||||
else:
|
||||
return False
|
||||
|
||||
|
|
@ -462,7 +464,7 @@ def is_online() -> bool:
|
|||
"""
|
||||
import socket
|
||||
|
||||
for host in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
|
||||
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
|
||||
try:
|
||||
test_connection = socket.create_connection(address=(host, 53), timeout=2)
|
||||
except (socket.timeout, socket.gaierror, OSError):
|
||||
|
|
@ -516,7 +518,7 @@ def is_pytest_running():
|
|||
Returns:
|
||||
(bool): True if pytest is running, False otherwise.
|
||||
"""
|
||||
return ('PYTEST_CURRENT_TEST' in os.environ) or ('pytest' in sys.modules) or ('pytest' in Path(sys.argv[0]).stem)
|
||||
return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(sys.argv[0]).stem)
|
||||
|
||||
|
||||
def is_github_action_running() -> bool:
|
||||
|
|
@ -526,7 +528,7 @@ def is_github_action_running() -> bool:
|
|||
Returns:
|
||||
(bool): True if the current environment is a GitHub Actions runner, False otherwise.
|
||||
"""
|
||||
return 'GITHUB_ACTIONS' in os.environ and 'GITHUB_WORKFLOW' in os.environ and 'RUNNER_OS' in os.environ
|
||||
return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ
|
||||
|
||||
|
||||
def is_git_dir():
|
||||
|
|
@ -549,7 +551,7 @@ def get_git_dir():
|
|||
(Path | None): Git root directory if found or None if not found.
|
||||
"""
|
||||
for d in Path(__file__).parents:
|
||||
if (d / '.git').is_dir():
|
||||
if (d / ".git").is_dir():
|
||||
return d
|
||||
|
||||
|
||||
|
|
@ -562,7 +564,7 @@ def get_git_origin_url():
|
|||
"""
|
||||
if is_git_dir():
|
||||
with contextlib.suppress(subprocess.CalledProcessError):
|
||||
origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url'])
|
||||
origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
|
||||
return origin.decode().strip()
|
||||
|
||||
|
||||
|
|
@ -575,7 +577,7 @@ def get_git_branch():
|
|||
"""
|
||||
if is_git_dir():
|
||||
with contextlib.suppress(subprocess.CalledProcessError):
|
||||
origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
||||
return origin.decode().strip()
|
||||
|
||||
|
||||
|
|
@ -602,11 +604,11 @@ def get_ubuntu_version():
|
|||
"""
|
||||
if is_ubuntu():
|
||||
with contextlib.suppress(FileNotFoundError, AttributeError):
|
||||
with open('/etc/os-release') as f:
|
||||
with open("/etc/os-release") as f:
|
||||
return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]
|
||||
|
||||
|
||||
def get_user_config_dir(sub_dir='Ultralytics'):
|
||||
def get_user_config_dir(sub_dir="Ultralytics"):
|
||||
"""
|
||||
Get the user config directory.
|
||||
|
||||
|
|
@ -618,19 +620,21 @@ def get_user_config_dir(sub_dir='Ultralytics'):
|
|||
"""
|
||||
# Return the appropriate config directory for each operating system
|
||||
if WINDOWS:
|
||||
path = Path.home() / 'AppData' / 'Roaming' / sub_dir
|
||||
path = Path.home() / "AppData" / "Roaming" / sub_dir
|
||||
elif MACOS: # macOS
|
||||
path = Path.home() / 'Library' / 'Application Support' / sub_dir
|
||||
path = Path.home() / "Library" / "Application Support" / sub_dir
|
||||
elif LINUX:
|
||||
path = Path.home() / '.config' / sub_dir
|
||||
path = Path.home() / ".config" / sub_dir
|
||||
else:
|
||||
raise ValueError(f'Unsupported operating system: {platform.system()}')
|
||||
raise ValueError(f"Unsupported operating system: {platform.system()}")
|
||||
|
||||
# GCP and AWS lambda fix, only /tmp is writeable
|
||||
if not is_dir_writeable(path.parent):
|
||||
LOGGER.warning(f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
|
||||
'Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path.')
|
||||
path = Path('/tmp') / sub_dir if is_dir_writeable('/tmp') else Path().cwd() / sub_dir
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
|
||||
"Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path."
|
||||
)
|
||||
path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir
|
||||
|
||||
# Create the subdirectory if it does not exist
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -638,8 +642,8 @@ def get_user_config_dir(sub_dir='Ultralytics'):
|
|||
return path
|
||||
|
||||
|
||||
USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR') or get_user_config_dir()) # Ultralytics settings dir
|
||||
SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml'
|
||||
USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir()) # Ultralytics settings dir
|
||||
SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml"
|
||||
|
||||
|
||||
def colorstr(*input):
|
||||
|
|
@ -670,28 +674,29 @@ def colorstr(*input):
|
|||
>>> colorstr('blue', 'bold', 'hello world')
|
||||
>>> '\033[34m\033[1mhello world\033[0m'
|
||||
"""
|
||||
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
||||
*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
|
||||
colors = {
|
||||
'black': '\033[30m', # basic colors
|
||||
'red': '\033[31m',
|
||||
'green': '\033[32m',
|
||||
'yellow': '\033[33m',
|
||||
'blue': '\033[34m',
|
||||
'magenta': '\033[35m',
|
||||
'cyan': '\033[36m',
|
||||
'white': '\033[37m',
|
||||
'bright_black': '\033[90m', # bright colors
|
||||
'bright_red': '\033[91m',
|
||||
'bright_green': '\033[92m',
|
||||
'bright_yellow': '\033[93m',
|
||||
'bright_blue': '\033[94m',
|
||||
'bright_magenta': '\033[95m',
|
||||
'bright_cyan': '\033[96m',
|
||||
'bright_white': '\033[97m',
|
||||
'end': '\033[0m', # misc
|
||||
'bold': '\033[1m',
|
||||
'underline': '\033[4m'}
|
||||
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
|
||||
"black": "\033[30m", # basic colors
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
"bright_black": "\033[90m", # bright colors
|
||||
"bright_red": "\033[91m",
|
||||
"bright_green": "\033[92m",
|
||||
"bright_yellow": "\033[93m",
|
||||
"bright_blue": "\033[94m",
|
||||
"bright_magenta": "\033[95m",
|
||||
"bright_cyan": "\033[96m",
|
||||
"bright_white": "\033[97m",
|
||||
"end": "\033[0m", # misc
|
||||
"bold": "\033[1m",
|
||||
"underline": "\033[4m",
|
||||
}
|
||||
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
|
||||
|
||||
|
||||
def remove_colorstr(input_string):
|
||||
|
|
@ -708,8 +713,8 @@ def remove_colorstr(input_string):
|
|||
>>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
|
||||
>>> 'hello world'
|
||||
"""
|
||||
ansi_escape = re.compile(r'\x1B\[[0-9;]*[A-Za-z]')
|
||||
return ansi_escape.sub('', input_string)
|
||||
ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]")
|
||||
return ansi_escape.sub("", input_string)
|
||||
|
||||
|
||||
class TryExcept(contextlib.ContextDecorator):
|
||||
|
|
@ -719,7 +724,7 @@ class TryExcept(contextlib.ContextDecorator):
|
|||
Use as @TryExcept() decorator or 'with TryExcept():' context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, msg='', verbose=True):
|
||||
def __init__(self, msg="", verbose=True):
|
||||
"""Initialize TryExcept class with optional message and verbosity settings."""
|
||||
self.msg = msg
|
||||
self.verbose = verbose
|
||||
|
|
@ -744,7 +749,7 @@ def threaded(func):
|
|||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result."""
|
||||
if kwargs.pop('threaded', True): # run in thread
|
||||
if kwargs.pop("threaded", True): # run in thread
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
|
|
@ -786,27 +791,28 @@ def set_sentry():
|
|||
Returns:
|
||||
dict: The modified event or None if the event should not be sent to Sentry.
|
||||
"""
|
||||
if 'exc_info' in hint:
|
||||
exc_type, exc_value, tb = hint['exc_info']
|
||||
if exc_type in (KeyboardInterrupt, FileNotFoundError) \
|
||||
or 'out of memory' in str(exc_value):
|
||||
if "exc_info" in hint:
|
||||
exc_type, exc_value, tb = hint["exc_info"]
|
||||
if exc_type in (KeyboardInterrupt, FileNotFoundError) or "out of memory" in str(exc_value):
|
||||
return None # do not send event
|
||||
|
||||
event['tags'] = {
|
||||
'sys_argv': sys.argv[0],
|
||||
'sys_argv_name': Path(sys.argv[0]).name,
|
||||
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
||||
'os': ENVIRONMENT}
|
||||
event["tags"] = {
|
||||
"sys_argv": sys.argv[0],
|
||||
"sys_argv_name": Path(sys.argv[0]).name,
|
||||
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
|
||||
"os": ENVIRONMENT,
|
||||
}
|
||||
return event
|
||||
|
||||
if SETTINGS['sync'] and \
|
||||
RANK in (-1, 0) and \
|
||||
Path(sys.argv[0]).name == 'yolo' and \
|
||||
not TESTS_RUNNING and \
|
||||
ONLINE and \
|
||||
is_pip_package() and \
|
||||
not is_git_dir():
|
||||
|
||||
if (
|
||||
SETTINGS["sync"]
|
||||
and RANK in (-1, 0)
|
||||
and Path(sys.argv[0]).name == "yolo"
|
||||
and not TESTS_RUNNING
|
||||
and ONLINE
|
||||
and is_pip_package()
|
||||
and not is_git_dir()
|
||||
):
|
||||
# If sentry_sdk package is not installed then return and do not use Sentry
|
||||
try:
|
||||
import sentry_sdk # noqa
|
||||
|
|
@ -814,14 +820,15 @@ def set_sentry():
|
|||
return
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn='https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016',
|
||||
dsn="https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016",
|
||||
debug=False,
|
||||
traces_sample_rate=1.0,
|
||||
release=__version__,
|
||||
environment='production', # 'dev' or 'production'
|
||||
environment="production", # 'dev' or 'production'
|
||||
before_send=before_send,
|
||||
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
|
||||
sentry_sdk.set_user({'id': SETTINGS['uuid']}) # SHA-256 anonymized UUID hash
|
||||
ignore_errors=[KeyboardInterrupt, FileNotFoundError],
|
||||
)
|
||||
sentry_sdk.set_user({"id": SETTINGS["uuid"]}) # SHA-256 anonymized UUID hash
|
||||
|
||||
|
||||
class SettingsManager(dict):
|
||||
|
|
@ -833,7 +840,7 @@ class SettingsManager(dict):
|
|||
version (str): Settings version. In case of local version mismatch, new default settings will be saved.
|
||||
"""
|
||||
|
||||
def __init__(self, file=SETTINGS_YAML, version='0.0.4'):
|
||||
def __init__(self, file=SETTINGS_YAML, version="0.0.4"):
|
||||
"""Initialize the SettingsManager with default settings, load and validate current settings from the YAML
|
||||
file.
|
||||
"""
|
||||
|
|
@ -850,23 +857,24 @@ class SettingsManager(dict):
|
|||
self.file = Path(file)
|
||||
self.version = version
|
||||
self.defaults = {
|
||||
'settings_version': version,
|
||||
'datasets_dir': str(datasets_root / 'datasets'),
|
||||
'weights_dir': str(root / 'weights'),
|
||||
'runs_dir': str(root / 'runs'),
|
||||
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
|
||||
'sync': True,
|
||||
'api_key': '',
|
||||
'openai_api_key': '',
|
||||
'clearml': True, # integrations
|
||||
'comet': True,
|
||||
'dvc': True,
|
||||
'hub': True,
|
||||
'mlflow': True,
|
||||
'neptune': True,
|
||||
'raytune': True,
|
||||
'tensorboard': True,
|
||||
'wandb': True}
|
||||
"settings_version": version,
|
||||
"datasets_dir": str(datasets_root / "datasets"),
|
||||
"weights_dir": str(root / "weights"),
|
||||
"runs_dir": str(root / "runs"),
|
||||
"uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
|
||||
"sync": True,
|
||||
"api_key": "",
|
||||
"openai_api_key": "",
|
||||
"clearml": True, # integrations
|
||||
"comet": True,
|
||||
"dvc": True,
|
||||
"hub": True,
|
||||
"mlflow": True,
|
||||
"neptune": True,
|
||||
"raytune": True,
|
||||
"tensorboard": True,
|
||||
"wandb": True,
|
||||
}
|
||||
|
||||
super().__init__(copy.deepcopy(self.defaults))
|
||||
|
||||
|
|
@ -877,13 +885,14 @@ class SettingsManager(dict):
|
|||
self.load()
|
||||
correct_keys = self.keys() == self.defaults.keys()
|
||||
correct_types = all(type(a) is type(b) for a, b in zip(self.values(), self.defaults.values()))
|
||||
correct_version = check_version(self['settings_version'], self.version)
|
||||
correct_version = check_version(self["settings_version"], self.version)
|
||||
if not (correct_keys and correct_types and correct_version):
|
||||
LOGGER.warning(
|
||||
'WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem '
|
||||
'with your settings or a recent ultralytics package update. '
|
||||
"WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem "
|
||||
"with your settings or a recent ultralytics package update. "
|
||||
f"\nView settings with 'yolo settings' or at '{self.file}'"
|
||||
"\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'.")
|
||||
"\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'."
|
||||
)
|
||||
self.reset()
|
||||
|
||||
def load(self):
|
||||
|
|
@ -910,14 +919,16 @@ def deprecation_warn(arg, new_arg, version=None):
|
|||
"""Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
|
||||
if not version:
|
||||
version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release
|
||||
LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
|
||||
f"Please use '{new_arg}' instead.")
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
|
||||
f"Please use '{new_arg}' instead."
|
||||
)
|
||||
|
||||
|
||||
def clean_url(url):
|
||||
"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
|
||||
url = Path(url).as_posix().replace(':/', '://') # Pathlib turns :// -> :/, as_posix() for Windows
|
||||
return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
|
||||
url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windows
|
||||
return urllib.parse.unquote(url).split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth
|
||||
|
||||
|
||||
def url2file(url):
|
||||
|
|
@ -928,13 +939,22 @@ def url2file(url):
|
|||
# Run below code on utils init ------------------------------------------------------------------------------------
|
||||
|
||||
# Check first-install steps
|
||||
PREFIX = colorstr('Ultralytics: ')
|
||||
PREFIX = colorstr("Ultralytics: ")
|
||||
SETTINGS = SettingsManager() # initialize settings
|
||||
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
|
||||
WEIGHTS_DIR = Path(SETTINGS['weights_dir']) # global weights directory
|
||||
RUNS_DIR = Path(SETTINGS['runs_dir']) # global runs directory
|
||||
ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
|
||||
'Docker' if is_docker() else platform.system()
|
||||
DATASETS_DIR = Path(SETTINGS["datasets_dir"]) # global datasets directory
|
||||
WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory
|
||||
RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory
|
||||
ENVIRONMENT = (
|
||||
"Colab"
|
||||
if is_colab()
|
||||
else "Kaggle"
|
||||
if is_kaggle()
|
||||
else "Jupyter"
|
||||
if is_jupyter()
|
||||
else "Docker"
|
||||
if is_docker()
|
||||
else platform.system()
|
||||
)
|
||||
TESTS_RUNNING = is_pytest_running() or is_github_action_running()
|
||||
set_sentry()
|
||||
|
||||
|
|
|
|||
|
|
@ -42,14 +42,14 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
|||
"""
|
||||
|
||||
# Check device
|
||||
prefix = colorstr('AutoBatch: ')
|
||||
LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
|
||||
prefix = colorstr("AutoBatch: ")
|
||||
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz}")
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type == 'cpu':
|
||||
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
|
||||
if device.type == "cpu":
|
||||
LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}")
|
||||
return batch_size
|
||||
if torch.backends.cudnn.benchmark:
|
||||
LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
|
||||
LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
|
||||
return batch_size
|
||||
|
||||
# Inspect CUDA memory
|
||||
|
|
@ -60,7 +60,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
|||
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
|
||||
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
|
||||
f = t - (r + a) # GiB free
|
||||
LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
|
||||
LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
|
||||
|
||||
# Profile batch sizes
|
||||
batch_sizes = [1, 2, 4, 8, 16]
|
||||
|
|
@ -70,7 +70,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
|||
|
||||
# Fit a solution
|
||||
y = [x[2] for x in results if x] # memory [2]
|
||||
p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
|
||||
p = np.polyfit(batch_sizes[: len(y)], y, deg=1) # first degree polynomial fit
|
||||
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
|
||||
if None in results: # some sizes failed
|
||||
i = results.index(None) # first fail index
|
||||
|
|
@ -78,11 +78,11 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
|||
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
||||
if b < 1 or b > 1024: # b outside of safe range
|
||||
b = batch_size
|
||||
LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.')
|
||||
LOGGER.info(f"{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.")
|
||||
|
||||
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
|
||||
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
|
||||
LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
|
||||
return b
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.')
|
||||
LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.")
|
||||
return batch_size
|
||||
|
|
|
|||
|
|
@ -42,13 +42,9 @@ from ultralytics.utils.files import file_size
|
|||
from ultralytics.utils.torch_utils import select_device
|
||||
|
||||
|
||||
def benchmark(model=WEIGHTS_DIR / 'yolov8n.pt',
|
||||
data=None,
|
||||
imgsz=160,
|
||||
half=False,
|
||||
int8=False,
|
||||
device='cpu',
|
||||
verbose=False):
|
||||
def benchmark(
|
||||
model=WEIGHTS_DIR / "yolov8n.pt", data=None, imgsz=160, half=False, int8=False, device="cpu", verbose=False
|
||||
):
|
||||
"""
|
||||
Benchmark a YOLO model across different formats for speed and accuracy.
|
||||
|
||||
|
|
@ -76,6 +72,7 @@ def benchmark(model=WEIGHTS_DIR / 'yolov8n.pt',
|
|||
"""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
pd.options.display.max_columns = 10
|
||||
pd.options.display.width = 120
|
||||
device = select_device(device, verbose=False)
|
||||
|
|
@ -85,67 +82,62 @@ def benchmark(model=WEIGHTS_DIR / 'yolov8n.pt',
|
|||
y = []
|
||||
t0 = time.time()
|
||||
for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
|
||||
emoji, filename = '❌', None # export defaults
|
||||
emoji, filename = "❌", None # export defaults
|
||||
try:
|
||||
assert i != 9 or LINUX, 'Edge TPU export only supported on Linux'
|
||||
assert i != 9 or LINUX, "Edge TPU export only supported on Linux"
|
||||
if i == 10:
|
||||
assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux'
|
||||
assert MACOS or LINUX, "TF.js export only supported on macOS and Linux"
|
||||
elif i == 11:
|
||||
assert sys.version_info < (3, 11), 'PaddlePaddle export only supported on Python<=3.10'
|
||||
if 'cpu' in device.type:
|
||||
assert cpu, 'inference not supported on CPU'
|
||||
if 'cuda' in device.type:
|
||||
assert gpu, 'inference not supported on GPU'
|
||||
assert sys.version_info < (3, 11), "PaddlePaddle export only supported on Python<=3.10"
|
||||
if "cpu" in device.type:
|
||||
assert cpu, "inference not supported on CPU"
|
||||
if "cuda" in device.type:
|
||||
assert gpu, "inference not supported on GPU"
|
||||
|
||||
# Export
|
||||
if format == '-':
|
||||
if format == "-":
|
||||
filename = model.ckpt_path or model.cfg
|
||||
exported_model = model # PyTorch format
|
||||
else:
|
||||
filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False)
|
||||
exported_model = YOLO(filename, task=model.task)
|
||||
assert suffix in str(filename), 'export failed'
|
||||
emoji = '❎' # indicates export succeeded
|
||||
assert suffix in str(filename), "export failed"
|
||||
emoji = "❎" # indicates export succeeded
|
||||
|
||||
# Predict
|
||||
assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported'
|
||||
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
|
||||
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
|
||||
exported_model.predict(ASSETS / 'bus.jpg', imgsz=imgsz, device=device, half=half)
|
||||
assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
|
||||
assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported
|
||||
assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
|
||||
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
|
||||
|
||||
# Validate
|
||||
data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
|
||||
key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
|
||||
results = exported_model.val(data=data,
|
||||
batch=1,
|
||||
imgsz=imgsz,
|
||||
plots=False,
|
||||
device=device,
|
||||
half=half,
|
||||
int8=int8,
|
||||
verbose=False)
|
||||
metric, speed = results.results_dict[key], results.speed['inference']
|
||||
y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
|
||||
results = exported_model.val(
|
||||
data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, int8=int8, verbose=False
|
||||
)
|
||||
metric, speed = results.results_dict[key], results.speed["inference"]
|
||||
y.append([name, "✅", round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
assert type(e) is AssertionError, f'Benchmark failure for {name}: {e}'
|
||||
LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}')
|
||||
assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}"
|
||||
LOGGER.warning(f"ERROR ❌️ Benchmark failure for {name}: {e}")
|
||||
y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference
|
||||
|
||||
# Print results
|
||||
check_yolo(device=device) # print system info
|
||||
df = pd.DataFrame(y, columns=['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)'])
|
||||
df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)"])
|
||||
|
||||
name = Path(model.ckpt_path).name
|
||||
s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n'
|
||||
s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n"
|
||||
LOGGER.info(s)
|
||||
with open('benchmarks.log', 'a', errors='ignore', encoding='utf-8') as f:
|
||||
with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
|
||||
f.write(s)
|
||||
|
||||
if verbose and isinstance(verbose, float):
|
||||
metrics = df[key].array # values to compare to floor
|
||||
floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
|
||||
assert all(x > floor for x in metrics if pd.notna(x)), f'Benchmark failure: metric(s) < floor {floor}'
|
||||
assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}"
|
||||
|
||||
return df
|
||||
|
||||
|
|
@ -175,15 +167,17 @@ class ProfileModels:
|
|||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
paths: list,
|
||||
num_timed_runs=100,
|
||||
num_warmup_runs=10,
|
||||
min_time=60,
|
||||
imgsz=640,
|
||||
half=True,
|
||||
trt=True,
|
||||
device=None):
|
||||
def __init__(
|
||||
self,
|
||||
paths: list,
|
||||
num_timed_runs=100,
|
||||
num_warmup_runs=10,
|
||||
min_time=60,
|
||||
imgsz=640,
|
||||
half=True,
|
||||
trt=True,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
Initialize the ProfileModels class for profiling models.
|
||||
|
||||
|
|
@ -204,37 +198,32 @@ class ProfileModels:
|
|||
self.imgsz = imgsz
|
||||
self.half = half
|
||||
self.trt = trt # run TensorRT profiling
|
||||
self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu')
|
||||
self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def profile(self):
|
||||
"""Logs the benchmarking results of a model, checks metrics against floor and returns the results."""
|
||||
files = self.get_files()
|
||||
|
||||
if not files:
|
||||
print('No matching *.pt or *.onnx files found.')
|
||||
print("No matching *.pt or *.onnx files found.")
|
||||
return
|
||||
|
||||
table_rows = []
|
||||
output = []
|
||||
for file in files:
|
||||
engine_file = file.with_suffix('.engine')
|
||||
if file.suffix in ('.pt', '.yaml', '.yml'):
|
||||
engine_file = file.with_suffix(".engine")
|
||||
if file.suffix in (".pt", ".yaml", ".yml"):
|
||||
model = YOLO(str(file))
|
||||
model.fuse() # to report correct params and GFLOPs in model.info()
|
||||
model_info = model.info()
|
||||
if self.trt and self.device.type != 'cpu' and not engine_file.is_file():
|
||||
engine_file = model.export(format='engine',
|
||||
half=self.half,
|
||||
imgsz=self.imgsz,
|
||||
device=self.device,
|
||||
verbose=False)
|
||||
onnx_file = model.export(format='onnx',
|
||||
half=self.half,
|
||||
imgsz=self.imgsz,
|
||||
simplify=True,
|
||||
device=self.device,
|
||||
verbose=False)
|
||||
elif file.suffix == '.onnx':
|
||||
if self.trt and self.device.type != "cpu" and not engine_file.is_file():
|
||||
engine_file = model.export(
|
||||
format="engine", half=self.half, imgsz=self.imgsz, device=self.device, verbose=False
|
||||
)
|
||||
onnx_file = model.export(
|
||||
format="onnx", half=self.half, imgsz=self.imgsz, simplify=True, device=self.device, verbose=False
|
||||
)
|
||||
elif file.suffix == ".onnx":
|
||||
model_info = self.get_onnx_model_info(file)
|
||||
onnx_file = file
|
||||
else:
|
||||
|
|
@ -254,14 +243,14 @@ class ProfileModels:
|
|||
for path in self.paths:
|
||||
path = Path(path)
|
||||
if path.is_dir():
|
||||
extensions = ['*.pt', '*.onnx', '*.yaml']
|
||||
extensions = ["*.pt", "*.onnx", "*.yaml"]
|
||||
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
|
||||
elif path.suffix in {'.pt', '.yaml', '.yml'}: # add non-existing
|
||||
elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing
|
||||
files.append(str(path))
|
||||
else:
|
||||
files.extend(glob.glob(str(path)))
|
||||
|
||||
print(f'Profiling: {sorted(files)}')
|
||||
print(f"Profiling: {sorted(files)}")
|
||||
return [Path(file) for file in sorted(files)]
|
||||
|
||||
def get_onnx_model_info(self, onnx_file: str):
|
||||
|
|
@ -306,7 +295,7 @@ class ProfileModels:
|
|||
run_times = []
|
||||
for _ in TQDM(range(num_runs), desc=engine_file):
|
||||
results = model(input_data, imgsz=self.imgsz, verbose=False)
|
||||
run_times.append(results[0].speed['inference']) # Convert to milliseconds
|
||||
run_times.append(results[0].speed["inference"]) # Convert to milliseconds
|
||||
|
||||
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
|
@ -315,31 +304,31 @@ class ProfileModels:
|
|||
"""Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
|
||||
times.
|
||||
"""
|
||||
check_requirements('onnxruntime')
|
||||
check_requirements("onnxruntime")
|
||||
import onnxruntime as ort
|
||||
|
||||
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.intra_op_num_threads = 8 # Limit the number of threads
|
||||
sess = ort.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider'])
|
||||
sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"])
|
||||
|
||||
input_tensor = sess.get_inputs()[0]
|
||||
input_type = input_tensor.type
|
||||
|
||||
# Mapping ONNX datatype to numpy datatype
|
||||
if 'float16' in input_type:
|
||||
if "float16" in input_type:
|
||||
input_dtype = np.float16
|
||||
elif 'float' in input_type:
|
||||
elif "float" in input_type:
|
||||
input_dtype = np.float32
|
||||
elif 'double' in input_type:
|
||||
elif "double" in input_type:
|
||||
input_dtype = np.float64
|
||||
elif 'int64' in input_type:
|
||||
elif "int64" in input_type:
|
||||
input_dtype = np.int64
|
||||
elif 'int32' in input_type:
|
||||
elif "int32" in input_type:
|
||||
input_dtype = np.int32
|
||||
else:
|
||||
raise ValueError(f'Unsupported ONNX datatype {input_type}')
|
||||
raise ValueError(f"Unsupported ONNX datatype {input_type}")
|
||||
|
||||
input_data = np.random.rand(*input_tensor.shape).astype(input_dtype)
|
||||
input_name = input_tensor.name
|
||||
|
|
@ -369,25 +358,26 @@ class ProfileModels:
|
|||
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
|
||||
"""Generates a formatted string for a table row that includes model performance and metric details."""
|
||||
layers, params, gradients, flops = model_info
|
||||
return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |'
|
||||
return f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |"
|
||||
|
||||
def generate_results_dict(self, model_name, t_onnx, t_engine, model_info):
|
||||
"""Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
|
||||
layers, params, gradients, flops = model_info
|
||||
return {
|
||||
'model/name': model_name,
|
||||
'model/parameters': params,
|
||||
'model/GFLOPs': round(flops, 3),
|
||||
'model/speed_ONNX(ms)': round(t_onnx[0], 3),
|
||||
'model/speed_TensorRT(ms)': round(t_engine[0], 3)}
|
||||
"model/name": model_name,
|
||||
"model/parameters": params,
|
||||
"model/GFLOPs": round(flops, 3),
|
||||
"model/speed_ONNX(ms)": round(t_onnx[0], 3),
|
||||
"model/speed_TensorRT(ms)": round(t_engine[0], 3),
|
||||
}
|
||||
|
||||
def print_table(self, table_rows):
|
||||
"""Formats and prints a comparison table for different models with given statistics and performance data."""
|
||||
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
|
||||
header = f'| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |'
|
||||
separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'
|
||||
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
|
||||
header = f"| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |"
|
||||
separator = "|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|"
|
||||
|
||||
print(f'\n\n{header}')
|
||||
print(f"\n\n{header}")
|
||||
print(separator)
|
||||
for row in table_rows:
|
||||
print(row)
|
||||
|
|
|
|||
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
|
||||
|
||||
__all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks'
|
||||
__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"
|
||||
|
|
|
|||
|
|
@ -143,37 +143,35 @@ def on_export_end(exporter):
|
|||
|
||||
default_callbacks = {
|
||||
# Run in trainer
|
||||
'on_pretrain_routine_start': [on_pretrain_routine_start],
|
||||
'on_pretrain_routine_end': [on_pretrain_routine_end],
|
||||
'on_train_start': [on_train_start],
|
||||
'on_train_epoch_start': [on_train_epoch_start],
|
||||
'on_train_batch_start': [on_train_batch_start],
|
||||
'optimizer_step': [optimizer_step],
|
||||
'on_before_zero_grad': [on_before_zero_grad],
|
||||
'on_train_batch_end': [on_train_batch_end],
|
||||
'on_train_epoch_end': [on_train_epoch_end],
|
||||
'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
|
||||
'on_model_save': [on_model_save],
|
||||
'on_train_end': [on_train_end],
|
||||
'on_params_update': [on_params_update],
|
||||
'teardown': [teardown],
|
||||
|
||||
"on_pretrain_routine_start": [on_pretrain_routine_start],
|
||||
"on_pretrain_routine_end": [on_pretrain_routine_end],
|
||||
"on_train_start": [on_train_start],
|
||||
"on_train_epoch_start": [on_train_epoch_start],
|
||||
"on_train_batch_start": [on_train_batch_start],
|
||||
"optimizer_step": [optimizer_step],
|
||||
"on_before_zero_grad": [on_before_zero_grad],
|
||||
"on_train_batch_end": [on_train_batch_end],
|
||||
"on_train_epoch_end": [on_train_epoch_end],
|
||||
"on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val
|
||||
"on_model_save": [on_model_save],
|
||||
"on_train_end": [on_train_end],
|
||||
"on_params_update": [on_params_update],
|
||||
"teardown": [teardown],
|
||||
# Run in validator
|
||||
'on_val_start': [on_val_start],
|
||||
'on_val_batch_start': [on_val_batch_start],
|
||||
'on_val_batch_end': [on_val_batch_end],
|
||||
'on_val_end': [on_val_end],
|
||||
|
||||
"on_val_start": [on_val_start],
|
||||
"on_val_batch_start": [on_val_batch_start],
|
||||
"on_val_batch_end": [on_val_batch_end],
|
||||
"on_val_end": [on_val_end],
|
||||
# Run in predictor
|
||||
'on_predict_start': [on_predict_start],
|
||||
'on_predict_batch_start': [on_predict_batch_start],
|
||||
'on_predict_postprocess_end': [on_predict_postprocess_end],
|
||||
'on_predict_batch_end': [on_predict_batch_end],
|
||||
'on_predict_end': [on_predict_end],
|
||||
|
||||
"on_predict_start": [on_predict_start],
|
||||
"on_predict_batch_start": [on_predict_batch_start],
|
||||
"on_predict_postprocess_end": [on_predict_postprocess_end],
|
||||
"on_predict_batch_end": [on_predict_batch_end],
|
||||
"on_predict_end": [on_predict_end],
|
||||
# Run in exporter
|
||||
'on_export_start': [on_export_start],
|
||||
'on_export_end': [on_export_end]}
|
||||
"on_export_start": [on_export_start],
|
||||
"on_export_end": [on_export_end],
|
||||
}
|
||||
|
||||
|
||||
def get_default_callbacks():
|
||||
|
|
@ -197,10 +195,11 @@ def add_integration_callbacks(instance):
|
|||
|
||||
# Load HUB callbacks
|
||||
from .hub import callbacks as hub_cb
|
||||
|
||||
callbacks_list = [hub_cb]
|
||||
|
||||
# Load training callbacks
|
||||
if 'Trainer' in instance.__class__.__name__:
|
||||
if "Trainer" in instance.__class__.__name__:
|
||||
from .clearml import callbacks as clear_cb
|
||||
from .comet import callbacks as comet_cb
|
||||
from .dvc import callbacks as dvc_cb
|
||||
|
|
@ -209,6 +208,7 @@ def add_integration_callbacks(instance):
|
|||
from .raytune import callbacks as tune_cb
|
||||
from .tensorboard import callbacks as tb_cb
|
||||
from .wb import callbacks as wb_cb
|
||||
|
||||
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
|
||||
|
||||
# Add the callbacks to the callbacks dictionary
|
||||
|
|
|
|||
|
|
@ -4,19 +4,19 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
|||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['clearml'] is True # verify integration is enabled
|
||||
assert SETTINGS["clearml"] is True # verify integration is enabled
|
||||
import clearml
|
||||
from clearml import Task
|
||||
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
||||
|
||||
assert hasattr(clearml, '__version__') # verify package is not directory
|
||||
assert hasattr(clearml, "__version__") # verify package is not directory
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
clearml = None
|
||||
|
||||
|
||||
def _log_debug_samples(files, title='Debug Samples') -> None:
|
||||
def _log_debug_samples(files, title="Debug Samples") -> None:
|
||||
"""
|
||||
Log files (images) as debug samples in the ClearML task.
|
||||
|
||||
|
|
@ -29,12 +29,11 @@ def _log_debug_samples(files, title='Debug Samples') -> None:
|
|||
if task := Task.current_task():
|
||||
for f in files:
|
||||
if f.exists():
|
||||
it = re.search(r'_batch(\d+)', f.name)
|
||||
it = re.search(r"_batch(\d+)", f.name)
|
||||
iteration = int(it.groups()[0]) if it else 0
|
||||
task.get_logger().report_image(title=title,
|
||||
series=f.name.replace(it.group(), ''),
|
||||
local_path=str(f),
|
||||
iteration=iteration)
|
||||
task.get_logger().report_image(
|
||||
title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
|
||||
)
|
||||
|
||||
|
||||
def _log_plot(title, plot_path) -> None:
|
||||
|
|
@ -50,13 +49,12 @@ def _log_plot(title, plot_path) -> None:
|
|||
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
|
||||
ax.imshow(img)
|
||||
|
||||
Task.current_task().get_logger().report_matplotlib_figure(title=title,
|
||||
series='',
|
||||
figure=fig,
|
||||
report_interactive=False)
|
||||
Task.current_task().get_logger().report_matplotlib_figure(
|
||||
title=title, series="", figure=fig, report_interactive=False
|
||||
)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
|
|
@ -68,19 +66,21 @@ def on_pretrain_routine_start(trainer):
|
|||
PatchPyTorchModelIO.update_current_task(None)
|
||||
PatchedMatplotlib.update_current_task(None)
|
||||
else:
|
||||
task = Task.init(project_name=trainer.args.project or 'YOLOv8',
|
||||
task_name=trainer.args.name,
|
||||
tags=['YOLOv8'],
|
||||
output_uri=True,
|
||||
reuse_last_task_id=False,
|
||||
auto_connect_frameworks={
|
||||
'pytorch': False,
|
||||
'matplotlib': False})
|
||||
LOGGER.warning('ClearML Initialized a new task. If you want to run remotely, '
|
||||
'please add clearml-init and connect your arguments before initializing YOLO.')
|
||||
task.connect(vars(trainer.args), name='General')
|
||||
task = Task.init(
|
||||
project_name=trainer.args.project or "YOLOv8",
|
||||
task_name=trainer.args.name,
|
||||
tags=["YOLOv8"],
|
||||
output_uri=True,
|
||||
reuse_last_task_id=False,
|
||||
auto_connect_frameworks={"pytorch": False, "matplotlib": False},
|
||||
)
|
||||
LOGGER.warning(
|
||||
"ClearML Initialized a new task. If you want to run remotely, "
|
||||
"please add clearml-init and connect your arguments before initializing YOLO."
|
||||
)
|
||||
task.connect(vars(trainer.args), name="General")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
|
|
@ -88,26 +88,26 @@ def on_train_epoch_end(trainer):
|
|||
if task := Task.current_task():
|
||||
# Log debug samples
|
||||
if trainer.epoch == 1:
|
||||
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
|
||||
_log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
|
||||
# Report the current training progress
|
||||
for k, v in trainer.label_loss_items(trainer.tloss, prefix='train').items():
|
||||
task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
|
||||
for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
|
||||
task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
|
||||
for k, v in trainer.lr.items():
|
||||
task.get_logger().report_scalar('lr', k, v, iteration=trainer.epoch)
|
||||
task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Reports model information to logger at the end of an epoch."""
|
||||
if task := Task.current_task():
|
||||
# You should have access to the validation bboxes under jdict
|
||||
task.get_logger().report_scalar(title='Epoch Time',
|
||||
series='Epoch Time',
|
||||
value=trainer.epoch_time,
|
||||
iteration=trainer.epoch)
|
||||
task.get_logger().report_scalar(
|
||||
title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
|
||||
)
|
||||
for k, v in trainer.metrics.items():
|
||||
task.get_logger().report_scalar('val', k, v, iteration=trainer.epoch)
|
||||
task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch)
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
for k, v in model_info_for_loggers(trainer).items():
|
||||
task.get_logger().report_single_value(k, v)
|
||||
|
||||
|
|
@ -116,7 +116,7 @@ def on_val_end(validator):
|
|||
"""Logs validation results including labels and predictions."""
|
||||
if Task.current_task():
|
||||
# Log val_labels and val_pred
|
||||
_log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation')
|
||||
_log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
|
|
@ -124,8 +124,11 @@ def on_train_end(trainer):
|
|||
if task := Task.current_task():
|
||||
# Log final results, CM matrix + PR plots
|
||||
files = [
|
||||
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
|
||||
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||
"results.png",
|
||||
"confusion_matrix.png",
|
||||
"confusion_matrix_normalized.png",
|
||||
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
|
||||
]
|
||||
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
||||
for f in files:
|
||||
_log_plot(title=f.stem, plot_path=f)
|
||||
|
|
@ -136,9 +139,14 @@ def on_train_end(trainer):
|
|||
task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_val_end': on_val_end,
|
||||
'on_train_end': on_train_end} if clearml else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_val_end": on_val_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if clearml
|
||||
else {}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,20 +4,20 @@ from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
|
|||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['comet'] is True # verify integration is enabled
|
||||
assert SETTINGS["comet"] is True # verify integration is enabled
|
||||
import comet_ml
|
||||
|
||||
assert hasattr(comet_ml, '__version__') # verify package is not directory
|
||||
assert hasattr(comet_ml, "__version__") # verify package is not directory
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Ensures certain logging functions only run for supported tasks
|
||||
COMET_SUPPORTED_TASKS = ['detect']
|
||||
COMET_SUPPORTED_TASKS = ["detect"]
|
||||
|
||||
# Names of plots created by YOLOv8 that are logged to Comet
|
||||
EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix'
|
||||
LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
|
||||
EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix"
|
||||
LABEL_PLOT_NAMES = "labels", "labels_correlogram"
|
||||
|
||||
_comet_image_prediction_count = 0
|
||||
|
||||
|
|
@ -27,43 +27,43 @@ except (ImportError, AssertionError):
|
|||
|
||||
def _get_comet_mode():
|
||||
"""Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
|
||||
return os.getenv('COMET_MODE', 'online')
|
||||
return os.getenv("COMET_MODE", "online")
|
||||
|
||||
|
||||
def _get_comet_model_name():
|
||||
"""Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
|
||||
return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
|
||||
return os.getenv("COMET_MODEL_NAME", "YOLOv8")
|
||||
|
||||
|
||||
def _get_eval_batch_logging_interval():
|
||||
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
|
||||
return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
|
||||
return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
|
||||
|
||||
|
||||
def _get_max_image_predictions_to_log():
|
||||
"""Get the maximum number of image predictions to log from the environment variables."""
|
||||
return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
|
||||
return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
|
||||
|
||||
|
||||
def _scale_confidence_score(score):
|
||||
"""Scales the given confidence score by a factor specified in an environment variable."""
|
||||
scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
|
||||
scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
|
||||
return score * scale
|
||||
|
||||
|
||||
def _should_log_confusion_matrix():
|
||||
"""Determines if the confusion matrix should be logged based on the environment variable settings."""
|
||||
return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true'
|
||||
return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
|
||||
|
||||
|
||||
def _should_log_image_predictions():
|
||||
"""Determines whether to log image predictions based on a specified environment variable."""
|
||||
return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
|
||||
return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
|
||||
|
||||
|
||||
def _get_experiment_type(mode, project_name):
|
||||
"""Return an experiment based on mode and project name."""
|
||||
if mode == 'offline':
|
||||
if mode == "offline":
|
||||
return comet_ml.OfflineExperiment(project_name=project_name)
|
||||
|
||||
return comet_ml.Experiment(project_name=project_name)
|
||||
|
|
@ -75,18 +75,21 @@ def _create_experiment(args):
|
|||
return
|
||||
try:
|
||||
comet_mode = _get_comet_mode()
|
||||
_project_name = os.getenv('COMET_PROJECT_NAME', args.project)
|
||||
_project_name = os.getenv("COMET_PROJECT_NAME", args.project)
|
||||
experiment = _get_experiment_type(comet_mode, _project_name)
|
||||
experiment.log_parameters(vars(args))
|
||||
experiment.log_others({
|
||||
'eval_batch_logging_interval': _get_eval_batch_logging_interval(),
|
||||
'log_confusion_matrix_on_eval': _should_log_confusion_matrix(),
|
||||
'log_image_predictions': _should_log_image_predictions(),
|
||||
'max_image_predictions': _get_max_image_predictions_to_log(), })
|
||||
experiment.log_other('Created from', 'yolov8')
|
||||
experiment.log_others(
|
||||
{
|
||||
"eval_batch_logging_interval": _get_eval_batch_logging_interval(),
|
||||
"log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
|
||||
"log_image_predictions": _should_log_image_predictions(),
|
||||
"max_image_predictions": _get_max_image_predictions_to_log(),
|
||||
}
|
||||
)
|
||||
experiment.log_other("Created from", "yolov8")
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def _fetch_trainer_metadata(trainer):
|
||||
|
|
@ -134,29 +137,32 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
|
|||
|
||||
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
|
||||
"""Format ground truth annotations for detection."""
|
||||
indices = batch['batch_idx'] == img_idx
|
||||
bboxes = batch['bboxes'][indices]
|
||||
indices = batch["batch_idx"] == img_idx
|
||||
bboxes = batch["bboxes"][indices]
|
||||
if len(bboxes) == 0:
|
||||
LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes labels')
|
||||
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels")
|
||||
return None
|
||||
|
||||
cls_labels = batch['cls'][indices].squeeze(1).tolist()
|
||||
cls_labels = batch["cls"][indices].squeeze(1).tolist()
|
||||
if class_name_map:
|
||||
cls_labels = [str(class_name_map[label]) for label in cls_labels]
|
||||
|
||||
original_image_shape = batch['ori_shape'][img_idx]
|
||||
resized_image_shape = batch['resized_shape'][img_idx]
|
||||
ratio_pad = batch['ratio_pad'][img_idx]
|
||||
original_image_shape = batch["ori_shape"][img_idx]
|
||||
resized_image_shape = batch["resized_shape"][img_idx]
|
||||
ratio_pad = batch["ratio_pad"][img_idx]
|
||||
|
||||
data = []
|
||||
for box, label in zip(bboxes, cls_labels):
|
||||
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
|
||||
data.append({
|
||||
'boxes': [box],
|
||||
'label': f'gt_{label}',
|
||||
'score': _scale_confidence_score(1.0), })
|
||||
data.append(
|
||||
{
|
||||
"boxes": [box],
|
||||
"label": f"gt_{label}",
|
||||
"score": _scale_confidence_score(1.0),
|
||||
}
|
||||
)
|
||||
|
||||
return {'name': 'ground_truth', 'data': data}
|
||||
return {"name": "ground_truth", "data": data}
|
||||
|
||||
|
||||
def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
|
||||
|
|
@ -166,31 +172,34 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab
|
|||
|
||||
predictions = metadata.get(image_id)
|
||||
if not predictions:
|
||||
LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes predictions')
|
||||
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions")
|
||||
return None
|
||||
|
||||
data = []
|
||||
for prediction in predictions:
|
||||
boxes = prediction['bbox']
|
||||
score = _scale_confidence_score(prediction['score'])
|
||||
cls_label = prediction['category_id']
|
||||
boxes = prediction["bbox"]
|
||||
score = _scale_confidence_score(prediction["score"])
|
||||
cls_label = prediction["category_id"]
|
||||
if class_label_map:
|
||||
cls_label = str(class_label_map[cls_label])
|
||||
|
||||
data.append({'boxes': [boxes], 'label': cls_label, 'score': score})
|
||||
data.append({"boxes": [boxes], "label": cls_label, "score": score})
|
||||
|
||||
return {'name': 'prediction', 'data': data}
|
||||
return {"name": "prediction", "data": data}
|
||||
|
||||
|
||||
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
|
||||
"""Join the ground truth and prediction annotations if they exist."""
|
||||
ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch,
|
||||
class_label_map)
|
||||
prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map,
|
||||
class_label_map)
|
||||
ground_truth_annotations = _format_ground_truth_annotations_for_detection(
|
||||
img_idx, image_path, batch, class_label_map
|
||||
)
|
||||
prediction_annotations = _format_prediction_annotations_for_detection(
|
||||
image_path, prediction_metadata_map, class_label_map
|
||||
)
|
||||
|
||||
annotations = [
|
||||
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None]
|
||||
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
|
||||
]
|
||||
return [annotations] if annotations else None
|
||||
|
||||
|
||||
|
|
@ -198,8 +207,8 @@ def _create_prediction_metadata_map(model_predictions):
|
|||
"""Create metadata map for model predictions by groupings them based on image ID."""
|
||||
pred_metadata_map = {}
|
||||
for prediction in model_predictions:
|
||||
pred_metadata_map.setdefault(prediction['image_id'], [])
|
||||
pred_metadata_map[prediction['image_id']].append(prediction)
|
||||
pred_metadata_map.setdefault(prediction["image_id"], [])
|
||||
pred_metadata_map[prediction["image_id"]].append(prediction)
|
||||
|
||||
return pred_metadata_map
|
||||
|
||||
|
|
@ -207,7 +216,7 @@ def _create_prediction_metadata_map(model_predictions):
|
|||
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
|
||||
"""Log the confusion matrix to Comet experiment."""
|
||||
conf_mat = trainer.validator.confusion_matrix.matrix
|
||||
names = list(trainer.data['names'].values()) + ['background']
|
||||
names = list(trainer.data["names"].values()) + ["background"]
|
||||
experiment.log_confusion_matrix(
|
||||
matrix=conf_mat,
|
||||
labels=names,
|
||||
|
|
@ -251,7 +260,7 @@ def _log_image_predictions(experiment, validator, curr_step):
|
|||
if (batch_idx + 1) % batch_logging_interval != 0:
|
||||
continue
|
||||
|
||||
image_paths = batch['im_file']
|
||||
image_paths = batch["im_file"]
|
||||
for img_idx, image_path in enumerate(image_paths):
|
||||
if _comet_image_prediction_count >= max_image_predictions:
|
||||
return
|
||||
|
|
@ -275,10 +284,10 @@ def _log_image_predictions(experiment, validator, curr_step):
|
|||
|
||||
def _log_plots(experiment, trainer):
|
||||
"""Logs evaluation plots and label plots for the experiment."""
|
||||
plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES]
|
||||
plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
|
||||
_log_images(experiment, plot_filenames, None)
|
||||
|
||||
label_plot_filenames = [trainer.save_dir / f'{labels}.jpg' for labels in LABEL_PLOT_NAMES]
|
||||
label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
|
||||
_log_images(experiment, label_plot_filenames, None)
|
||||
|
||||
|
||||
|
|
@ -288,7 +297,7 @@ def _log_model(experiment, trainer):
|
|||
experiment.log_model(
|
||||
model_name,
|
||||
file_or_folder=str(trainer.best),
|
||||
file_name='best.pt',
|
||||
file_name="best.pt",
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
|
@ -296,7 +305,7 @@ def _log_model(experiment, trainer):
|
|||
def on_pretrain_routine_start(trainer):
|
||||
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
is_alive = getattr(experiment, 'alive', False)
|
||||
is_alive = getattr(experiment, "alive", False)
|
||||
if not experiment or not is_alive:
|
||||
_create_experiment(trainer.args)
|
||||
|
||||
|
|
@ -308,17 +317,17 @@ def on_train_epoch_end(trainer):
|
|||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata['curr_epoch']
|
||||
curr_step = metadata['curr_step']
|
||||
curr_epoch = metadata["curr_epoch"]
|
||||
curr_step = metadata["curr_step"]
|
||||
|
||||
experiment.log_metrics(
|
||||
trainer.label_loss_items(trainer.tloss, prefix='train'),
|
||||
trainer.label_loss_items(trainer.tloss, prefix="train"),
|
||||
step=curr_step,
|
||||
epoch=curr_epoch,
|
||||
)
|
||||
|
||||
if curr_epoch == 1:
|
||||
_log_images(experiment, trainer.save_dir.glob('train_batch*.jpg'), curr_step)
|
||||
_log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
|
|
@ -328,14 +337,15 @@ def on_fit_epoch_end(trainer):
|
|||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata['curr_epoch']
|
||||
curr_step = metadata['curr_step']
|
||||
save_assets = metadata['save_assets']
|
||||
curr_epoch = metadata["curr_epoch"]
|
||||
curr_step = metadata["curr_step"]
|
||||
save_assets = metadata["save_assets"]
|
||||
|
||||
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
|
||||
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
|
||||
if curr_epoch == 1:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
|
||||
|
||||
if not save_assets:
|
||||
|
|
@ -355,8 +365,8 @@ def on_train_end(trainer):
|
|||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata['curr_epoch']
|
||||
curr_step = metadata['curr_step']
|
||||
curr_epoch = metadata["curr_epoch"]
|
||||
curr_step = metadata["curr_step"]
|
||||
plots = trainer.args.plots
|
||||
|
||||
_log_model(experiment, trainer)
|
||||
|
|
@ -371,8 +381,13 @@ def on_train_end(trainer):
|
|||
_comet_image_prediction_count = 0
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_end': on_train_end} if comet_ml else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if comet_ml
|
||||
else {}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
|
|||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['dvc'] is True # verify integration is enabled
|
||||
assert SETTINGS["dvc"] is True # verify integration is enabled
|
||||
import dvclive
|
||||
assert checks.check_version('dvclive', '2.11.0', verbose=True)
|
||||
|
||||
assert checks.check_version("dvclive", "2.11.0", verbose=True)
|
||||
|
||||
import os
|
||||
import re
|
||||
|
|
@ -24,24 +25,24 @@ except (ImportError, AssertionError, TypeError):
|
|||
dvclive = None
|
||||
|
||||
|
||||
def _log_images(path, prefix=''):
|
||||
def _log_images(path, prefix=""):
|
||||
"""Logs images at specified path with an optional prefix using DVCLive."""
|
||||
if live:
|
||||
name = path.name
|
||||
|
||||
# Group images by batch to enable sliders in UI
|
||||
if m := re.search(r'_batch(\d+)', name):
|
||||
if m := re.search(r"_batch(\d+)", name):
|
||||
ni = m[1]
|
||||
new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem)
|
||||
new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
|
||||
name = (Path(new_stem) / ni).with_suffix(path.suffix)
|
||||
|
||||
live.log_image(os.path.join(prefix, name), path)
|
||||
|
||||
|
||||
def _log_plots(plots, prefix=''):
|
||||
def _log_plots(plots, prefix=""):
|
||||
"""Logs plot images for training progress if they have not been previously processed."""
|
||||
for name, params in plots.items():
|
||||
timestamp = params['timestamp']
|
||||
timestamp = params["timestamp"]
|
||||
if _processed_plots.get(name) != timestamp:
|
||||
_log_images(name, prefix)
|
||||
_processed_plots[name] = timestamp
|
||||
|
|
@ -53,15 +54,15 @@ def _log_confusion_matrix(validator):
|
|||
preds = []
|
||||
matrix = validator.confusion_matrix.matrix
|
||||
names = list(validator.names.values())
|
||||
if validator.confusion_matrix.task == 'detect':
|
||||
names += ['background']
|
||||
if validator.confusion_matrix.task == "detect":
|
||||
names += ["background"]
|
||||
|
||||
for ti, pred in enumerate(matrix.T.astype(int)):
|
||||
for pi, num in enumerate(pred):
|
||||
targets.extend([names[ti]] * num)
|
||||
preds.extend([names[pi]] * num)
|
||||
|
||||
live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True)
|
||||
live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
|
|
@ -71,12 +72,12 @@ def on_pretrain_routine_start(trainer):
|
|||
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
|
||||
LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Logs plots related to the training process at the end of the pretraining routine."""
|
||||
_log_plots(trainer.plots, 'train')
|
||||
_log_plots(trainer.plots, "train")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
|
|
@ -95,17 +96,18 @@ def on_fit_epoch_end(trainer):
|
|||
"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
|
||||
global _training_epoch
|
||||
if live and _training_epoch:
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
|
||||
for metric, value in all_metrics.items():
|
||||
live.log_metric(metric, value)
|
||||
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
for metric, value in model_info_for_loggers(trainer).items():
|
||||
live.log_metric(metric, value, plot=False)
|
||||
|
||||
_log_plots(trainer.plots, 'train')
|
||||
_log_plots(trainer.validator.plots, 'val')
|
||||
_log_plots(trainer.plots, "train")
|
||||
_log_plots(trainer.validator.plots, "val")
|
||||
|
||||
live.next_step()
|
||||
_training_epoch = False
|
||||
|
|
@ -115,24 +117,29 @@ def on_train_end(trainer):
|
|||
"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
|
||||
if live:
|
||||
# At the end log the best metrics. It runs validator on the best model internally.
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
|
||||
for metric, value in all_metrics.items():
|
||||
live.log_metric(metric, value, plot=False)
|
||||
|
||||
_log_plots(trainer.plots, 'val')
|
||||
_log_plots(trainer.validator.plots, 'val')
|
||||
_log_plots(trainer.plots, "val")
|
||||
_log_plots(trainer.validator.plots, "val")
|
||||
_log_confusion_matrix(trainer.validator)
|
||||
|
||||
if trainer.best.exists():
|
||||
live.log_artifact(trainer.best, copy=True, type='model')
|
||||
live.log_artifact(trainer.best, copy=True, type="model")
|
||||
|
||||
live.end()
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||
'on_train_start': on_train_start,
|
||||
'on_train_epoch_start': on_train_epoch_start,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_end': on_train_end} if dvclive else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_train_start": on_train_start,
|
||||
"on_train_epoch_start": on_train_epoch_start,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if dvclive
|
||||
else {}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,60 +11,62 @@ from ultralytics.utils import LOGGER, SETTINGS
|
|||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Logs info before starting timer for upload rate limit."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
session = getattr(trainer, "hub_session", None)
|
||||
if session:
|
||||
# Start timer for upload rate limit
|
||||
session.timers = {
|
||||
'metrics': time(),
|
||||
'ckpt': time(), } # start timer on session.rate_limit
|
||||
"metrics": time(),
|
||||
"ckpt": time(),
|
||||
} # start timer on session.rate_limit
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Uploads training progress metrics at the end of each epoch."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
session = getattr(trainer, "hub_session", None)
|
||||
if session:
|
||||
# Upload metrics after val end
|
||||
all_plots = {
|
||||
**trainer.label_loss_items(trainer.tloss, prefix='train'),
|
||||
**trainer.metrics, }
|
||||
**trainer.label_loss_items(trainer.tloss, prefix="train"),
|
||||
**trainer.metrics,
|
||||
}
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
|
||||
|
||||
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
||||
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
|
||||
if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
|
||||
session.upload_metrics()
|
||||
session.timers['metrics'] = time() # reset timer
|
||||
session.timers["metrics"] = time() # reset timer
|
||||
session.metrics_queue = {} # reset queue
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
"""Saves checkpoints to Ultralytics HUB with rate limiting."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
session = getattr(trainer, "hub_session", None)
|
||||
if session:
|
||||
# Upload checkpoints with rate limiting
|
||||
is_best = trainer.best_fitness == trainer.fitness
|
||||
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
|
||||
LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}')
|
||||
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
|
||||
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}")
|
||||
session.upload_model(trainer.epoch, trainer.last, is_best)
|
||||
session.timers['ckpt'] = time() # reset timer
|
||||
session.timers["ckpt"] = time() # reset timer
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
session = getattr(trainer, "hub_session", None)
|
||||
if session:
|
||||
# Upload final model and metrics with exponential standoff
|
||||
LOGGER.info(f'{PREFIX}Syncing final model...')
|
||||
LOGGER.info(f"{PREFIX}Syncing final model...")
|
||||
session.upload_model(
|
||||
trainer.epoch,
|
||||
trainer.best,
|
||||
map=trainer.metrics.get('metrics/mAP50-95(B)', 0),
|
||||
map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
|
||||
final=True,
|
||||
)
|
||||
session.alive = False # stop heartbeats
|
||||
LOGGER.info(f'{PREFIX}Done ✅\n'
|
||||
f'{PREFIX}View model at {session.model_url} 🚀')
|
||||
LOGGER.info(f"{PREFIX}Done ✅\n" f"{PREFIX}View model at {session.model_url} 🚀")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
|
|
@ -87,12 +89,17 @@ def on_export_start(exporter):
|
|||
events(exporter.args)
|
||||
|
||||
|
||||
callbacks = ({
|
||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_model_save': on_model_save,
|
||||
'on_train_end': on_train_end,
|
||||
'on_train_start': on_train_start,
|
||||
'on_val_start': on_val_start,
|
||||
'on_predict_start': on_predict_start,
|
||||
'on_export_start': on_export_start, } if SETTINGS['hub'] is True else {}) # verify enabled
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_model_save": on_model_save,
|
||||
"on_train_end": on_train_end,
|
||||
"on_train_start": on_train_start,
|
||||
"on_val_start": on_val_start,
|
||||
"on_predict_start": on_predict_start,
|
||||
"on_export_start": on_export_start,
|
||||
}
|
||||
if SETTINGS["hub"] is True
|
||||
else {}
|
||||
) # verify enabled
|
||||
|
|
|
|||
|
|
@ -26,15 +26,15 @@ from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorst
|
|||
try:
|
||||
import os
|
||||
|
||||
assert not TESTS_RUNNING or 'test_mlflow' in os.environ.get('PYTEST_CURRENT_TEST', '') # do not log pytest
|
||||
assert SETTINGS['mlflow'] is True # verify integration is enabled
|
||||
assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
|
||||
assert SETTINGS["mlflow"] is True # verify integration is enabled
|
||||
import mlflow
|
||||
|
||||
assert hasattr(mlflow, '__version__') # verify package is not directory
|
||||
assert hasattr(mlflow, "__version__") # verify package is not directory
|
||||
from pathlib import Path
|
||||
|
||||
PREFIX = colorstr('MLflow: ')
|
||||
SANITIZE = lambda x: {k.replace('(', '').replace(')', ''): float(v) for k, v in x.items()}
|
||||
PREFIX = colorstr("MLflow: ")
|
||||
SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
mlflow = None
|
||||
|
|
@ -61,33 +61,33 @@ def on_pretrain_routine_end(trainer):
|
|||
"""
|
||||
global mlflow
|
||||
|
||||
uri = os.environ.get('MLFLOW_TRACKING_URI') or str(RUNS_DIR / 'mlflow')
|
||||
LOGGER.debug(f'{PREFIX} tracking uri: {uri}')
|
||||
uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
|
||||
LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
|
||||
mlflow.set_tracking_uri(uri)
|
||||
|
||||
# Set experiment and run names
|
||||
experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8'
|
||||
run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name
|
||||
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
|
||||
run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
|
||||
mlflow.set_experiment(experiment_name)
|
||||
|
||||
mlflow.autolog()
|
||||
try:
|
||||
active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
|
||||
LOGGER.info(f'{PREFIX}logging run_id({active_run.info.run_id}) to {uri}')
|
||||
LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
|
||||
if Path(uri).is_dir():
|
||||
LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
|
||||
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
|
||||
mlflow.log_params(dict(trainer.args))
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n'
|
||||
f'{PREFIX}WARNING ⚠️ Not tracking this run')
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log training metrics at the end of each train epoch to MLflow."""
|
||||
if mlflow:
|
||||
mlflow.log_metrics(metrics=SANITIZE(trainer.label_loss_items(trainer.tloss, prefix='train')),
|
||||
step=trainer.epoch)
|
||||
mlflow.log_metrics(
|
||||
metrics=SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")), step=trainer.epoch
|
||||
)
|
||||
mlflow.log_metrics(metrics=SANITIZE(trainer.lr), step=trainer.epoch)
|
||||
|
||||
|
||||
|
|
@ -101,16 +101,23 @@ def on_train_end(trainer):
|
|||
"""Log model artifacts at the end of the training."""
|
||||
if mlflow:
|
||||
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
|
||||
for f in trainer.save_dir.glob('*'): # log all other files in save_dir
|
||||
if f.suffix in {'.png', '.jpg', '.csv', '.pt', '.yaml'}:
|
||||
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
|
||||
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
|
||||
mlflow.log_artifact(str(f))
|
||||
|
||||
mlflow.end_run()
|
||||
LOGGER.info(f'{PREFIX}results logged to {mlflow.get_tracking_uri()}\n'
|
||||
f"{PREFIX}disable with 'yolo settings mlflow=False'")
|
||||
LOGGER.info(
|
||||
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
|
||||
f"{PREFIX}disable with 'yolo settings mlflow=False'"
|
||||
)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_end': on_train_end} if mlflow else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if mlflow
|
||||
else {}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
|||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['neptune'] is True # verify integration is enabled
|
||||
assert SETTINGS["neptune"] is True # verify integration is enabled
|
||||
import neptune
|
||||
from neptune.types import File
|
||||
|
||||
assert hasattr(neptune, '__version__')
|
||||
assert hasattr(neptune, "__version__")
|
||||
|
||||
run = None # NeptuneAI experiment logger instance
|
||||
|
||||
|
|
@ -23,11 +23,11 @@ def _log_scalars(scalars, step=0):
|
|||
run[k].append(value=v, step=step)
|
||||
|
||||
|
||||
def _log_images(imgs_dict, group=''):
|
||||
def _log_images(imgs_dict, group=""):
|
||||
"""Log scalars to the NeptuneAI experiment logger."""
|
||||
if run:
|
||||
for k, v in imgs_dict.items():
|
||||
run[f'{group}/{k}'].upload(File(v))
|
||||
run[f"{group}/{k}"].upload(File(v))
|
||||
|
||||
|
||||
def _log_plot(title, plot_path):
|
||||
|
|
@ -43,34 +43,35 @@ def _log_plot(title, plot_path):
|
|||
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
|
||||
ax.imshow(img)
|
||||
run[f'Plots/{title}'].upload(fig)
|
||||
run[f"Plots/{title}"].upload(fig)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Callback function called before the training routine starts."""
|
||||
try:
|
||||
global run
|
||||
run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8'])
|
||||
run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()}
|
||||
run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"])
|
||||
run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Callback function called at end of each training epoch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
||||
_log_scalars(trainer.lr, trainer.epoch + 1)
|
||||
if trainer.epoch == 1:
|
||||
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic')
|
||||
_log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Callback function called at end of each fit (train+val) epoch."""
|
||||
if run and trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
run['Configuration/Model'] = model_info_for_loggers(trainer)
|
||||
|
||||
run["Configuration/Model"] = model_info_for_loggers(trainer)
|
||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
||||
|
|
@ -78,7 +79,7 @@ def on_val_end(validator):
|
|||
"""Callback function called at end of each validation."""
|
||||
if run:
|
||||
# Log val_labels and val_pred
|
||||
_log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation')
|
||||
_log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
|
|
@ -86,19 +87,28 @@ def on_train_end(trainer):
|
|||
if run:
|
||||
# Log final results, CM matrix + PR plots
|
||||
files = [
|
||||
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
|
||||
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||
"results.png",
|
||||
"confusion_matrix.png",
|
||||
"confusion_matrix_normalized.png",
|
||||
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
|
||||
]
|
||||
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
||||
for f in files:
|
||||
_log_plot(title=f.stem, plot_path=f)
|
||||
# Log the final model
|
||||
run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str(
|
||||
trainer.best)))
|
||||
run[f"weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}"].upload(
|
||||
File(str(trainer.best))
|
||||
)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_val_end': on_val_end,
|
||||
'on_train_end': on_train_end} if neptune else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_val_end": on_val_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if neptune
|
||||
else {}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
from ultralytics.utils import SETTINGS
|
||||
|
||||
try:
|
||||
assert SETTINGS['raytune'] is True # verify integration is enabled
|
||||
assert SETTINGS["raytune"] is True # verify integration is enabled
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air import session
|
||||
|
|
@ -16,9 +16,14 @@ def on_fit_epoch_end(trainer):
|
|||
"""Sends training metrics to Ray Tune at end of each epoch."""
|
||||
if ray.tune.is_session_enabled():
|
||||
metrics = trainer.metrics
|
||||
metrics['epoch'] = trainer.epoch
|
||||
metrics["epoch"] = trainer.epoch
|
||||
session.report(metrics)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_fit_epoch_end': on_fit_epoch_end, } if tune else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
}
|
||||
if tune
|
||||
else {}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ try:
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['tensorboard'] is True # verify integration is enabled
|
||||
assert SETTINGS["tensorboard"] is True # verify integration is enabled
|
||||
WRITER = None # TensorBoard SummaryWriter instance
|
||||
|
||||
except (ImportError, AssertionError, TypeError):
|
||||
|
|
@ -34,10 +34,10 @@ def _log_tensorboard_graph(trainer):
|
|||
p = next(trainer.model.parameters()) # for device, type
|
||||
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning
|
||||
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
|
||||
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ TensorBoard graph visualization failure {e}")
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
|
|
@ -46,10 +46,10 @@ def on_pretrain_routine_start(trainer):
|
|||
try:
|
||||
global WRITER
|
||||
WRITER = SummaryWriter(str(trainer.save_dir))
|
||||
prefix = colorstr('TensorBoard: ')
|
||||
prefix = colorstr("TensorBoard: ")
|
||||
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
|
|
@ -60,7 +60,7 @@ def on_train_start(trainer):
|
|||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Logs scalar statistics at the end of a training epoch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
||||
_log_scalars(trainer.lr, trainer.epoch + 1)
|
||||
|
||||
|
||||
|
|
@ -69,8 +69,13 @@ def on_fit_epoch_end(trainer):
|
|||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_start': on_train_start,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_epoch_end': on_train_epoch_end} if SummaryWriter else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_start": on_train_start,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
}
|
||||
if SummaryWriter
|
||||
else {}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@ from ultralytics.utils.torch_utils import model_info_for_loggers
|
|||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['wandb'] is True # verify integration is enabled
|
||||
assert SETTINGS["wandb"] is True # verify integration is enabled
|
||||
import wandb as wb
|
||||
|
||||
assert hasattr(wb, '__version__') # verify package is not directory
|
||||
assert hasattr(wb, "__version__") # verify package is not directory
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
|
@ -19,7 +19,7 @@ except (ImportError, AssertionError):
|
|||
wb = None
|
||||
|
||||
|
||||
def _custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall', y_title='Precision'):
|
||||
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
|
||||
"""
|
||||
Create and log a custom metric visualization to wandb.plot.pr_curve.
|
||||
|
||||
|
|
@ -37,24 +37,25 @@ def _custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall
|
|||
Returns:
|
||||
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
|
||||
"""
|
||||
df = pd.DataFrame({'class': classes, 'y': y, 'x': x}).round(3)
|
||||
fields = {'x': 'x', 'y': 'y', 'class': 'class'}
|
||||
string_fields = {'title': title, 'x-axis-title': x_title, 'y-axis-title': y_title}
|
||||
return wb.plot_table('wandb/area-under-curve/v0',
|
||||
wb.Table(dataframe=df),
|
||||
fields=fields,
|
||||
string_fields=string_fields)
|
||||
df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
||||
fields = {"x": "x", "y": "y", "class": "class"}
|
||||
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
|
||||
return wb.plot_table(
|
||||
"wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
|
||||
)
|
||||
|
||||
|
||||
def _plot_curve(x,
|
||||
y,
|
||||
names=None,
|
||||
id='precision-recall',
|
||||
title='Precision Recall Curve',
|
||||
x_title='Recall',
|
||||
y_title='Precision',
|
||||
num_x=100,
|
||||
only_mean=False):
|
||||
def _plot_curve(
|
||||
x,
|
||||
y,
|
||||
names=None,
|
||||
id="precision-recall",
|
||||
title="Precision Recall Curve",
|
||||
x_title="Recall",
|
||||
y_title="Precision",
|
||||
num_x=100,
|
||||
only_mean=False,
|
||||
):
|
||||
"""
|
||||
Log a metric curve visualization.
|
||||
|
||||
|
|
@ -88,7 +89,7 @@ def _plot_curve(x,
|
|||
table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
|
||||
wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
|
||||
else:
|
||||
classes = ['mean'] * len(x_log)
|
||||
classes = ["mean"] * len(x_log)
|
||||
for i, yi in enumerate(y):
|
||||
x_log.extend(x_new) # add new x
|
||||
y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
|
||||
|
|
@ -99,7 +100,7 @@ def _plot_curve(x,
|
|||
def _log_plots(plots, step):
|
||||
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
|
||||
for name, params in plots.items():
|
||||
timestamp = params['timestamp']
|
||||
timestamp = params["timestamp"]
|
||||
if _processed_plots.get(name) != timestamp:
|
||||
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
|
||||
_processed_plots[name] = timestamp
|
||||
|
|
@ -107,7 +108,7 @@ def _log_plots(plots, step):
|
|||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initiate and start project if module is present."""
|
||||
wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args))
|
||||
wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
|
|
@ -121,7 +122,7 @@ def on_fit_epoch_end(trainer):
|
|||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log metrics and save images at the end of each training epoch."""
|
||||
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
|
||||
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
|
||||
wb.run.log(trainer.lr, step=trainer.epoch + 1)
|
||||
if trainer.epoch == 1:
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
|
|
@ -131,17 +132,17 @@ def on_train_end(trainer):
|
|||
"""Save the best model as an artifact at end of training."""
|
||||
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
|
||||
art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
|
||||
if trainer.best.exists():
|
||||
art.add_file(trainer.best)
|
||||
wb.run.log_artifact(art, aliases=['best'])
|
||||
wb.run.log_artifact(art, aliases=["best"])
|
||||
for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
|
||||
x, y, x_title, y_title = curve_values
|
||||
_plot_curve(
|
||||
x,
|
||||
y,
|
||||
names=list(trainer.validator.metrics.names.values()),
|
||||
id=f'curves/{curve_name}',
|
||||
id=f"curves/{curve_name}",
|
||||
title=curve_name,
|
||||
x_title=x_title,
|
||||
y_title=y_title,
|
||||
|
|
@ -149,8 +150,13 @@ def on_train_end(trainer):
|
|||
wb.run.finish() # required or run continues on dashboard
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_end': on_train_end} if wb else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if wb
|
||||
else {}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,12 +21,33 @@ import requests
|
|||
import torch
|
||||
from matplotlib import font_manager
|
||||
|
||||
from ultralytics.utils import (ASSETS, AUTOINSTALL, LINUX, LOGGER, ONLINE, ROOT, USER_CONFIG_DIR, SimpleNamespace,
|
||||
ThreadingLocked, TryExcept, clean_url, colorstr, downloads, emojis, is_colab, is_docker,
|
||||
is_github_action_running, is_jupyter, is_kaggle, is_online, is_pip_package, url2file)
|
||||
from ultralytics.utils import (
|
||||
ASSETS,
|
||||
AUTOINSTALL,
|
||||
LINUX,
|
||||
LOGGER,
|
||||
ONLINE,
|
||||
ROOT,
|
||||
USER_CONFIG_DIR,
|
||||
SimpleNamespace,
|
||||
ThreadingLocked,
|
||||
TryExcept,
|
||||
clean_url,
|
||||
colorstr,
|
||||
downloads,
|
||||
emojis,
|
||||
is_colab,
|
||||
is_docker,
|
||||
is_github_action_running,
|
||||
is_jupyter,
|
||||
is_kaggle,
|
||||
is_online,
|
||||
is_pip_package,
|
||||
url2file,
|
||||
)
|
||||
|
||||
|
||||
def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
|
||||
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
|
||||
"""
|
||||
Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
|
||||
|
||||
|
|
@ -46,23 +67,23 @@ def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
|
|||
"""
|
||||
|
||||
if package:
|
||||
requires = [x for x in metadata.distribution(package).requires if 'extra == ' not in x]
|
||||
requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
|
||||
else:
|
||||
requires = Path(file_path).read_text().splitlines()
|
||||
|
||||
requirements = []
|
||||
for line in requires:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
line = line.split('#')[0].strip() # ignore inline comments
|
||||
match = re.match(r'([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?', line)
|
||||
if line and not line.startswith("#"):
|
||||
line = line.split("#")[0].strip() # ignore inline comments
|
||||
match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line)
|
||||
if match:
|
||||
requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ''))
|
||||
requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
|
||||
|
||||
return requirements
|
||||
|
||||
|
||||
def parse_version(version='0.0.0') -> tuple:
|
||||
def parse_version(version="0.0.0") -> tuple:
|
||||
"""
|
||||
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
|
||||
function replaces deprecated 'pkg_resources.parse_version(v)'.
|
||||
|
|
@ -74,9 +95,9 @@ def parse_version(version='0.0.0') -> tuple:
|
|||
(tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
|
||||
"""
|
||||
try:
|
||||
return tuple(map(int, re.findall(r'\d+', version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
|
||||
return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}")
|
||||
return 0, 0, 0
|
||||
|
||||
|
||||
|
|
@ -121,15 +142,19 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
|||
elif isinstance(imgsz, (list, tuple)):
|
||||
imgsz = list(imgsz)
|
||||
else:
|
||||
raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
|
||||
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'")
|
||||
raise TypeError(
|
||||
f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
|
||||
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
|
||||
)
|
||||
|
||||
# Apply max_dim
|
||||
if len(imgsz) > max_dim:
|
||||
msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \
|
||||
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
|
||||
msg = (
|
||||
"'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
|
||||
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
|
||||
)
|
||||
if max_dim != 1:
|
||||
raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}')
|
||||
raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
|
||||
LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
|
||||
imgsz = [max(imgsz)]
|
||||
# Make image size a multiple of the stride
|
||||
|
|
@ -137,7 +162,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
|||
|
||||
# Print warning message if image size was updated
|
||||
if sz != imgsz:
|
||||
LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}')
|
||||
LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
|
||||
|
||||
# Add missing dimensions if necessary
|
||||
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
|
||||
|
|
@ -145,12 +170,14 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
|||
return sz
|
||||
|
||||
|
||||
def check_version(current: str = '0.0.0',
|
||||
required: str = '0.0.0',
|
||||
name: str = 'version',
|
||||
hard: bool = False,
|
||||
verbose: bool = False,
|
||||
msg: str = '') -> bool:
|
||||
def check_version(
|
||||
current: str = "0.0.0",
|
||||
required: str = "0.0.0",
|
||||
name: str = "version",
|
||||
hard: bool = False,
|
||||
verbose: bool = False,
|
||||
msg: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
Check current version against the required version or range.
|
||||
|
||||
|
|
@ -181,7 +208,7 @@ def check_version(current: str = '0.0.0',
|
|||
```
|
||||
"""
|
||||
if not current: # if current is '' or None
|
||||
LOGGER.warning(f'WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.')
|
||||
LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.")
|
||||
return True
|
||||
elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'
|
||||
try:
|
||||
|
|
@ -189,34 +216,34 @@ def check_version(current: str = '0.0.0',
|
|||
current = metadata.version(current) # get version string from package name
|
||||
except metadata.PackageNotFoundError:
|
||||
if hard:
|
||||
raise ModuleNotFoundError(emojis(f'WARNING ⚠️ {current} package is required but not installed'))
|
||||
raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed"))
|
||||
else:
|
||||
return False
|
||||
|
||||
if not required: # if required is '' or None
|
||||
return True
|
||||
|
||||
op = ''
|
||||
version = ''
|
||||
op = ""
|
||||
version = ""
|
||||
result = True
|
||||
c = parse_version(current) # '1.2.3' -> (1, 2, 3)
|
||||
for r in required.strip(',').split(','):
|
||||
op, version = re.match(r'([^0-9]*)([\d.]+)', r).groups() # split '>=22.04' -> ('>=', '22.04')
|
||||
for r in required.strip(",").split(","):
|
||||
op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
|
||||
v = parse_version(version) # '1.2.3' -> (1, 2, 3)
|
||||
if op == '==' and c != v:
|
||||
if op == "==" and c != v:
|
||||
result = False
|
||||
elif op == '!=' and c == v:
|
||||
elif op == "!=" and c == v:
|
||||
result = False
|
||||
elif op in ('>=', '') and not (c >= v): # if no constraint passed assume '>=required'
|
||||
elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required'
|
||||
result = False
|
||||
elif op == '<=' and not (c <= v):
|
||||
elif op == "<=" and not (c <= v):
|
||||
result = False
|
||||
elif op == '>' and not (c > v):
|
||||
elif op == ">" and not (c > v):
|
||||
result = False
|
||||
elif op == '<' and not (c < v):
|
||||
elif op == "<" and not (c < v):
|
||||
result = False
|
||||
if not result:
|
||||
warning = f'WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}'
|
||||
warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}"
|
||||
if hard:
|
||||
raise ModuleNotFoundError(emojis(warning)) # assert version requirements met
|
||||
if verbose:
|
||||
|
|
@ -224,7 +251,7 @@ def check_version(current: str = '0.0.0',
|
|||
return result
|
||||
|
||||
|
||||
def check_latest_pypi_version(package_name='ultralytics'):
|
||||
def check_latest_pypi_version(package_name="ultralytics"):
|
||||
"""
|
||||
Returns the latest version of a PyPI package without downloading or installing it.
|
||||
|
||||
|
|
@ -236,9 +263,9 @@ def check_latest_pypi_version(package_name='ultralytics'):
|
|||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
|
||||
response = requests.get(f'https://pypi.org/pypi/{package_name}/json', timeout=3)
|
||||
response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
|
||||
if response.status_code == 200:
|
||||
return response.json()['info']['version']
|
||||
return response.json()["info"]["version"]
|
||||
|
||||
|
||||
def check_pip_update_available():
|
||||
|
|
@ -251,16 +278,19 @@ def check_pip_update_available():
|
|||
if ONLINE and is_pip_package():
|
||||
with contextlib.suppress(Exception):
|
||||
from ultralytics import __version__
|
||||
|
||||
latest = check_latest_pypi_version()
|
||||
if check_version(__version__, f'<{latest}'): # check if current version is < latest version
|
||||
LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 '
|
||||
f"Update with 'pip install -U ultralytics'")
|
||||
if check_version(__version__, f"<{latest}"): # check if current version is < latest version
|
||||
LOGGER.info(
|
||||
f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
|
||||
f"Update with 'pip install -U ultralytics'"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ThreadingLocked()
|
||||
def check_font(font='Arial.ttf'):
|
||||
def check_font(font="Arial.ttf"):
|
||||
"""
|
||||
Find font locally or download to user's configuration directory if it does not already exist.
|
||||
|
||||
|
|
@ -283,13 +313,13 @@ def check_font(font='Arial.ttf'):
|
|||
return matches[0]
|
||||
|
||||
# Download to USER_CONFIG_DIR if missing
|
||||
url = f'https://ultralytics.com/assets/{name}'
|
||||
url = f"https://ultralytics.com/assets/{name}"
|
||||
if downloads.is_url(url):
|
||||
downloads.safe_download(url=url, file=file)
|
||||
return file
|
||||
|
||||
|
||||
def check_python(minimum: str = '3.8.0') -> bool:
|
||||
def check_python(minimum: str = "3.8.0") -> bool:
|
||||
"""
|
||||
Check current python version against the required minimum version.
|
||||
|
||||
|
|
@ -299,11 +329,11 @@ def check_python(minimum: str = '3.8.0') -> bool:
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
return check_version(platform.python_version(), minimum, name='Python ', hard=True)
|
||||
return check_version(platform.python_version(), minimum, name="Python ", hard=True)
|
||||
|
||||
|
||||
@TryExcept()
|
||||
def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''):
|
||||
def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
|
||||
"""
|
||||
Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed.
|
||||
|
||||
|
|
@ -329,41 +359,42 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
|||
```
|
||||
"""
|
||||
|
||||
prefix = colorstr('red', 'bold', 'requirements:')
|
||||
prefix = colorstr("red", "bold", "requirements:")
|
||||
check_python() # check python version
|
||||
check_torchvision() # check torch-torchvision compatibility
|
||||
if isinstance(requirements, Path): # requirements.txt file
|
||||
file = requirements.resolve()
|
||||
assert file.exists(), f'{prefix} {file} not found, check failed.'
|
||||
requirements = [f'{x.name}{x.specifier}' for x in parse_requirements(file) if x.name not in exclude]
|
||||
assert file.exists(), f"{prefix} {file} not found, check failed."
|
||||
requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
|
||||
elif isinstance(requirements, str):
|
||||
requirements = [requirements]
|
||||
|
||||
pkgs = []
|
||||
for r in requirements:
|
||||
r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
|
||||
match = re.match(r'([a-zA-Z0-9-_]+)([<>!=~]+.*)?', r_stripped)
|
||||
name, required = match[1], match[2].strip() if match[2] else ''
|
||||
r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
|
||||
match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
|
||||
name, required = match[1], match[2].strip() if match[2] else ""
|
||||
try:
|
||||
assert check_version(metadata.version(name), required) # exception if requirements not met
|
||||
except (AssertionError, metadata.PackageNotFoundError):
|
||||
pkgs.append(r)
|
||||
|
||||
s = ' '.join(f'"{x}"' for x in pkgs) # console string
|
||||
s = " ".join(f'"{x}"' for x in pkgs) # console string
|
||||
if s:
|
||||
if install and AUTOINSTALL: # check environment variable
|
||||
n = len(pkgs) # number of packages updates
|
||||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
|
||||
try:
|
||||
t = time.time()
|
||||
assert is_online(), 'AutoUpdate skipped (offline)'
|
||||
LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
|
||||
assert is_online(), "AutoUpdate skipped (offline)"
|
||||
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
|
||||
dt = time.time() - t
|
||||
LOGGER.info(
|
||||
f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
|
||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n")
|
||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{prefix} ❌ {e}')
|
||||
LOGGER.warning(f"{prefix} ❌ {e}")
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
|
@ -386,76 +417,82 @@ def check_torchvision():
|
|||
import torchvision
|
||||
|
||||
# Compatibility table
|
||||
compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
|
||||
compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
|
||||
|
||||
# Extract only the major and minor versions
|
||||
v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
|
||||
v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
|
||||
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
|
||||
v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2])
|
||||
|
||||
if v_torch in compatibility_table:
|
||||
compatible_versions = compatibility_table[v_torch]
|
||||
if all(v_torchvision != v for v in compatible_versions):
|
||||
print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
|
||||
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
|
||||
"'pip install -U torch torchvision' to update both.\n"
|
||||
'For a full compatibility table see https://github.com/pytorch/vision#installation')
|
||||
print(
|
||||
f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
|
||||
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
|
||||
"'pip install -U torch torchvision' to update both.\n"
|
||||
"For a full compatibility table see https://github.com/pytorch/vision#installation"
|
||||
)
|
||||
|
||||
|
||||
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
|
||||
def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""):
|
||||
"""Check file(s) for acceptable suffix."""
|
||||
if file and suffix:
|
||||
if isinstance(suffix, str):
|
||||
suffix = (suffix, )
|
||||
suffix = (suffix,)
|
||||
for f in file if isinstance(file, (list, tuple)) else [file]:
|
||||
s = Path(f).suffix.lower().strip() # file suffix
|
||||
if len(s):
|
||||
assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}, not {s}'
|
||||
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"
|
||||
|
||||
|
||||
def check_yolov5u_filename(file: str, verbose: bool = True):
|
||||
"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
|
||||
if 'yolov3' in file or 'yolov5' in file:
|
||||
if 'u.yaml' in file:
|
||||
file = file.replace('u.yaml', '.yaml') # i.e. yolov5nu.yaml -> yolov5n.yaml
|
||||
elif '.pt' in file and 'u' not in file:
|
||||
if "yolov3" in file or "yolov5" in file:
|
||||
if "u.yaml" in file:
|
||||
file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
|
||||
elif ".pt" in file and "u" not in file:
|
||||
original_file = file
|
||||
file = re.sub(r'(.*yolov5([nsmlx]))\.pt', '\\1u.pt', file) # i.e. yolov5n.pt -> yolov5nu.pt
|
||||
file = re.sub(r'(.*yolov5([nsmlx])6)\.pt', '\\1u.pt', file) # i.e. yolov5n6.pt -> yolov5n6u.pt
|
||||
file = re.sub(r'(.*yolov3(|-tiny|-spp))\.pt', '\\1u.pt', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
|
||||
file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
|
||||
file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
|
||||
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
|
||||
if file != original_file and verbose:
|
||||
LOGGER.info(
|
||||
f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
|
||||
f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
|
||||
f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')
|
||||
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
|
||||
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
def check_model_file_from_stem(model='yolov8n'):
|
||||
def check_model_file_from_stem(model="yolov8n"):
|
||||
"""Return a model filename from a valid model stem."""
|
||||
if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
|
||||
return Path(model).with_suffix('.pt') # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def check_file(file, suffix='', download=True, hard=True):
|
||||
def check_file(file, suffix="", download=True, hard=True):
|
||||
"""Search/download file (if necessary) and return path."""
|
||||
check_suffix(file, suffix) # optional
|
||||
file = str(file).strip() # convert to string and strip spaces
|
||||
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
|
||||
if (not file or ('://' not in file and Path(file).exists()) or # '://' check required in Windows Python<3.10
|
||||
file.lower().startswith('grpc://')): # file exists or gRPC Triton images
|
||||
if (
|
||||
not file
|
||||
or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
|
||||
or file.lower().startswith("grpc://")
|
||||
): # file exists or gRPC Triton images
|
||||
return file
|
||||
elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://', 'tcp://')): # download
|
||||
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
|
||||
url = file # warning: Pathlib turns :// -> :/
|
||||
file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
|
||||
if Path(file).exists():
|
||||
LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
|
||||
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
|
||||
else:
|
||||
downloads.safe_download(url=url, file=file, unzip=False)
|
||||
return file
|
||||
else: # search
|
||||
files = glob.glob(str(ROOT / 'cfg' / '**' / file), recursive=True) # find file
|
||||
files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file
|
||||
if not files and hard:
|
||||
raise FileNotFoundError(f"'{file}' does not exist")
|
||||
elif len(files) > 1 and hard:
|
||||
|
|
@ -463,7 +500,7 @@ def check_file(file, suffix='', download=True, hard=True):
|
|||
return files[0] if len(files) else [] # return file
|
||||
|
||||
|
||||
def check_yaml(file, suffix=('.yaml', '.yml'), hard=True):
|
||||
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
|
||||
"""Search/download YAML file (if necessary) and return path, checking suffix."""
|
||||
return check_file(file, suffix, hard=hard)
|
||||
|
||||
|
|
@ -482,51 +519,52 @@ def check_is_path_safe(basedir, path):
|
|||
base_dir_resolved = Path(basedir).resolve()
|
||||
path_resolved = Path(path).resolve()
|
||||
|
||||
return path_resolved.is_file() and path_resolved.parts[:len(base_dir_resolved.parts)] == base_dir_resolved.parts
|
||||
return path_resolved.is_file() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
|
||||
|
||||
|
||||
def check_imshow(warn=False):
|
||||
"""Check if environment supports image displays."""
|
||||
try:
|
||||
if LINUX:
|
||||
assert 'DISPLAY' in os.environ and not is_docker() and not is_colab() and not is_kaggle()
|
||||
cv2.imshow('test', np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
|
||||
assert "DISPLAY" in os.environ and not is_docker() and not is_colab() and not is_kaggle()
|
||||
cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
|
||||
cv2.waitKey(1)
|
||||
cv2.destroyAllWindows()
|
||||
cv2.waitKey(1)
|
||||
return True
|
||||
except Exception as e:
|
||||
if warn:
|
||||
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_yolo(verbose=True, device=''):
|
||||
def check_yolo(verbose=True, device=""):
|
||||
"""Return a human-readable YOLO software and hardware summary."""
|
||||
import psutil
|
||||
|
||||
from ultralytics.utils.torch_utils import select_device
|
||||
|
||||
if is_jupyter():
|
||||
if check_requirements('wandb', install=False):
|
||||
os.system('pip uninstall -y wandb') # uninstall wandb: unwanted account creation prompt with infinite hang
|
||||
if check_requirements("wandb", install=False):
|
||||
os.system("pip uninstall -y wandb") # uninstall wandb: unwanted account creation prompt with infinite hang
|
||||
if is_colab():
|
||||
shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory
|
||||
shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
|
||||
|
||||
if verbose:
|
||||
# System info
|
||||
gib = 1 << 30 # bytes per GiB
|
||||
ram = psutil.virtual_memory().total
|
||||
total, used, free = shutil.disk_usage('/')
|
||||
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
|
||||
total, used, free = shutil.disk_usage("/")
|
||||
s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
|
||||
with contextlib.suppress(Exception): # clear display if ipython is installed
|
||||
from IPython import display
|
||||
|
||||
display.clear_output()
|
||||
else:
|
||||
s = ''
|
||||
s = ""
|
||||
|
||||
select_device(device=device, newline=False)
|
||||
LOGGER.info(f'Setup complete ✅ {s}')
|
||||
LOGGER.info(f"Setup complete ✅ {s}")
|
||||
|
||||
|
||||
def collect_system_info():
|
||||
|
|
@ -537,32 +575,36 @@ def collect_system_info():
|
|||
from ultralytics.utils import ENVIRONMENT, is_git_dir
|
||||
from ultralytics.utils.torch_utils import get_cpu_info
|
||||
|
||||
ram_info = psutil.virtual_memory().total / (1024 ** 3) # Convert bytes to GB
|
||||
ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
|
||||
check_yolo()
|
||||
LOGGER.info(f"\n{'OS':<20}{platform.platform()}\n"
|
||||
f"{'Environment':<20}{ENVIRONMENT}\n"
|
||||
f"{'Python':<20}{sys.version.split()[0]}\n"
|
||||
f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n"
|
||||
f"{'RAM':<20}{ram_info:.2f} GB\n"
|
||||
f"{'CPU':<20}{get_cpu_info()}\n"
|
||||
f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n")
|
||||
LOGGER.info(
|
||||
f"\n{'OS':<20}{platform.platform()}\n"
|
||||
f"{'Environment':<20}{ENVIRONMENT}\n"
|
||||
f"{'Python':<20}{sys.version.split()[0]}\n"
|
||||
f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n"
|
||||
f"{'RAM':<20}{ram_info:.2f} GB\n"
|
||||
f"{'CPU':<20}{get_cpu_info()}\n"
|
||||
f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n"
|
||||
)
|
||||
|
||||
for r in parse_requirements(package='ultralytics'):
|
||||
for r in parse_requirements(package="ultralytics"):
|
||||
try:
|
||||
current = metadata.version(r.name)
|
||||
is_met = '✅ ' if check_version(current, str(r.specifier), hard=True) else '❌ '
|
||||
is_met = "✅ " if check_version(current, str(r.specifier), hard=True) else "❌ "
|
||||
except metadata.PackageNotFoundError:
|
||||
current = '(not installed)'
|
||||
is_met = '❌ '
|
||||
LOGGER.info(f'{r.name:<20}{is_met}{current}{r.specifier}')
|
||||
current = "(not installed)"
|
||||
is_met = "❌ "
|
||||
LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}")
|
||||
|
||||
if is_github_action_running():
|
||||
LOGGER.info(f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
|
||||
f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
|
||||
f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
|
||||
f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
|
||||
f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
|
||||
f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n")
|
||||
LOGGER.info(
|
||||
f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
|
||||
f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
|
||||
f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
|
||||
f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
|
||||
f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
|
||||
f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n"
|
||||
)
|
||||
|
||||
|
||||
def check_amp(model):
|
||||
|
|
@ -587,7 +629,7 @@ def check_amp(model):
|
|||
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
|
||||
"""
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type in ('cpu', 'mps'):
|
||||
if device.type in ("cpu", "mps"):
|
||||
return False # AMP only used on CUDA devices
|
||||
|
||||
def amp_allclose(m, im):
|
||||
|
|
@ -598,22 +640,27 @@ def check_amp(model):
|
|||
del m
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
||||
|
||||
im = ASSETS / 'bus.jpg' # image to check
|
||||
prefix = colorstr('AMP: ')
|
||||
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
|
||||
im = ASSETS / "bus.jpg" # image to check
|
||||
prefix = colorstr("AMP: ")
|
||||
LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...")
|
||||
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
assert amp_allclose(YOLO('yolov8n.pt'), im)
|
||||
LOGGER.info(f'{prefix}checks passed ✅')
|
||||
|
||||
assert amp_allclose(YOLO("yolov8n.pt"), im)
|
||||
LOGGER.info(f"{prefix}checks passed ✅")
|
||||
except ConnectionError:
|
||||
LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}')
|
||||
LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}")
|
||||
except (AttributeError, ModuleNotFoundError):
|
||||
LOGGER.warning(f'{prefix}checks skipped ⚠️. '
|
||||
f'Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}')
|
||||
LOGGER.warning(
|
||||
f"{prefix}checks skipped ⚠️. "
|
||||
f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}"
|
||||
)
|
||||
except AssertionError:
|
||||
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
|
||||
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
|
||||
LOGGER.warning(
|
||||
f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to "
|
||||
f"NaN losses or zero-mAP results, so AMP will be disabled during training."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
|
@ -621,8 +668,8 @@ def check_amp(model):
|
|||
def git_describe(path=ROOT): # path must be a directory
|
||||
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
|
||||
with contextlib.suppress(Exception):
|
||||
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
||||
return ''
|
||||
return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
|
||||
return ""
|
||||
|
||||
|
||||
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||
|
|
@ -630,7 +677,7 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
|||
|
||||
def strip_auth(v):
|
||||
"""Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
|
||||
return clean_url(v) if (isinstance(v, str) and v.startswith('http') and len(v) > 100) else v
|
||||
return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
|
||||
|
||||
x = inspect.currentframe().f_back # previous frame
|
||||
file, _, func, _, _ = inspect.getframeinfo(x)
|
||||
|
|
@ -638,11 +685,11 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
|||
args, _, _, frm = inspect.getargvalues(x)
|
||||
args = {k: v for k, v in frm.items() if k in args}
|
||||
try:
|
||||
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
|
||||
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
|
||||
except ValueError:
|
||||
file = Path(file).stem
|
||||
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
|
||||
LOGGER.info(colorstr(s) + ', '.join(f'{k}={strip_auth(v)}' for k, v in args.items()))
|
||||
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
|
||||
LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items()))
|
||||
|
||||
|
||||
def cuda_device_count() -> int:
|
||||
|
|
@ -654,11 +701,12 @@ def cuda_device_count() -> int:
|
|||
"""
|
||||
try:
|
||||
# Run the nvidia-smi command and capture its output
|
||||
output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'],
|
||||
encoding='utf-8')
|
||||
output = subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
|
||||
)
|
||||
|
||||
# Take the first line and strip any leading/trailing white space
|
||||
first_line = output.strip().split('\n')[0]
|
||||
first_line = output.strip().split("\n")[0]
|
||||
|
||||
return int(first_line)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
|
||||
|
|
|
|||
|
|
@ -18,13 +18,13 @@ def find_free_network_port() -> int:
|
|||
`MASTER_PORT` environment variable.
|
||||
"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('127.0.0.1', 0))
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1] # port
|
||||
|
||||
|
||||
def generate_ddp_file(trainer):
|
||||
"""Generates a DDP file and returns its file name."""
|
||||
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
|
||||
module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
|
||||
|
||||
content = f"""
|
||||
# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
|
||||
|
|
@ -39,13 +39,15 @@ if __name__ == "__main__":
|
|||
trainer = {name}(cfg=cfg, overrides=overrides)
|
||||
results = trainer.train()
|
||||
"""
|
||||
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(prefix='_temp_',
|
||||
suffix=f'{id(trainer)}.py',
|
||||
mode='w+',
|
||||
encoding='utf-8',
|
||||
dir=USER_CONFIG_DIR / 'DDP',
|
||||
delete=False) as file:
|
||||
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
prefix="_temp_",
|
||||
suffix=f"{id(trainer)}.py",
|
||||
mode="w+",
|
||||
encoding="utf-8",
|
||||
dir=USER_CONFIG_DIR / "DDP",
|
||||
delete=False,
|
||||
) as file:
|
||||
file.write(content)
|
||||
return file.name
|
||||
|
||||
|
|
@ -53,16 +55,17 @@ if __name__ == "__main__":
|
|||
def generate_ddp_command(world_size, trainer):
|
||||
"""Generates and returns command for distributed training."""
|
||||
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
|
||||
if not trainer.resume:
|
||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||
file = generate_ddp_file(trainer)
|
||||
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
|
||||
dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
|
||||
port = find_free_network_port()
|
||||
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
|
||||
cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
|
||||
return cmd, file
|
||||
|
||||
|
||||
def ddp_cleanup(trainer, file):
|
||||
"""Delete temp file if created."""
|
||||
if f'{id(trainer)}.py' in file: # if temp_file suffix in file
|
||||
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
|
||||
os.remove(file)
|
||||
|
|
|
|||
|
|
@ -15,15 +15,17 @@ import torch
|
|||
from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
|
||||
|
||||
# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
|
||||
GITHUB_ASSETS_REPO = 'ultralytics/assets'
|
||||
GITHUB_ASSETS_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '-cls', '-seg', '-pose', '-obb')] + \
|
||||
[f'yolov5{k}{resolution}u.pt' for k in 'nsmlx' for resolution in ('', '6')] + \
|
||||
[f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \
|
||||
[f'yolo_nas_{k}.pt' for k in 'sml'] + \
|
||||
[f'sam_{k}.pt' for k in 'bl'] + \
|
||||
[f'FastSAM-{k}.pt' for k in 'sx'] + \
|
||||
[f'rtdetr-{k}.pt' for k in 'lx'] + \
|
||||
['mobile_sam.pt']
|
||||
GITHUB_ASSETS_REPO = "ultralytics/assets"
|
||||
GITHUB_ASSETS_NAMES = (
|
||||
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
|
||||
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
|
||||
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
|
||||
+ [f"yolo_nas_{k}.pt" for k in "sml"]
|
||||
+ [f"sam_{k}.pt" for k in "bl"]
|
||||
+ [f"FastSAM-{k}.pt" for k in "sx"]
|
||||
+ [f"rtdetr-{k}.pt" for k in "lx"]
|
||||
+ ["mobile_sam.pt"]
|
||||
)
|
||||
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
|
||||
|
||||
|
||||
|
|
@ -56,7 +58,7 @@ def is_url(url, check=True):
|
|||
return False
|
||||
|
||||
|
||||
def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
|
||||
def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
|
||||
"""
|
||||
Deletes all ".DS_store" files under a specified directory.
|
||||
|
||||
|
|
@ -77,12 +79,12 @@ def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
|
|||
"""
|
||||
for file in files_to_delete:
|
||||
matches = list(Path(path).rglob(file))
|
||||
LOGGER.info(f'Deleting {file} files: {matches}')
|
||||
LOGGER.info(f"Deleting {file} files: {matches}")
|
||||
for f in matches:
|
||||
f.unlink()
|
||||
|
||||
|
||||
def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), progress=True):
|
||||
def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):
|
||||
"""
|
||||
Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
|
||||
named after the directory and placed alongside it.
|
||||
|
|
@ -111,17 +113,17 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p
|
|||
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
|
||||
|
||||
# Unzip with progress bar
|
||||
files_to_zip = [f for f in directory.rglob('*') if f.is_file() and all(x not in f.name for x in exclude)]
|
||||
zip_file = directory.with_suffix('.zip')
|
||||
files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)]
|
||||
zip_file = directory.with_suffix(".zip")
|
||||
compression = ZIP_DEFLATED if compress else ZIP_STORED
|
||||
with ZipFile(zip_file, 'w', compression) as f:
|
||||
for file in TQDM(files_to_zip, desc=f'Zipping {directory} to {zip_file}...', unit='file', disable=not progress):
|
||||
with ZipFile(zip_file, "w", compression) as f:
|
||||
for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress):
|
||||
f.write(file, file.relative_to(directory))
|
||||
|
||||
return zip_file # return path to zip file
|
||||
|
||||
|
||||
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=False, progress=True):
|
||||
def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
|
||||
"""
|
||||
Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
|
||||
|
||||
|
|
@ -161,7 +163,7 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
|
|||
files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
|
||||
top_level_dirs = {Path(f).parts[0] for f in files}
|
||||
|
||||
if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith('/')):
|
||||
if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith("/")):
|
||||
# Zip has multiple files at top level
|
||||
path = extract_path = Path(path) / Path(file).stem # i.e. ../datasets/coco8
|
||||
else:
|
||||
|
|
@ -172,20 +174,20 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
|
|||
# Check if destination directory already exists and contains files
|
||||
if path.exists() and any(path.iterdir()) and not exist_ok:
|
||||
# If it exists and is not empty, return the path without unzipping
|
||||
LOGGER.warning(f'WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.')
|
||||
LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.")
|
||||
return path
|
||||
|
||||
for f in TQDM(files, desc=f'Unzipping {file} to {Path(path).resolve()}...', unit='file', disable=not progress):
|
||||
for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress):
|
||||
# Ensure the file is within the extract_path to avoid path traversal security vulnerability
|
||||
if '..' in Path(f).parts:
|
||||
LOGGER.warning(f'Potentially insecure file path: {f}, skipping extraction.')
|
||||
if ".." in Path(f).parts:
|
||||
LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
|
||||
continue
|
||||
zipObj.extract(f, extract_path)
|
||||
|
||||
return path # return unzip dir
|
||||
|
||||
|
||||
def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, hard=True):
|
||||
def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", sf=1.5, hard=True):
|
||||
"""
|
||||
Check if there is sufficient disk space to download and store a file.
|
||||
|
||||
|
|
@ -199,20 +201,23 @@ def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, h
|
|||
"""
|
||||
try:
|
||||
r = requests.head(url) # response
|
||||
assert r.status_code < 400, f'URL error for {url}: {r.status_code} {r.reason}' # check response
|
||||
assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response
|
||||
except Exception:
|
||||
return True # requests issue, default to True
|
||||
|
||||
# Check file size
|
||||
gib = 1 << 30 # bytes per GiB
|
||||
data = int(r.headers.get('Content-Length', 0)) / gib # file size (GB)
|
||||
data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB)
|
||||
total, used, free = (x / gib for x in shutil.disk_usage(Path.cwd())) # bytes
|
||||
|
||||
if data * sf < free:
|
||||
return True # sufficient space
|
||||
|
||||
# Insufficient space
|
||||
text = (f'WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, '
|
||||
f'Please free {data * sf - free:.1f} GB additional disk space and try again.')
|
||||
text = (
|
||||
f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, "
|
||||
f"Please free {data * sf - free:.1f} GB additional disk space and try again."
|
||||
)
|
||||
if hard:
|
||||
raise MemoryError(text)
|
||||
LOGGER.warning(text)
|
||||
|
|
@ -238,36 +243,41 @@ def get_google_drive_file_info(link):
|
|||
url, filename = get_google_drive_file_info(link)
|
||||
```
|
||||
"""
|
||||
file_id = link.split('/d/')[1].split('/view')[0]
|
||||
drive_url = f'https://drive.google.com/uc?export=download&id={file_id}'
|
||||
file_id = link.split("/d/")[1].split("/view")[0]
|
||||
drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
|
||||
filename = None
|
||||
|
||||
# Start session
|
||||
with requests.Session() as session:
|
||||
response = session.get(drive_url, stream=True)
|
||||
if 'quota exceeded' in str(response.content.lower()):
|
||||
if "quota exceeded" in str(response.content.lower()):
|
||||
raise ConnectionError(
|
||||
emojis(f'❌ Google Drive file download quota exceeded. '
|
||||
f'Please try again later or download this file manually at {link}.'))
|
||||
emojis(
|
||||
f"❌ Google Drive file download quota exceeded. "
|
||||
f"Please try again later or download this file manually at {link}."
|
||||
)
|
||||
)
|
||||
for k, v in response.cookies.items():
|
||||
if k.startswith('download_warning'):
|
||||
drive_url += f'&confirm={v}' # v is token
|
||||
cd = response.headers.get('content-disposition')
|
||||
if k.startswith("download_warning"):
|
||||
drive_url += f"&confirm={v}" # v is token
|
||||
cd = response.headers.get("content-disposition")
|
||||
if cd:
|
||||
filename = re.findall('filename="(.+)"', cd)[0]
|
||||
return drive_url, filename
|
||||
|
||||
|
||||
def safe_download(url,
|
||||
file=None,
|
||||
dir=None,
|
||||
unzip=True,
|
||||
delete=False,
|
||||
curl=False,
|
||||
retry=3,
|
||||
min_bytes=1E0,
|
||||
exist_ok=False,
|
||||
progress=True):
|
||||
def safe_download(
|
||||
url,
|
||||
file=None,
|
||||
dir=None,
|
||||
unzip=True,
|
||||
delete=False,
|
||||
curl=False,
|
||||
retry=3,
|
||||
min_bytes=1e0,
|
||||
exist_ok=False,
|
||||
progress=True,
|
||||
):
|
||||
"""
|
||||
Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
|
||||
|
||||
|
|
@ -294,36 +304,38 @@ def safe_download(url,
|
|||
path = safe_download(link)
|
||||
```
|
||||
"""
|
||||
gdrive = url.startswith('https://drive.google.com/') # check if the URL is a Google Drive link
|
||||
gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link
|
||||
if gdrive:
|
||||
url, file = get_google_drive_file_info(url)
|
||||
|
||||
f = Path(dir or '.') / (file or url2file(url)) # URL converted to filename
|
||||
if '://' not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
|
||||
f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename
|
||||
if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
|
||||
f = Path(url) # filename
|
||||
elif not f.is_file(): # URL and file do not exist
|
||||
desc = f"Downloading {url if gdrive else clean_url(url)} to '{f}'"
|
||||
LOGGER.info(f'{desc}...')
|
||||
LOGGER.info(f"{desc}...")
|
||||
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
|
||||
check_disk_space(url)
|
||||
for i in range(retry + 1):
|
||||
try:
|
||||
if curl or i > 0: # curl download with retry, continue
|
||||
s = 'sS' * (not progress) # silent
|
||||
r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode
|
||||
assert r == 0, f'Curl return value {r}'
|
||||
s = "sS" * (not progress) # silent
|
||||
r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
|
||||
assert r == 0, f"Curl return value {r}"
|
||||
else: # urllib download
|
||||
method = 'torch'
|
||||
if method == 'torch':
|
||||
method = "torch"
|
||||
if method == "torch":
|
||||
torch.hub.download_url_to_file(url, f, progress=progress)
|
||||
else:
|
||||
with request.urlopen(url) as response, TQDM(total=int(response.getheader('Content-Length', 0)),
|
||||
desc=desc,
|
||||
disable=not progress,
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024) as pbar:
|
||||
with open(f, 'wb') as f_opened:
|
||||
with request.urlopen(url) as response, TQDM(
|
||||
total=int(response.getheader("Content-Length", 0)),
|
||||
desc=desc,
|
||||
disable=not progress,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as pbar:
|
||||
with open(f, "wb") as f_opened:
|
||||
for data in response:
|
||||
f_opened.write(data)
|
||||
pbar.update(len(data))
|
||||
|
|
@ -334,26 +346,26 @@ def safe_download(url,
|
|||
f.unlink() # remove partial downloads
|
||||
except Exception as e:
|
||||
if i == 0 and not is_online():
|
||||
raise ConnectionError(emojis(f'❌ Download failure for {url}. Environment is not online.')) from e
|
||||
raise ConnectionError(emojis(f"❌ Download failure for {url}. Environment is not online.")) from e
|
||||
elif i >= retry:
|
||||
raise ConnectionError(emojis(f'❌ Download failure for {url}. Retry limit reached.')) from e
|
||||
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
||||
raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e
|
||||
LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
|
||||
|
||||
if unzip and f.exists() and f.suffix in ('', '.zip', '.tar', '.gz'):
|
||||
if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"):
|
||||
from zipfile import is_zipfile
|
||||
|
||||
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
|
||||
if is_zipfile(f):
|
||||
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
|
||||
elif f.suffix in ('.tar', '.gz'):
|
||||
LOGGER.info(f'Unzipping {f} to {unzip_dir}...')
|
||||
subprocess.run(['tar', 'xf' if f.suffix == '.tar' else 'xfz', f, '--directory', unzip_dir], check=True)
|
||||
elif f.suffix in (".tar", ".gz"):
|
||||
LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
|
||||
subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
|
||||
if delete:
|
||||
f.unlink() # remove zip
|
||||
return unzip_dir
|
||||
|
||||
|
||||
def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
|
||||
def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
|
||||
"""
|
||||
Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the
|
||||
function fetches the latest release assets.
|
||||
|
|
@ -372,20 +384,20 @@ def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
|
|||
```
|
||||
"""
|
||||
|
||||
if version != 'latest':
|
||||
version = f'tags/{version}' # i.e. tags/v6.2
|
||||
url = f'https://api.github.com/repos/{repo}/releases/{version}'
|
||||
if version != "latest":
|
||||
version = f"tags/{version}" # i.e. tags/v6.2
|
||||
url = f"https://api.github.com/repos/{repo}/releases/{version}"
|
||||
r = requests.get(url) # github api
|
||||
if r.status_code != 200 and r.reason != 'rate limit exceeded' and retry: # failed and not 403 rate limit exceeded
|
||||
if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded
|
||||
r = requests.get(url) # try again
|
||||
if r.status_code != 200:
|
||||
LOGGER.warning(f'⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}')
|
||||
return '', []
|
||||
LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}")
|
||||
return "", []
|
||||
data = r.json()
|
||||
return data['tag_name'], [x['name'] for x in data['assets']] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
|
||||
return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
|
||||
|
||||
|
||||
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0', **kwargs):
|
||||
def attempt_download_asset(file, repo="ultralytics/assets", release="v0.0.0", **kwargs):
|
||||
"""
|
||||
Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file
|
||||
locally first, then tries to download it from the specified GitHub repository release.
|
||||
|
|
@ -409,32 +421,32 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0', **
|
|||
# YOLOv3/5u updates
|
||||
file = str(file)
|
||||
file = checks.check_yolov5u_filename(file)
|
||||
file = Path(file.strip().replace("'", ''))
|
||||
file = Path(file.strip().replace("'", ""))
|
||||
if file.exists():
|
||||
return str(file)
|
||||
elif (SETTINGS['weights_dir'] / file).exists():
|
||||
return str(SETTINGS['weights_dir'] / file)
|
||||
elif (SETTINGS["weights_dir"] / file).exists():
|
||||
return str(SETTINGS["weights_dir"] / file)
|
||||
else:
|
||||
# URL specified
|
||||
name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
|
||||
download_url = f'https://github.com/{repo}/releases/download'
|
||||
if str(file).startswith(('http:/', 'https:/')): # download
|
||||
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
|
||||
download_url = f"https://github.com/{repo}/releases/download"
|
||||
if str(file).startswith(("http:/", "https:/")): # download
|
||||
url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
|
||||
file = url2file(name) # parse authentication https://url.com/file.txt?auth...
|
||||
if Path(file).is_file():
|
||||
LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
|
||||
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
|
||||
else:
|
||||
safe_download(url=url, file=file, min_bytes=1E5, **kwargs)
|
||||
safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
|
||||
|
||||
elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
|
||||
safe_download(url=f'{download_url}/{release}/{name}', file=file, min_bytes=1E5, **kwargs)
|
||||
safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
|
||||
|
||||
else:
|
||||
tag, assets = get_github_assets(repo, release)
|
||||
if not assets:
|
||||
tag, assets = get_github_assets(repo) # latest release
|
||||
if name in assets:
|
||||
safe_download(url=f'{download_url}/{tag}/{name}', file=file, min_bytes=1E5, **kwargs)
|
||||
safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
|
||||
|
||||
return str(file)
|
||||
|
||||
|
|
@ -464,14 +476,18 @@ def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=
|
|||
if threads > 1:
|
||||
with ThreadPool(threads) as pool:
|
||||
pool.map(
|
||||
lambda x: safe_download(url=x[0],
|
||||
dir=x[1],
|
||||
unzip=unzip,
|
||||
delete=delete,
|
||||
curl=curl,
|
||||
retry=retry,
|
||||
exist_ok=exist_ok,
|
||||
progress=threads <= 1), zip(url, repeat(dir)))
|
||||
lambda x: safe_download(
|
||||
url=x[0],
|
||||
dir=x[1],
|
||||
unzip=unzip,
|
||||
delete=delete,
|
||||
curl=curl,
|
||||
retry=retry,
|
||||
exist_ok=exist_ok,
|
||||
progress=threads <= 1,
|
||||
),
|
||||
zip(url, repeat(dir)),
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,6 @@ class HUBModelError(Exception):
|
|||
The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package.
|
||||
"""
|
||||
|
||||
def __init__(self, message='Model not found. Please check model URL and try again.'):
|
||||
def __init__(self, message="Model not found. Please check model URL and try again."):
|
||||
"""Create an exception for when a model is not found."""
|
||||
super().__init__(emojis(message))
|
||||
|
|
|
|||
|
|
@ -50,13 +50,13 @@ def spaces_in_path(path):
|
|||
"""
|
||||
|
||||
# If path has spaces, replace them with underscores
|
||||
if ' ' in str(path):
|
||||
if " " in str(path):
|
||||
string = isinstance(path, str) # input type
|
||||
path = Path(path)
|
||||
|
||||
# Create a temporary directory and construct the new path
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir) / path.name.replace(' ', '_')
|
||||
tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")
|
||||
|
||||
# Copy file/directory
|
||||
if path.is_dir():
|
||||
|
|
@ -82,7 +82,7 @@ def spaces_in_path(path):
|
|||
yield path
|
||||
|
||||
|
||||
def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
||||
def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
||||
"""
|
||||
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
|
||||
|
|
@ -102,11 +102,11 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
|||
"""
|
||||
path = Path(path) # os-agnostic
|
||||
if path.exists() and not exist_ok:
|
||||
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
|
||||
path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
|
||||
|
||||
# Method 1
|
||||
for n in range(2, 9999):
|
||||
p = f'{path}{sep}{n}{suffix}' # increment path
|
||||
p = f"{path}{sep}{n}{suffix}" # increment path
|
||||
if not os.path.exists(p):
|
||||
break
|
||||
path = Path(p)
|
||||
|
|
@ -119,14 +119,14 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
|||
|
||||
def file_age(path=__file__):
|
||||
"""Return days since last file update."""
|
||||
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
|
||||
dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
|
||||
return dt.days # + dt.seconds / 86400 # fractional days
|
||||
|
||||
|
||||
def file_date(path=__file__):
|
||||
"""Return human-readable file modification date, i.e. '2021-3-26'."""
|
||||
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||
return f'{t.year}-{t.month}-{t.day}'
|
||||
return f"{t.year}-{t.month}-{t.day}"
|
||||
|
||||
|
||||
def file_size(path):
|
||||
|
|
@ -137,11 +137,11 @@ def file_size(path):
|
|||
if path.is_file():
|
||||
return path.stat().st_size / mb
|
||||
elif path.is_dir():
|
||||
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
|
||||
return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_latest_run(search_dir='.'):
|
||||
def get_latest_run(search_dir="."):
|
||||
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
|
||||
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
||||
return max(last_list, key=os.path.getctime) if last_list else ''
|
||||
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
|
||||
return max(last_list, key=os.path.getctime) if last_list else ""
|
||||
|
|
|
|||
|
|
@ -26,9 +26,9 @@ to_4tuple = _ntuple(4)
|
|||
# `xyxy` means left top and right bottom
|
||||
# `xywh` means center x, center y and width, height(YOLO format)
|
||||
# `ltwh` means left top and width, height(COCO format)
|
||||
_formats = ['xyxy', 'xywh', 'ltwh']
|
||||
_formats = ["xyxy", "xywh", "ltwh"]
|
||||
|
||||
__all__ = 'Bboxes', # tuple or list
|
||||
__all__ = ("Bboxes",) # tuple or list
|
||||
|
||||
|
||||
class Bboxes:
|
||||
|
|
@ -46,9 +46,9 @@ class Bboxes:
|
|||
This class does not handle normalization or denormalization of bounding boxes.
|
||||
"""
|
||||
|
||||
def __init__(self, bboxes, format='xyxy') -> None:
|
||||
def __init__(self, bboxes, format="xyxy") -> None:
|
||||
"""Initializes the Bboxes class with bounding box data in a specified format."""
|
||||
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
|
||||
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
|
||||
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
|
||||
assert bboxes.ndim == 2
|
||||
assert bboxes.shape[1] == 4
|
||||
|
|
@ -58,21 +58,21 @@ class Bboxes:
|
|||
|
||||
def convert(self, format):
|
||||
"""Converts bounding box format from one type to another."""
|
||||
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
|
||||
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
|
||||
if self.format == format:
|
||||
return
|
||||
elif self.format == 'xyxy':
|
||||
func = xyxy2xywh if format == 'xywh' else xyxy2ltwh
|
||||
elif self.format == 'xywh':
|
||||
func = xywh2xyxy if format == 'xyxy' else xywh2ltwh
|
||||
elif self.format == "xyxy":
|
||||
func = xyxy2xywh if format == "xywh" else xyxy2ltwh
|
||||
elif self.format == "xywh":
|
||||
func = xywh2xyxy if format == "xyxy" else xywh2ltwh
|
||||
else:
|
||||
func = ltwh2xyxy if format == 'xyxy' else ltwh2xywh
|
||||
func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
|
||||
self.bboxes = func(self.bboxes)
|
||||
self.format = format
|
||||
|
||||
def areas(self):
|
||||
"""Return box areas."""
|
||||
self.convert('xyxy')
|
||||
self.convert("xyxy")
|
||||
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
|
||||
|
||||
# def denormalize(self, w, h):
|
||||
|
|
@ -124,7 +124,7 @@ class Bboxes:
|
|||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
|
||||
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
|
||||
"""
|
||||
Concatenate a list of Bboxes objects into a single Bboxes object.
|
||||
|
||||
|
|
@ -148,7 +148,7 @@ class Bboxes:
|
|||
return boxes_list[0]
|
||||
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
|
||||
|
||||
def __getitem__(self, index) -> 'Bboxes':
|
||||
def __getitem__(self, index) -> "Bboxes":
|
||||
"""
|
||||
Retrieve a specific bounding box or a set of bounding boxes using indexing.
|
||||
|
||||
|
|
@ -169,7 +169,7 @@ class Bboxes:
|
|||
if isinstance(index, int):
|
||||
return Bboxes(self.bboxes[index].view(1, -1))
|
||||
b = self.bboxes[index]
|
||||
assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!'
|
||||
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
|
||||
return Bboxes(b)
|
||||
|
||||
|
||||
|
|
@ -205,7 +205,7 @@ class Instances:
|
|||
This class does not perform input validation, and it assumes the inputs are well-formed.
|
||||
"""
|
||||
|
||||
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None:
|
||||
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
|
||||
"""
|
||||
Args:
|
||||
bboxes (ndarray): bboxes with shape [N, 4].
|
||||
|
|
@ -263,7 +263,7 @@ class Instances:
|
|||
|
||||
def add_padding(self, padw, padh):
|
||||
"""Handle rect and mosaic situation."""
|
||||
assert not self.normalized, 'you should add padding with absolute coordinates.'
|
||||
assert not self.normalized, "you should add padding with absolute coordinates."
|
||||
self._bboxes.add(offset=(padw, padh, padw, padh))
|
||||
self.segments[..., 0] += padw
|
||||
self.segments[..., 1] += padh
|
||||
|
|
@ -271,7 +271,7 @@ class Instances:
|
|||
self.keypoints[..., 0] += padw
|
||||
self.keypoints[..., 1] += padh
|
||||
|
||||
def __getitem__(self, index) -> 'Instances':
|
||||
def __getitem__(self, index) -> "Instances":
|
||||
"""
|
||||
Retrieve a specific instance or a set of instances using indexing.
|
||||
|
||||
|
|
@ -301,7 +301,7 @@ class Instances:
|
|||
|
||||
def flipud(self, h):
|
||||
"""Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
|
||||
if self._bboxes.format == 'xyxy':
|
||||
if self._bboxes.format == "xyxy":
|
||||
y1 = self.bboxes[:, 1].copy()
|
||||
y2 = self.bboxes[:, 3].copy()
|
||||
self.bboxes[:, 1] = h - y2
|
||||
|
|
@ -314,7 +314,7 @@ class Instances:
|
|||
|
||||
def fliplr(self, w):
|
||||
"""Reverses the order of the bounding boxes and segments horizontally."""
|
||||
if self._bboxes.format == 'xyxy':
|
||||
if self._bboxes.format == "xyxy":
|
||||
x1 = self.bboxes[:, 0].copy()
|
||||
x2 = self.bboxes[:, 2].copy()
|
||||
self.bboxes[:, 0] = w - x2
|
||||
|
|
@ -328,10 +328,10 @@ class Instances:
|
|||
def clip(self, w, h):
|
||||
"""Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
|
||||
ori_format = self._bboxes.format
|
||||
self.convert_bbox(format='xyxy')
|
||||
self.convert_bbox(format="xyxy")
|
||||
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
|
||||
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
|
||||
if ori_format != 'xyxy':
|
||||
if ori_format != "xyxy":
|
||||
self.convert_bbox(format=ori_format)
|
||||
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
|
||||
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
|
||||
|
|
@ -367,7 +367,7 @@ class Instances:
|
|||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
|
||||
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
|
||||
"""
|
||||
Concatenates a list of Instances objects into a single Instances object.
|
||||
|
||||
|
|
|
|||
|
|
@ -28,22 +28,27 @@ class VarifocalLoss(nn.Module):
|
|||
"""Computes varfocal loss."""
|
||||
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
|
||||
weight).mean(1).sum()
|
||||
loss = (
|
||||
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
|
||||
.mean(1)
|
||||
.sum()
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
|
||||
|
||||
def __init__(self, ):
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
"""Initializer for FocalLoss class with no parameters."""
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def forward(pred, label, gamma=1.5, alpha=0.25):
|
||||
"""Calculates and updates confusion matrix for object detection/classification tasks."""
|
||||
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
|
||||
loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
|
||||
# p_t = torch.exp(-loss)
|
||||
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
||||
|
||||
|
|
@ -91,8 +96,10 @@ class BboxLoss(nn.Module):
|
|||
tr = tl + 1 # target right
|
||||
wl = tr - target # weight left
|
||||
wr = 1 - wl # weight right
|
||||
return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
|
||||
F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
|
||||
return (
|
||||
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
|
||||
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
|
||||
).mean(-1, keepdim=True)
|
||||
|
||||
|
||||
class RotatedBboxLoss(BboxLoss):
|
||||
|
|
@ -145,7 +152,7 @@ class v8DetectionLoss:
|
|||
h = model.args # hyperparameters
|
||||
|
||||
m = model.model[-1] # Detect() module
|
||||
self.bce = nn.BCEWithLogitsLoss(reduction='none')
|
||||
self.bce = nn.BCEWithLogitsLoss(reduction="none")
|
||||
self.hyp = h
|
||||
self.stride = m.stride # model strides
|
||||
self.nc = m.nc # number of classes
|
||||
|
|
@ -190,7 +197,8 @@ class v8DetectionLoss:
|
|||
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
||||
feats = preds[1] if isinstance(preds, tuple) else preds
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1)
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
|
|
@ -201,7 +209,7 @@ class v8DetectionLoss:
|
|||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# Targets
|
||||
targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
|
||||
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
|
|
@ -210,8 +218,13 @@ class v8DetectionLoss:
|
|||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
||||
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
|
|
@ -222,8 +235,9 @@ class v8DetectionLoss:
|
|||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes /= stride_tensor
|
||||
loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
|
||||
target_scores_sum, fg_mask)
|
||||
loss[0], loss[2] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.cls # cls gain
|
||||
|
|
@ -246,7 +260,8 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|||
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
||||
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1)
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
# B, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
|
|
@ -259,24 +274,31 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|||
|
||||
# Targets
|
||||
try:
|
||||
batch_idx = batch['batch_idx'].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
|
||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
except RuntimeError as e:
|
||||
raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
|
||||
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
|
||||
'as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help.') from e
|
||||
raise TypeError(
|
||||
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
|
||||
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
|
||||
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
|
||||
) from e
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
||||
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
|
|
@ -286,15 +308,23 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|||
|
||||
if fg_mask.sum():
|
||||
# Bbox loss
|
||||
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
|
||||
target_scores, target_scores_sum, fg_mask)
|
||||
loss[0], loss[3] = self.bbox_loss(
|
||||
pred_distri,
|
||||
pred_bboxes,
|
||||
anchor_points,
|
||||
target_bboxes / stride_tensor,
|
||||
target_scores,
|
||||
target_scores_sum,
|
||||
fg_mask,
|
||||
)
|
||||
# Masks loss
|
||||
masks = batch['masks'].to(self.device).float()
|
||||
masks = batch["masks"].to(self.device).float()
|
||||
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
||||
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
|
||||
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
||||
|
||||
loss[1] = self.calculate_segmentation_loss(fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto,
|
||||
pred_masks, imgsz, self.overlap)
|
||||
loss[1] = self.calculate_segmentation_loss(
|
||||
fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
|
||||
)
|
||||
|
||||
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
|
|
@ -308,8 +338,9 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
@staticmethod
|
||||
def single_mask_loss(gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor,
|
||||
area: torch.Tensor) -> torch.Tensor:
|
||||
def single_mask_loss(
|
||||
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the instance segmentation loss for a single image.
|
||||
|
||||
|
|
@ -327,8 +358,8 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|||
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
|
||||
predicted masks from the prototype masks and predicted mask coefficients.
|
||||
"""
|
||||
pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
|
||||
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
|
||||
pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
|
||||
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
|
||||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
|
||||
|
||||
def calculate_segmentation_loss(
|
||||
|
|
@ -387,8 +418,9 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|||
else:
|
||||
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
|
||||
|
||||
loss += self.single_mask_loss(gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i],
|
||||
marea_i[fg_mask_i])
|
||||
loss += self.single_mask_loss(
|
||||
gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
|
||||
)
|
||||
|
||||
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
|
|
@ -415,7 +447,8 @@ class v8PoseLoss(v8DetectionLoss):
|
|||
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
||||
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1)
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
# B, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
|
|
@ -428,8 +461,8 @@ class v8PoseLoss(v8DetectionLoss):
|
|||
|
||||
# Targets
|
||||
batch_size = pred_scores.shape[0]
|
||||
batch_idx = batch['batch_idx'].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
|
||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
|
|
@ -439,8 +472,13 @@ class v8PoseLoss(v8DetectionLoss):
|
|||
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
||||
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
|
|
@ -451,14 +489,16 @@ class v8PoseLoss(v8DetectionLoss):
|
|||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes /= stride_tensor
|
||||
loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
|
||||
target_scores_sum, fg_mask)
|
||||
keypoints = batch['keypoints'].to(self.device).float().clone()
|
||||
loss[0], loss[4] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
keypoints = batch["keypoints"].to(self.device).float().clone()
|
||||
keypoints[..., 0] *= imgsz[1]
|
||||
keypoints[..., 1] *= imgsz[0]
|
||||
|
||||
loss[1], loss[2] = self.calculate_keypoints_loss(fg_mask, target_gt_idx, keypoints, batch_idx,
|
||||
stride_tensor, target_bboxes, pred_kpts)
|
||||
loss[1], loss[2] = self.calculate_keypoints_loss(
|
||||
fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
|
||||
)
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.pose # pose gain
|
||||
|
|
@ -477,8 +517,9 @@ class v8PoseLoss(v8DetectionLoss):
|
|||
y[..., 1] += anchor_points[:, [1]] - 0.5
|
||||
return y
|
||||
|
||||
def calculate_keypoints_loss(self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes,
|
||||
pred_kpts):
|
||||
def calculate_keypoints_loss(
|
||||
self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
|
||||
):
|
||||
"""
|
||||
Calculate the keypoints loss for the model.
|
||||
|
||||
|
|
@ -507,21 +548,23 @@ class v8PoseLoss(v8DetectionLoss):
|
|||
max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
|
||||
|
||||
# Create a tensor to hold batched keypoints
|
||||
batched_keypoints = torch.zeros((batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]),
|
||||
device=keypoints.device)
|
||||
batched_keypoints = torch.zeros(
|
||||
(batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
|
||||
)
|
||||
|
||||
# TODO: any idea how to vectorize this?
|
||||
# Fill batched_keypoints with keypoints based on batch_idx
|
||||
for i in range(batch_size):
|
||||
keypoints_i = keypoints[batch_idx == i]
|
||||
batched_keypoints[i, :keypoints_i.shape[0]] = keypoints_i
|
||||
batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
|
||||
|
||||
# Expand dimensions of target_gt_idx to match the shape of batched_keypoints
|
||||
target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# Use target_gt_idx_expanded to select keypoints from batched_keypoints
|
||||
selected_keypoints = batched_keypoints.gather(
|
||||
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2]))
|
||||
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
|
||||
)
|
||||
|
||||
# Divide coordinates by stride
|
||||
selected_keypoints /= stride_tensor.view(1, -1, 1, 1)
|
||||
|
|
@ -547,13 +590,12 @@ class v8ClassificationLoss:
|
|||
|
||||
def __call__(self, preds, batch):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='mean')
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")
|
||||
loss_items = loss.detach()
|
||||
return loss, loss_items
|
||||
|
||||
|
||||
class v8OBBLoss(v8DetectionLoss):
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
super().__init__(model)
|
||||
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
|
|
@ -583,7 +625,8 @@ class v8OBBLoss(v8DetectionLoss):
|
|||
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
|
||||
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1)
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
# b, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
|
|
@ -596,19 +639,21 @@ class v8OBBLoss(v8DetectionLoss):
|
|||
|
||||
# targets
|
||||
try:
|
||||
batch_idx = batch['batch_idx'].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes'].view(-1, 5)), 1)
|
||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
|
||||
rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
|
||||
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
except RuntimeError as e:
|
||||
raise TypeError('ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n'
|
||||
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolov8n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'OBB' dataset using 'data=coco8-obb.yaml' "
|
||||
'as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help.') from e
|
||||
raise TypeError(
|
||||
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
|
||||
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolov8n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'OBB' dataset using 'data=coco8-obb.yaml' "
|
||||
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
|
||||
) from e
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
|
||||
|
|
@ -616,10 +661,14 @@ class v8OBBLoss(v8DetectionLoss):
|
|||
bboxes_for_assigner = pred_bboxes.clone().detach()
|
||||
# Only the first four elements need to be scaled
|
||||
bboxes_for_assigner[..., :4] *= stride_tensor
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(pred_scores.detach().sigmoid(),
|
||||
bboxes_for_assigner.type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor, gt_labels, gt_bboxes,
|
||||
mask_gt)
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
||||
pred_scores.detach().sigmoid(),
|
||||
bboxes_for_assigner.type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
|
|
@ -630,8 +679,9 @@ class v8OBBLoss(v8DetectionLoss):
|
|||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes[..., :4] /= stride_tensor
|
||||
loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
|
||||
target_scores_sum, fg_mask)
|
||||
loss[0], loss[2] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
else:
|
||||
loss[0] += (pred_angle * 0).sum()
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,10 @@ import torch
|
|||
|
||||
from ultralytics.utils import LOGGER, SimpleClass, TryExcept, plt_settings
|
||||
|
||||
OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
|
||||
OKS_SIGMA = (
|
||||
np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
|
||||
/ 10.0
|
||||
)
|
||||
|
||||
|
||||
def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
||||
|
|
@ -33,8 +36,9 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
|||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
|
||||
|
||||
# Intersection area
|
||||
inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \
|
||||
(np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)
|
||||
inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * (
|
||||
np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)
|
||||
).clip(0)
|
||||
|
||||
# Box2 area
|
||||
area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
|
||||
|
|
@ -99,8 +103,9 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
|||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
# Intersection area
|
||||
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * \
|
||||
(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0)
|
||||
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (
|
||||
b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
|
||||
).clamp_(0)
|
||||
|
||||
# Union Area
|
||||
union = w1 * h1 + w2 * h2 - inter + eps
|
||||
|
|
@ -111,10 +116,10 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
|||
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
|
||||
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
|
||||
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
||||
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
|
||||
c2 = cw**2 + ch**2 + eps # convex diagonal squared
|
||||
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
|
||||
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
||||
v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
|
||||
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
|
||||
with torch.no_grad():
|
||||
alpha = v / (v - iou + (1 + eps))
|
||||
return iou - (rho2 / c2 + v * alpha) # CIoU
|
||||
|
|
@ -202,12 +207,19 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
|||
a1, b1, c1 = _get_covariance_matrix(obb1)
|
||||
a2, b2, c2 = _get_covariance_matrix(obb2)
|
||||
|
||||
t1 = (((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) /
|
||||
((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.25
|
||||
t1 = (
|
||||
((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2)))
|
||||
/ ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)
|
||||
) * 0.25
|
||||
t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5
|
||||
t3 = torch.log(((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) /
|
||||
(4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) *
|
||||
(a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + eps) * 0.5
|
||||
t3 = (
|
||||
torch.log(
|
||||
((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)))
|
||||
/ (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) * (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps)
|
||||
+ eps
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
bd = t1 + t2 + t3
|
||||
bd = torch.clamp(bd, eps, 100.0)
|
||||
hd = torch.sqrt(1.0 - torch.exp(-bd) + eps)
|
||||
|
|
@ -215,7 +227,7 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
|||
if CIoU: # only include the wh aspect ratio part
|
||||
w1, h1 = obb1[..., 2:4].split(1, dim=-1)
|
||||
w2, h2 = obb2[..., 2:4].split(1, dim=-1)
|
||||
v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
|
||||
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
|
||||
with torch.no_grad():
|
||||
alpha = v / (v - iou + (1 + eps))
|
||||
return iou - v * alpha # CIoU
|
||||
|
|
@ -239,12 +251,19 @@ def batch_probiou(obb1, obb2, eps=1e-7):
|
|||
a1, b1, c1 = _get_covariance_matrix(obb1)
|
||||
a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))
|
||||
|
||||
t1 = (((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) /
|
||||
((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.25
|
||||
t1 = (
|
||||
((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2)))
|
||||
/ ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)
|
||||
) * 0.25
|
||||
t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5
|
||||
t3 = torch.log(((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) /
|
||||
(4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) *
|
||||
(a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + eps) * 0.5
|
||||
t3 = (
|
||||
torch.log(
|
||||
((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)))
|
||||
/ (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) * (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps)
|
||||
+ eps
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
bd = t1 + t2 + t3
|
||||
bd = torch.clamp(bd, eps, 100.0)
|
||||
hd = torch.sqrt(1.0 - torch.exp(-bd) + eps)
|
||||
|
|
@ -279,10 +298,10 @@ class ConfusionMatrix:
|
|||
iou_thres (float): The Intersection over Union threshold.
|
||||
"""
|
||||
|
||||
def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'):
|
||||
def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"):
|
||||
"""Initialize attributes for the YOLO model."""
|
||||
self.task = task
|
||||
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc))
|
||||
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
|
||||
self.nc = nc # number of classes
|
||||
self.conf = 0.25 if conf in (None, 0.001) else conf # apply 0.25 if default val conf is passed
|
||||
self.iou_thres = iou_thres
|
||||
|
|
@ -361,11 +380,11 @@ class ConfusionMatrix:
|
|||
tp = self.matrix.diagonal() # true positives
|
||||
fp = self.matrix.sum(1) - tp # false positives
|
||||
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
||||
return (tp[:-1], fp[:-1]) if self.task == 'detect' else (tp, fp) # remove background class if task=detect
|
||||
return (tp[:-1], fp[:-1]) if self.task == "detect" else (tp, fp) # remove background class if task=detect
|
||||
|
||||
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
|
||||
@TryExcept("WARNING ⚠️ ConfusionMatrix plot failure")
|
||||
@plt_settings()
|
||||
def plot(self, normalize=True, save_dir='', names=(), on_plot=None):
|
||||
def plot(self, normalize=True, save_dir="", names=(), on_plot=None):
|
||||
"""
|
||||
Plot the confusion matrix using seaborn and save it to a file.
|
||||
|
||||
|
|
@ -377,30 +396,31 @@ class ConfusionMatrix:
|
|||
"""
|
||||
import seaborn as sn
|
||||
|
||||
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
|
||||
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
|
||||
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
||||
nc, nn = self.nc, len(names) # number of classes, names
|
||||
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
||||
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
||||
ticklabels = (list(names) + ['background']) if labels else 'auto'
|
||||
ticklabels = (list(names) + ["background"]) if labels else "auto"
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
||||
sn.heatmap(array,
|
||||
ax=ax,
|
||||
annot=nc < 30,
|
||||
annot_kws={
|
||||
'size': 8},
|
||||
cmap='Blues',
|
||||
fmt='.2f' if normalize else '.0f',
|
||||
square=True,
|
||||
vmin=0.0,
|
||||
xticklabels=ticklabels,
|
||||
yticklabels=ticklabels).set_facecolor((1, 1, 1))
|
||||
title = 'Confusion Matrix' + ' Normalized' * normalize
|
||||
ax.set_xlabel('True')
|
||||
ax.set_ylabel('Predicted')
|
||||
warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
||||
sn.heatmap(
|
||||
array,
|
||||
ax=ax,
|
||||
annot=nc < 30,
|
||||
annot_kws={"size": 8},
|
||||
cmap="Blues",
|
||||
fmt=".2f" if normalize else ".0f",
|
||||
square=True,
|
||||
vmin=0.0,
|
||||
xticklabels=ticklabels,
|
||||
yticklabels=ticklabels,
|
||||
).set_facecolor((1, 1, 1))
|
||||
title = "Confusion Matrix" + " Normalized" * normalize
|
||||
ax.set_xlabel("True")
|
||||
ax.set_ylabel("Predicted")
|
||||
ax.set_title(title)
|
||||
plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png'
|
||||
fig.savefig(plot_fname, dpi=250)
|
||||
|
|
@ -411,7 +431,7 @@ class ConfusionMatrix:
|
|||
def print(self):
|
||||
"""Print the confusion matrix to the console."""
|
||||
for i in range(self.nc + 1):
|
||||
LOGGER.info(' '.join(map(str, self.matrix[i])))
|
||||
LOGGER.info(" ".join(map(str, self.matrix[i])))
|
||||
|
||||
|
||||
def smooth(y, f=0.05):
|
||||
|
|
@ -419,28 +439,28 @@ def smooth(y, f=0.05):
|
|||
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
|
||||
p = np.ones(nf // 2) # ones padding
|
||||
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
|
||||
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
|
||||
return np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed
|
||||
|
||||
|
||||
@plt_settings()
|
||||
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=None):
|
||||
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=None):
|
||||
"""Plots a precision-recall curve."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
py = np.stack(py, axis=1)
|
||||
|
||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
||||
for i, y in enumerate(py.T):
|
||||
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
|
||||
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
|
||||
else:
|
||||
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
|
||||
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
|
||||
|
||||
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
|
||||
ax.set_xlabel('Recall')
|
||||
ax.set_ylabel('Precision')
|
||||
ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean())
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
|
||||
ax.set_title('Precision-Recall Curve')
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
ax.set_title("Precision-Recall Curve")
|
||||
fig.savefig(save_dir, dpi=250)
|
||||
plt.close(fig)
|
||||
if on_plot:
|
||||
|
|
@ -448,24 +468,24 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=N
|
|||
|
||||
|
||||
@plt_settings()
|
||||
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric', on_plot=None):
|
||||
def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=(), xlabel="Confidence", ylabel="Metric", on_plot=None):
|
||||
"""Plots a metric-confidence curve."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
|
||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
||||
for i, y in enumerate(py):
|
||||
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
|
||||
ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
|
||||
else:
|
||||
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
|
||||
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
|
||||
|
||||
y = smooth(py.mean(0), 0.05)
|
||||
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
|
||||
ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
|
||||
ax.set_title(f'{ylabel}-Confidence Curve')
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
ax.set_title(f"{ylabel}-Confidence Curve")
|
||||
fig.savefig(save_dir, dpi=250)
|
||||
plt.close(fig)
|
||||
if on_plot:
|
||||
|
|
@ -494,8 +514,8 @@ def compute_ap(recall, precision):
|
|||
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
|
||||
|
||||
# Integrate area under curve
|
||||
method = 'interp' # methods: 'continuous', 'interp'
|
||||
if method == 'interp':
|
||||
method = "interp" # methods: 'continuous', 'interp'
|
||||
if method == "interp":
|
||||
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
|
||||
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
|
||||
else: # 'continuous'
|
||||
|
|
@ -505,16 +525,9 @@ def compute_ap(recall, precision):
|
|||
return ap, mpre, mrec
|
||||
|
||||
|
||||
def ap_per_class(tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=False,
|
||||
on_plot=None,
|
||||
save_dir=Path(),
|
||||
names=(),
|
||||
eps=1e-16,
|
||||
prefix=''):
|
||||
def ap_per_class(
|
||||
tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names=(), eps=1e-16, prefix=""
|
||||
):
|
||||
"""
|
||||
Computes the average precision per class for object detection evaluation.
|
||||
|
||||
|
|
@ -591,10 +604,10 @@ def ap_per_class(tp,
|
|||
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
|
||||
names = dict(enumerate(names)) # to dict
|
||||
if plot:
|
||||
plot_pr_curve(x, prec_values, ap, save_dir / f'{prefix}PR_curve.png', names, on_plot=on_plot)
|
||||
plot_mc_curve(x, f1_curve, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1', on_plot=on_plot)
|
||||
plot_mc_curve(x, p_curve, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision', on_plot=on_plot)
|
||||
plot_mc_curve(x, r_curve, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall', on_plot=on_plot)
|
||||
plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)
|
||||
plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)
|
||||
plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot)
|
||||
plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot)
|
||||
|
||||
i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index
|
||||
p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values
|
||||
|
|
@ -746,8 +759,18 @@ class Metric(SimpleClass):
|
|||
Updates the class attributes `self.p`, `self.r`, `self.f1`, `self.all_ap`, and `self.ap_class_index` based
|
||||
on the values provided in the `results` tuple.
|
||||
"""
|
||||
(self.p, self.r, self.f1, self.all_ap, self.ap_class_index, self.p_curve, self.r_curve, self.f1_curve, self.px,
|
||||
self.prec_values) = results
|
||||
(
|
||||
self.p,
|
||||
self.r,
|
||||
self.f1,
|
||||
self.all_ap,
|
||||
self.ap_class_index,
|
||||
self.p_curve,
|
||||
self.r_curve,
|
||||
self.f1_curve,
|
||||
self.px,
|
||||
self.prec_values,
|
||||
) = results
|
||||
|
||||
@property
|
||||
def curves(self):
|
||||
|
|
@ -757,8 +780,12 @@ class Metric(SimpleClass):
|
|||
@property
|
||||
def curves_results(self):
|
||||
"""Returns a list of curves for accessing specific metrics curves."""
|
||||
return [[self.px, self.prec_values, 'Recall', 'Precision'], [self.px, self.f1_curve, 'Confidence', 'F1'],
|
||||
[self.px, self.p_curve, 'Confidence', 'Precision'], [self.px, self.r_curve, 'Confidence', 'Recall']]
|
||||
return [
|
||||
[self.px, self.prec_values, "Recall", "Precision"],
|
||||
[self.px, self.f1_curve, "Confidence", "F1"],
|
||||
[self.px, self.p_curve, "Confidence", "Precision"],
|
||||
[self.px, self.r_curve, "Confidence", "Recall"],
|
||||
]
|
||||
|
||||
|
||||
class DetMetrics(SimpleClass):
|
||||
|
|
@ -793,33 +820,35 @@ class DetMetrics(SimpleClass):
|
|||
curves_results: TODO
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
|
||||
"""Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
self.on_plot = on_plot
|
||||
self.names = names
|
||||
self.box = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.task = 'detect'
|
||||
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
||||
self.task = "detect"
|
||||
|
||||
def process(self, tp, conf, pred_cls, target_cls):
|
||||
"""Process predicted results for object detection and update metrics."""
|
||||
results = ap_per_class(tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)[2:]
|
||||
results = ap_per_class(
|
||||
tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
)[2:]
|
||||
self.box.nc = len(self.names)
|
||||
self.box.update(results)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Returns a list of keys for accessing specific metrics."""
|
||||
return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
|
||||
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
||||
|
||||
def mean_results(self):
|
||||
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
||||
|
|
@ -847,12 +876,12 @@ class DetMetrics(SimpleClass):
|
|||
@property
|
||||
def results_dict(self):
|
||||
"""Returns dictionary of computed performance metrics and statistics."""
|
||||
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
||||
|
||||
@property
|
||||
def curves(self):
|
||||
"""Returns a list of curves for accessing specific metrics curves."""
|
||||
return ['Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)']
|
||||
return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
|
||||
|
||||
@property
|
||||
def curves_results(self):
|
||||
|
|
@ -889,7 +918,7 @@ class SegmentMetrics(SimpleClass):
|
|||
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
|
||||
"""Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names."""
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
|
|
@ -897,8 +926,8 @@ class SegmentMetrics(SimpleClass):
|
|||
self.names = names
|
||||
self.box = Metric()
|
||||
self.seg = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.task = 'segment'
|
||||
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
||||
self.task = "segment"
|
||||
|
||||
def process(self, tp, tp_m, conf, pred_cls, target_cls):
|
||||
"""
|
||||
|
|
@ -912,26 +941,30 @@ class SegmentMetrics(SimpleClass):
|
|||
target_cls (list): List of target classes.
|
||||
"""
|
||||
|
||||
results_mask = ap_per_class(tp_m,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Mask')[2:]
|
||||
results_mask = ap_per_class(
|
||||
tp_m,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix="Mask",
|
||||
)[2:]
|
||||
self.seg.nc = len(self.names)
|
||||
self.seg.update(results_mask)
|
||||
results_box = ap_per_class(tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Box')[2:]
|
||||
results_box = ap_per_class(
|
||||
tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix="Box",
|
||||
)[2:]
|
||||
self.box.nc = len(self.names)
|
||||
self.box.update(results_box)
|
||||
|
||||
|
|
@ -939,8 +972,15 @@ class SegmentMetrics(SimpleClass):
|
|||
def keys(self):
|
||||
"""Returns a list of keys for accessing metrics."""
|
||||
return [
|
||||
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
|
||||
'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)']
|
||||
"metrics/precision(B)",
|
||||
"metrics/recall(B)",
|
||||
"metrics/mAP50(B)",
|
||||
"metrics/mAP50-95(B)",
|
||||
"metrics/precision(M)",
|
||||
"metrics/recall(M)",
|
||||
"metrics/mAP50(M)",
|
||||
"metrics/mAP50-95(M)",
|
||||
]
|
||||
|
||||
def mean_results(self):
|
||||
"""Return the mean metrics for bounding box and segmentation results."""
|
||||
|
|
@ -968,14 +1008,21 @@ class SegmentMetrics(SimpleClass):
|
|||
@property
|
||||
def results_dict(self):
|
||||
"""Returns results of object detection model for evaluation."""
|
||||
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
||||
|
||||
@property
|
||||
def curves(self):
|
||||
"""Returns a list of curves for accessing specific metrics curves."""
|
||||
return [
|
||||
'Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)',
|
||||
'Precision-Recall(M)', 'F1-Confidence(M)', 'Precision-Confidence(M)', 'Recall-Confidence(M)']
|
||||
"Precision-Recall(B)",
|
||||
"F1-Confidence(B)",
|
||||
"Precision-Confidence(B)",
|
||||
"Recall-Confidence(B)",
|
||||
"Precision-Recall(M)",
|
||||
"F1-Confidence(M)",
|
||||
"Precision-Confidence(M)",
|
||||
"Recall-Confidence(M)",
|
||||
]
|
||||
|
||||
@property
|
||||
def curves_results(self):
|
||||
|
|
@ -1012,7 +1059,7 @@ class PoseMetrics(SegmentMetrics):
|
|||
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
|
||||
"""Initialize the PoseMetrics class with directory path, class names, and plotting options."""
|
||||
super().__init__(save_dir, plot, names)
|
||||
self.save_dir = save_dir
|
||||
|
|
@ -1021,8 +1068,8 @@ class PoseMetrics(SegmentMetrics):
|
|||
self.names = names
|
||||
self.box = Metric()
|
||||
self.pose = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.task = 'pose'
|
||||
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
||||
self.task = "pose"
|
||||
|
||||
def process(self, tp, tp_p, conf, pred_cls, target_cls):
|
||||
"""
|
||||
|
|
@ -1036,26 +1083,30 @@ class PoseMetrics(SegmentMetrics):
|
|||
target_cls (list): List of target classes.
|
||||
"""
|
||||
|
||||
results_pose = ap_per_class(tp_p,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Pose')[2:]
|
||||
results_pose = ap_per_class(
|
||||
tp_p,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix="Pose",
|
||||
)[2:]
|
||||
self.pose.nc = len(self.names)
|
||||
self.pose.update(results_pose)
|
||||
results_box = ap_per_class(tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Box')[2:]
|
||||
results_box = ap_per_class(
|
||||
tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix="Box",
|
||||
)[2:]
|
||||
self.box.nc = len(self.names)
|
||||
self.box.update(results_box)
|
||||
|
||||
|
|
@ -1063,8 +1114,15 @@ class PoseMetrics(SegmentMetrics):
|
|||
def keys(self):
|
||||
"""Returns list of evaluation metric keys."""
|
||||
return [
|
||||
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
|
||||
'metrics/precision(P)', 'metrics/recall(P)', 'metrics/mAP50(P)', 'metrics/mAP50-95(P)']
|
||||
"metrics/precision(B)",
|
||||
"metrics/recall(B)",
|
||||
"metrics/mAP50(B)",
|
||||
"metrics/mAP50-95(B)",
|
||||
"metrics/precision(P)",
|
||||
"metrics/recall(P)",
|
||||
"metrics/mAP50(P)",
|
||||
"metrics/mAP50-95(P)",
|
||||
]
|
||||
|
||||
def mean_results(self):
|
||||
"""Return the mean results of box and pose."""
|
||||
|
|
@ -1088,8 +1146,15 @@ class PoseMetrics(SegmentMetrics):
|
|||
def curves(self):
|
||||
"""Returns a list of curves for accessing specific metrics curves."""
|
||||
return [
|
||||
'Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)',
|
||||
'Precision-Recall(P)', 'F1-Confidence(P)', 'Precision-Confidence(P)', 'Recall-Confidence(P)']
|
||||
"Precision-Recall(B)",
|
||||
"F1-Confidence(B)",
|
||||
"Precision-Confidence(B)",
|
||||
"Recall-Confidence(B)",
|
||||
"Precision-Recall(P)",
|
||||
"F1-Confidence(P)",
|
||||
"Precision-Confidence(P)",
|
||||
"Recall-Confidence(P)",
|
||||
]
|
||||
|
||||
@property
|
||||
def curves_results(self):
|
||||
|
|
@ -1119,8 +1184,8 @@ class ClassifyMetrics(SimpleClass):
|
|||
"""Initialize a ClassifyMetrics instance."""
|
||||
self.top1 = 0
|
||||
self.top5 = 0
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.task = 'classify'
|
||||
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
||||
self.task = "classify"
|
||||
|
||||
def process(self, targets, pred):
|
||||
"""Target classes and predicted classes."""
|
||||
|
|
@ -1137,12 +1202,12 @@ class ClassifyMetrics(SimpleClass):
|
|||
@property
|
||||
def results_dict(self):
|
||||
"""Returns a dictionary with model's performance metrics and fitness score."""
|
||||
return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness]))
|
||||
return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Returns a list of keys for the results_dict property."""
|
||||
return ['metrics/accuracy_top1', 'metrics/accuracy_top5']
|
||||
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
|
||||
|
||||
@property
|
||||
def curves(self):
|
||||
|
|
@ -1156,32 +1221,33 @@ class ClassifyMetrics(SimpleClass):
|
|||
|
||||
|
||||
class OBBMetrics(SimpleClass):
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
self.on_plot = on_plot
|
||||
self.names = names
|
||||
self.box = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
||||
|
||||
def process(self, tp, conf, pred_cls, target_cls):
|
||||
"""Process predicted results for object detection and update metrics."""
|
||||
results = ap_per_class(tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)[2:]
|
||||
results = ap_per_class(
|
||||
tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
)[2:]
|
||||
self.box.nc = len(self.names)
|
||||
self.box.update(results)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Returns a list of keys for accessing specific metrics."""
|
||||
return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
|
||||
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
||||
|
||||
def mean_results(self):
|
||||
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
||||
|
|
@ -1209,7 +1275,7 @@ class OBBMetrics(SimpleClass):
|
|||
@property
|
||||
def results_dict(self):
|
||||
"""Returns dictionary of computed performance metrics and statistics."""
|
||||
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
||||
|
||||
@property
|
||||
def curves(self):
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class Profile(contextlib.ContextDecorator):
|
|||
|
||||
def __str__(self):
|
||||
"""Returns a human-readable string representing the accumulated elapsed time in the profiler."""
|
||||
return f'Elapsed time is {self.t} s'
|
||||
return f"Elapsed time is {self.t} s"
|
||||
|
||||
def time(self):
|
||||
"""Get current time."""
|
||||
|
|
@ -76,9 +76,13 @@ def segment2box(segment, width=640, height=640):
|
|||
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
||||
x, y = segment.T # segment xy
|
||||
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
|
||||
x, y, = x[inside], y[inside]
|
||||
return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros(
|
||||
4, dtype=segment.dtype) # xyxy
|
||||
x = x[inside]
|
||||
y = y[inside]
|
||||
return (
|
||||
np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
|
||||
if any(x)
|
||||
else np.zeros(4, dtype=segment.dtype)
|
||||
) # xyxy
|
||||
|
||||
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
|
||||
|
|
@ -101,8 +105,10 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xyw
|
|||
"""
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round(
|
||||
(img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding
|
||||
pad = (
|
||||
round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
|
||||
round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
|
||||
) # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
|
|
@ -145,7 +151,7 @@ def nms_rotated(boxes, scores, threshold=0.45):
|
|||
Returns:
|
||||
"""
|
||||
if len(boxes) == 0:
|
||||
return np.empty((0, ), dtype=np.int8)
|
||||
return np.empty((0,), dtype=np.int8)
|
||||
sorted_idx = torch.argsort(scores, descending=True)
|
||||
boxes = boxes[sorted_idx]
|
||||
ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
|
||||
|
|
@ -199,8 +205,8 @@ def non_max_suppression(
|
|||
"""
|
||||
|
||||
# Checks
|
||||
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
||||
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
||||
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
||||
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
|
||||
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
|
||||
prediction = prediction[0] # select only inference output
|
||||
|
||||
|
|
@ -284,7 +290,7 @@ def non_max_suppression(
|
|||
|
||||
output[xi] = x[i]
|
||||
if (time.time() - t) > time_limit:
|
||||
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
|
||||
LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
|
||||
break # time limit exceeded
|
||||
|
||||
return output
|
||||
|
|
@ -378,7 +384,7 @@ def xyxy2xywh(x):
|
|||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
||||
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
|
||||
|
|
@ -398,7 +404,7 @@ def xywh2xyxy(x):
|
|||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
dw = x[..., 2] / 2 # half-width
|
||||
dh = x[..., 3] / 2 # half-height
|
||||
|
|
@ -423,7 +429,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
|||
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
|
||||
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
||||
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
|
||||
|
|
@ -449,7 +455,7 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
|||
"""
|
||||
if clip:
|
||||
x = clip_boxes(x, (h - eps, w - eps))
|
||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
||||
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
|
||||
|
|
@ -526,8 +532,11 @@ def xyxyxyxy2xywhr(corners):
|
|||
# especially some objects are cut off by augmentations in dataloader.
|
||||
(x, y), (w, h), angle = cv2.minAreaRect(pts)
|
||||
rboxes.append([x, y, w, h, angle / 180 * np.pi])
|
||||
rboxes = torch.tensor(rboxes, device=corners.device, dtype=corners.dtype) if is_torch else np.asarray(
|
||||
rboxes, dtype=points.dtype)
|
||||
rboxes = (
|
||||
torch.tensor(rboxes, device=corners.device, dtype=corners.dtype)
|
||||
if is_torch
|
||||
else np.asarray(rboxes, dtype=points.dtype)
|
||||
)
|
||||
return rboxes
|
||||
|
||||
|
||||
|
|
@ -546,7 +555,7 @@ def xywhr2xyxyxyxy(center):
|
|||
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
|
||||
|
||||
ctr = center[..., :2]
|
||||
w, h, angle = (center[..., i:i + 1] for i in range(2, 5))
|
||||
w, h, angle = (center[..., i : i + 1] for i in range(2, 5))
|
||||
cos_value, sin_value = cos(angle), sin(angle)
|
||||
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
|
||||
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
|
||||
|
|
@ -607,8 +616,9 @@ def resample_segments(segments, n=1000):
|
|||
s = np.concatenate((s, s[0:1, :]), axis=0)
|
||||
x = np.linspace(0, len(s) - 1, n)
|
||||
xp = np.arange(len(s))
|
||||
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)],
|
||||
dtype=np.float32).reshape(2, -1).T # segment xy
|
||||
segments[i] = (
|
||||
np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
|
||||
) # segment xy
|
||||
return segments
|
||||
|
||||
|
||||
|
|
@ -647,7 +657,7 @@ def process_mask_upsample(protos, masks_in, bboxes, shape):
|
|||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
|
||||
masks = crop_mask(masks, bboxes) # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
|
@ -680,7 +690,7 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
|||
|
||||
masks = crop_mask(masks, downsampled_bboxes) # CHW
|
||||
if upsample:
|
||||
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
|
|
@ -724,7 +734,7 @@ def scale_masks(masks, shape, padding=True):
|
|||
bottom, right = (int(round(mh - pad[1] + 0.1)), int(round(mw - pad[0] + 0.1)))
|
||||
masks = masks[..., top:bottom, left:right]
|
||||
|
||||
masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW
|
||||
masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW
|
||||
return masks
|
||||
|
||||
|
||||
|
|
@ -763,7 +773,7 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
|
|||
return coords
|
||||
|
||||
|
||||
def masks2segments(masks, strategy='largest'):
|
||||
def masks2segments(masks, strategy="largest"):
|
||||
"""
|
||||
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
|
||||
|
||||
|
|
@ -775,16 +785,16 @@ def masks2segments(masks, strategy='largest'):
|
|||
segments (List): list of segment masks
|
||||
"""
|
||||
segments = []
|
||||
for x in masks.int().cpu().numpy().astype('uint8'):
|
||||
for x in masks.int().cpu().numpy().astype("uint8"):
|
||||
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
|
||||
if c:
|
||||
if strategy == 'concat': # concatenate all segments
|
||||
if strategy == "concat": # concatenate all segments
|
||||
c = np.concatenate([x.reshape(-1, 2) for x in c])
|
||||
elif strategy == 'largest': # select largest segment
|
||||
elif strategy == "largest": # select largest segment
|
||||
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
|
||||
else:
|
||||
c = np.zeros((0, 2)) # no segments found
|
||||
segments.append(c.astype('float32'))
|
||||
segments.append(c.astype("float32"))
|
||||
return segments
|
||||
|
||||
|
||||
|
|
@ -811,4 +821,4 @@ def clean_str(s):
|
|||
Returns:
|
||||
(str): a string with special characters replaced by an underscore _
|
||||
"""
|
||||
return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
|
||||
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ def imshow(winname: str, mat: np.ndarray):
|
|||
winname (str): Name of the window.
|
||||
mat (np.ndarray): Image to be shown.
|
||||
"""
|
||||
_imshow(winname.encode('unicode_escape').decode(), mat)
|
||||
_imshow(winname.encode("unicode_escape").decode(), mat)
|
||||
|
||||
|
||||
# PyTorch functions ----------------------------------------------------------------------------------------------------
|
||||
|
|
@ -72,6 +72,6 @@ def torch_save(*args, **kwargs):
|
|||
except ImportError:
|
||||
import pickle
|
||||
|
||||
if 'pickle_module' not in kwargs:
|
||||
kwargs['pickle_module'] = pickle # noqa
|
||||
if "pickle_module" not in kwargs:
|
||||
kwargs["pickle_module"] = pickle # noqa
|
||||
return _torch_save(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -33,15 +33,55 @@ class Colors:
|
|||
|
||||
def __init__(self):
|
||||
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
|
||||
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
||||
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
||||
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
|
||||
hexs = (
|
||||
"FF3838",
|
||||
"FF9D97",
|
||||
"FF701F",
|
||||
"FFB21D",
|
||||
"CFD231",
|
||||
"48F90A",
|
||||
"92CC17",
|
||||
"3DDB86",
|
||||
"1A9334",
|
||||
"00D4BB",
|
||||
"2C99A8",
|
||||
"00C2FF",
|
||||
"344593",
|
||||
"6473FF",
|
||||
"0018EC",
|
||||
"8438FF",
|
||||
"520085",
|
||||
"CB38FF",
|
||||
"FF95C8",
|
||||
"FF37C7",
|
||||
)
|
||||
self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
|
||||
self.n = len(self.palette)
|
||||
self.pose_palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0], [255, 153, 255],
|
||||
[153, 204, 255], [255, 102, 255], [255, 51, 255], [102, 178, 255], [51, 153, 255],
|
||||
[255, 153, 153], [255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102],
|
||||
[51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255]],
|
||||
dtype=np.uint8)
|
||||
self.pose_palette = np.array(
|
||||
[
|
||||
[255, 128, 0],
|
||||
[255, 153, 51],
|
||||
[255, 178, 102],
|
||||
[230, 230, 0],
|
||||
[255, 153, 255],
|
||||
[153, 204, 255],
|
||||
[255, 102, 255],
|
||||
[255, 51, 255],
|
||||
[102, 178, 255],
|
||||
[51, 153, 255],
|
||||
[255, 153, 153],
|
||||
[255, 102, 102],
|
||||
[255, 51, 51],
|
||||
[153, 255, 153],
|
||||
[102, 255, 102],
|
||||
[51, 255, 51],
|
||||
[0, 255, 0],
|
||||
[0, 0, 255],
|
||||
[255, 0, 0],
|
||||
[255, 255, 255],
|
||||
],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
def __call__(self, i, bgr=False):
|
||||
"""Converts hex color codes to RGB values."""
|
||||
|
|
@ -51,7 +91,7 @@ class Colors:
|
|||
@staticmethod
|
||||
def hex2rgb(h):
|
||||
"""Converts hex color codes to RGB values (i.e. default PIL order)."""
|
||||
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
|
||||
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
|
||||
|
||||
|
||||
colors = Colors() # create instance for 'from utils.plots import colors'
|
||||
|
|
@ -71,9 +111,9 @@ class Annotator:
|
|||
kpt_color (List[int]): Color palette for keypoints.
|
||||
"""
|
||||
|
||||
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
|
||||
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
|
||||
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
||||
assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images."
|
||||
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
|
||||
self.pil = pil or non_ascii
|
||||
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
||||
|
|
@ -81,26 +121,45 @@ class Annotator:
|
|||
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
||||
self.draw = ImageDraw.Draw(self.im)
|
||||
try:
|
||||
font = check_font('Arial.Unicode.ttf' if non_ascii else font)
|
||||
font = check_font("Arial.Unicode.ttf" if non_ascii else font)
|
||||
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
|
||||
self.font = ImageFont.truetype(str(font), size)
|
||||
except Exception:
|
||||
self.font = ImageFont.load_default()
|
||||
# Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
|
||||
if check_version(pil_version, '9.2.0'):
|
||||
if check_version(pil_version, "9.2.0"):
|
||||
self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
|
||||
else: # use cv2
|
||||
self.im = im if im.flags.writeable else im.copy()
|
||||
self.tf = max(self.lw - 1, 1) # font thickness
|
||||
self.sf = self.lw / 3 # font scale
|
||||
# Pose
|
||||
self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9],
|
||||
[8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
|
||||
self.skeleton = [
|
||||
[16, 14],
|
||||
[14, 12],
|
||||
[17, 15],
|
||||
[15, 13],
|
||||
[12, 13],
|
||||
[6, 12],
|
||||
[7, 13],
|
||||
[6, 7],
|
||||
[6, 8],
|
||||
[7, 9],
|
||||
[8, 10],
|
||||
[9, 11],
|
||||
[2, 3],
|
||||
[1, 2],
|
||||
[1, 3],
|
||||
[2, 4],
|
||||
[3, 5],
|
||||
[4, 6],
|
||||
[5, 7],
|
||||
]
|
||||
|
||||
self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
|
||||
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
|
||||
|
||||
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
|
||||
def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
|
||||
"""Add one xyxy box to image with label."""
|
||||
if isinstance(box, torch.Tensor):
|
||||
box = box.tolist()
|
||||
|
|
@ -134,13 +193,16 @@ class Annotator:
|
|||
outside = p1[1] - h >= 3
|
||||
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
||||
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
||||
cv2.putText(self.im,
|
||||
label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
|
||||
0,
|
||||
self.sf,
|
||||
txt_color,
|
||||
thickness=self.tf,
|
||||
lineType=cv2.LINE_AA)
|
||||
cv2.putText(
|
||||
self.im,
|
||||
label,
|
||||
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
|
||||
0,
|
||||
self.sf,
|
||||
txt_color,
|
||||
thickness=self.tf,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
|
||||
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
|
||||
"""
|
||||
|
|
@ -171,7 +233,7 @@ class Annotator:
|
|||
im_gpu = im_gpu.flip(dims=[0]) # flip channel
|
||||
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
|
||||
im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
|
||||
im_mask = (im_gpu * 255)
|
||||
im_mask = im_gpu * 255
|
||||
im_mask_np = im_mask.byte().cpu().numpy()
|
||||
self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
|
||||
if self.pil:
|
||||
|
|
@ -230,9 +292,9 @@ class Annotator:
|
|||
"""Add rectangle to image (PIL-only)."""
|
||||
self.draw.rectangle(xy, fill, outline, width)
|
||||
|
||||
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top', box_style=False):
|
||||
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):
|
||||
"""Adds text to an image using PIL or cv2."""
|
||||
if anchor == 'bottom': # start y from font bottom
|
||||
if anchor == "bottom": # start y from font bottom
|
||||
w, h = self.font.getsize(text) # text width, height
|
||||
xy[1] += 1 - h
|
||||
if self.pil:
|
||||
|
|
@ -241,8 +303,8 @@ class Annotator:
|
|||
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
|
||||
# Using `txt_color` for background and draw fg with white color
|
||||
txt_color = (255, 255, 255)
|
||||
if '\n' in text:
|
||||
lines = text.split('\n')
|
||||
if "\n" in text:
|
||||
lines = text.split("\n")
|
||||
_, h = self.font.getsize(text)
|
||||
for line in lines:
|
||||
self.draw.text(xy, line, fill=txt_color, font=self.font)
|
||||
|
|
@ -314,15 +376,12 @@ class Annotator:
|
|||
text_y = t_size_in[1]
|
||||
|
||||
# Create a rounded rectangle for in_count
|
||||
cv2.rectangle(self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color,
|
||||
-1)
|
||||
cv2.putText(self.im,
|
||||
str(counts), (text_x, text_y + t_size_in[1]),
|
||||
0,
|
||||
tl / 2,
|
||||
txt_color,
|
||||
self.tf,
|
||||
lineType=cv2.LINE_AA)
|
||||
cv2.rectangle(
|
||||
self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color, -1
|
||||
)
|
||||
cv2.putText(
|
||||
self.im, str(counts), (text_x, text_y + t_size_in[1]), 0, tl / 2, txt_color, self.tf, lineType=cv2.LINE_AA
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def estimate_pose_angle(a, b, c):
|
||||
|
|
@ -375,7 +434,7 @@ class Annotator:
|
|||
center_kpt (int): centroid pose index for workout monitoring
|
||||
line_thickness (int): thickness for text display
|
||||
"""
|
||||
angle_text, count_text, stage_text = (f' {angle_text:.2f}', 'Steps : ' + f'{count_text}', f' {stage_text}')
|
||||
angle_text, count_text, stage_text = (f" {angle_text:.2f}", "Steps : " + f"{count_text}", f" {stage_text}")
|
||||
font_scale = 0.6 + (line_thickness / 10.0)
|
||||
|
||||
# Draw angle
|
||||
|
|
@ -383,21 +442,37 @@ class Annotator:
|
|||
angle_text_position = (int(center_kpt[0]), int(center_kpt[1]))
|
||||
angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5)
|
||||
angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (line_thickness * 2))
|
||||
cv2.rectangle(self.im, angle_background_position, (angle_background_position[0] + angle_background_size[0],
|
||||
angle_background_position[1] + angle_background_size[1]),
|
||||
(255, 255, 255), -1)
|
||||
cv2.rectangle(
|
||||
self.im,
|
||||
angle_background_position,
|
||||
(
|
||||
angle_background_position[0] + angle_background_size[0],
|
||||
angle_background_position[1] + angle_background_size[1],
|
||||
),
|
||||
(255, 255, 255),
|
||||
-1,
|
||||
)
|
||||
cv2.putText(self.im, angle_text, angle_text_position, 0, font_scale, (0, 0, 0), line_thickness)
|
||||
|
||||
# Draw Counts
|
||||
(count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, font_scale, line_thickness)
|
||||
count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20)
|
||||
count_background_position = (angle_background_position[0],
|
||||
angle_background_position[1] + angle_background_size[1] + 5)
|
||||
count_background_position = (
|
||||
angle_background_position[0],
|
||||
angle_background_position[1] + angle_background_size[1] + 5,
|
||||
)
|
||||
count_background_size = (count_text_width + 10, count_text_height + 10 + (line_thickness * 2))
|
||||
|
||||
cv2.rectangle(self.im, count_background_position, (count_background_position[0] + count_background_size[0],
|
||||
count_background_position[1] + count_background_size[1]),
|
||||
(255, 255, 255), -1)
|
||||
cv2.rectangle(
|
||||
self.im,
|
||||
count_background_position,
|
||||
(
|
||||
count_background_position[0] + count_background_size[0],
|
||||
count_background_position[1] + count_background_size[1],
|
||||
),
|
||||
(255, 255, 255),
|
||||
-1,
|
||||
)
|
||||
cv2.putText(self.im, count_text, count_text_position, 0, font_scale, (0, 0, 0), line_thickness)
|
||||
|
||||
# Draw Stage
|
||||
|
|
@ -406,9 +481,16 @@ class Annotator:
|
|||
stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5)
|
||||
stage_background_size = (stage_text_width + 10, stage_text_height + 10)
|
||||
|
||||
cv2.rectangle(self.im, stage_background_position, (stage_background_position[0] + stage_background_size[0],
|
||||
stage_background_position[1] + stage_background_size[1]),
|
||||
(255, 255, 255), -1)
|
||||
cv2.rectangle(
|
||||
self.im,
|
||||
stage_background_position,
|
||||
(
|
||||
stage_background_position[0] + stage_background_size[0],
|
||||
stage_background_position[1] + stage_background_size[1],
|
||||
),
|
||||
(255, 255, 255),
|
||||
-1,
|
||||
)
|
||||
cv2.putText(self.im, stage_text, stage_text_position, 0, font_scale, (0, 0, 0), line_thickness)
|
||||
|
||||
def seg_bbox(self, mask, mask_color=(255, 0, 255), det_label=None, track_label=None):
|
||||
|
|
@ -423,14 +505,20 @@ class Annotator:
|
|||
"""
|
||||
cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
|
||||
|
||||
label = f'Track ID: {track_label}' if track_label else det_label
|
||||
label = f"Track ID: {track_label}" if track_label else det_label
|
||||
text_size, _ = cv2.getTextSize(label, 0, 0.7, 1)
|
||||
|
||||
cv2.rectangle(self.im, (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
|
||||
(int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)), mask_color, -1)
|
||||
cv2.rectangle(
|
||||
self.im,
|
||||
(int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
|
||||
(int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)),
|
||||
mask_color,
|
||||
-1,
|
||||
)
|
||||
|
||||
cv2.putText(self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255),
|
||||
2)
|
||||
cv2.putText(
|
||||
self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255), 2
|
||||
)
|
||||
|
||||
def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255), thickness=2, pins_radius=10):
|
||||
"""
|
||||
|
|
@ -452,24 +540,24 @@ class Annotator:
|
|||
|
||||
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
||||
@plt_settings()
|
||||
def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
|
||||
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
||||
"""Plot training labels including class histograms and box statistics."""
|
||||
import pandas as pd
|
||||
import seaborn as sn
|
||||
|
||||
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='The figure layout has changed to tight')
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
# Plot dataset labels
|
||||
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
||||
nc = int(cls.max() + 1) # number of classes
|
||||
boxes = boxes[:1000000] # limit to 1M boxes
|
||||
x = pd.DataFrame(boxes, columns=['x', 'y', 'width', 'height'])
|
||||
x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"])
|
||||
|
||||
# Seaborn correlogram
|
||||
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
||||
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
||||
sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
||||
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
|
||||
plt.close()
|
||||
|
||||
# Matplotlib labels
|
||||
|
|
@ -477,14 +565,14 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
|
|||
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
||||
for i in range(nc):
|
||||
y[2].patches[i].set_color([x / 255 for x in colors(i)])
|
||||
ax[0].set_ylabel('instances')
|
||||
ax[0].set_ylabel("instances")
|
||||
if 0 < len(names) < 30:
|
||||
ax[0].set_xticks(range(len(names)))
|
||||
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
||||
else:
|
||||
ax[0].set_xlabel('classes')
|
||||
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
|
||||
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
|
||||
ax[0].set_xlabel("classes")
|
||||
sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
|
||||
sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
|
||||
|
||||
# Rectangles
|
||||
boxes[:, 0:2] = 0.5 # center
|
||||
|
|
@ -493,20 +581,20 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
|
|||
for cls, box in zip(cls[:500], boxes[:500]):
|
||||
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
|
||||
ax[1].imshow(img)
|
||||
ax[1].axis('off')
|
||||
ax[1].axis("off")
|
||||
|
||||
for a in [0, 1, 2, 3]:
|
||||
for s in ['top', 'right', 'left', 'bottom']:
|
||||
for s in ["top", "right", "left", "bottom"]:
|
||||
ax[a].spines[s].set_visible(False)
|
||||
|
||||
fname = save_dir / 'labels.jpg'
|
||||
fname = save_dir / "labels.jpg"
|
||||
plt.savefig(fname, dpi=200)
|
||||
plt.close()
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
|
||||
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
||||
def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
||||
"""
|
||||
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
||||
|
||||
|
|
@ -545,29 +633,31 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
|
|||
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
||||
xyxy = ops.xywh2xyxy(b).long()
|
||||
xyxy = ops.clip_boxes(xyxy, im.shape)
|
||||
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
|
||||
crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
|
||||
if save:
|
||||
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
||||
f = str(increment_path(file).with_suffix('.jpg'))
|
||||
f = str(increment_path(file).with_suffix(".jpg"))
|
||||
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
|
||||
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
|
||||
return crop
|
||||
|
||||
|
||||
@threaded
|
||||
def plot_images(images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes=np.zeros(0, dtype=np.float32),
|
||||
confs=None,
|
||||
masks=np.zeros(0, dtype=np.uint8),
|
||||
kpts=np.zeros((0, 51), dtype=np.float32),
|
||||
paths=None,
|
||||
fname='images.jpg',
|
||||
names=None,
|
||||
on_plot=None,
|
||||
max_subplots=16,
|
||||
save=True):
|
||||
def plot_images(
|
||||
images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes=np.zeros(0, dtype=np.float32),
|
||||
confs=None,
|
||||
masks=np.zeros(0, dtype=np.uint8),
|
||||
kpts=np.zeros((0, 51), dtype=np.float32),
|
||||
paths=None,
|
||||
fname="images.jpg",
|
||||
names=None,
|
||||
on_plot=None,
|
||||
max_subplots=16,
|
||||
save=True,
|
||||
):
|
||||
"""Plot image grid with labels."""
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.cpu().float().numpy()
|
||||
|
|
@ -585,7 +675,7 @@ def plot_images(images,
|
|||
max_size = 1920 # max image size
|
||||
bs, _, h, w = images.shape # batch size, _, height, width
|
||||
bs = min(bs, max_subplots) # limit plot images
|
||||
ns = np.ceil(bs ** 0.5) # number of subplots (square)
|
||||
ns = np.ceil(bs**0.5) # number of subplots (square)
|
||||
if np.max(images[0]) <= 1:
|
||||
images *= 255 # de-normalise (optional)
|
||||
|
||||
|
|
@ -593,7 +683,7 @@ def plot_images(images,
|
|||
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
||||
for i in range(bs):
|
||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||
mosaic[y:y + h, x:x + w, :] = images[i].transpose(1, 2, 0)
|
||||
mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
|
||||
|
||||
# Resize (optional)
|
||||
scale = max_size / ns / max(h, w)
|
||||
|
|
@ -612,7 +702,7 @@ def plot_images(images,
|
|||
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
||||
if len(cls) > 0:
|
||||
idx = batch_idx == i
|
||||
classes = cls[idx].astype('int')
|
||||
classes = cls[idx].astype("int")
|
||||
labels = confs is None
|
||||
|
||||
if len(bboxes):
|
||||
|
|
@ -633,14 +723,14 @@ def plot_images(images,
|
|||
color = colors(c)
|
||||
c = names.get(c, c) if names else c
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
|
||||
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
|
||||
annotator.box_label(box, label, color=color, rotated=is_obb)
|
||||
|
||||
elif len(classes):
|
||||
for c in classes:
|
||||
color = colors(c)
|
||||
c = names.get(c, c) if names else c
|
||||
annotator.text((x, y), f'{c}', txt_color=color, box_style=True)
|
||||
annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
|
||||
|
||||
# Plot keypoints
|
||||
if len(kpts):
|
||||
|
|
@ -680,7 +770,9 @@ def plot_images(images,
|
|||
else:
|
||||
mask = image_masks[j].astype(bool)
|
||||
with contextlib.suppress(Exception):
|
||||
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
||||
im[y : y + h, x : x + w, :][mask] = (
|
||||
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
||||
)
|
||||
annotator.fromarray(im)
|
||||
if save:
|
||||
annotator.im.save(fname) # save
|
||||
|
|
@ -691,7 +783,7 @@ def plot_images(images,
|
|||
|
||||
|
||||
@plt_settings()
|
||||
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
|
||||
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
|
||||
"""
|
||||
Plot training results from a results CSV file. The function supports various types of data including segmentation,
|
||||
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
|
||||
|
|
@ -714,6 +806,7 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
|
|||
"""
|
||||
import pandas as pd
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
save_dir = Path(file).parent if file else Path(dir)
|
||||
if classify:
|
||||
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
|
||||
|
|
@ -728,32 +821,32 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
|
|||
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
||||
index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
|
||||
ax = ax.ravel()
|
||||
files = list(save_dir.glob('results*.csv'))
|
||||
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
|
||||
files = list(save_dir.glob("results*.csv"))
|
||||
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
|
||||
for f in files:
|
||||
try:
|
||||
data = pd.read_csv(f)
|
||||
s = [x.strip() for x in data.columns]
|
||||
x = data.values[:, 0]
|
||||
for i, j in enumerate(index):
|
||||
y = data.values[:, j].astype('float')
|
||||
y = data.values[:, j].astype("float")
|
||||
# y[y == 0] = np.nan # don't show zero values
|
||||
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results
|
||||
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line
|
||||
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
|
||||
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
|
||||
ax[i].set_title(s[j], fontsize=12)
|
||||
# if j in [8, 9, 10]: # share train and val loss y axes
|
||||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
|
||||
LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
|
||||
ax[1].legend()
|
||||
fname = save_dir / 'results.png'
|
||||
fname = save_dir / "results.png"
|
||||
fig.savefig(fname, dpi=200)
|
||||
plt.close()
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
|
||||
def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none'):
|
||||
def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
|
||||
"""
|
||||
Plots a scatter plot with points colored based on a 2D histogram.
|
||||
|
||||
|
|
@ -774,14 +867,18 @@ def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none
|
|||
# Calculate 2D histogram and corresponding colors
|
||||
hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
|
||||
colors = [
|
||||
hist[min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
|
||||
min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1)] for i in range(len(v))]
|
||||
hist[
|
||||
min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
|
||||
min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
|
||||
]
|
||||
for i in range(len(v))
|
||||
]
|
||||
|
||||
# Scatter plot
|
||||
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
|
||||
|
||||
|
||||
def plot_tune_results(csv_file='tune_results.csv'):
|
||||
def plot_tune_results(csv_file="tune_results.csv"):
|
||||
"""
|
||||
Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
|
||||
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
|
||||
|
|
@ -810,33 +907,33 @@ def plot_tune_results(csv_file='tune_results.csv'):
|
|||
v = x[:, i + num_metrics_columns]
|
||||
mu = v[j] # best single result
|
||||
plt.subplot(n, n, i + 1)
|
||||
plt_color_scatter(v, fitness, cmap='viridis', alpha=.8, edgecolors='none')
|
||||
plt.plot(mu, fitness.max(), 'k+', markersize=15)
|
||||
plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
|
||||
plt.tick_params(axis='both', labelsize=8) # Set axis label size to 8
|
||||
plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
|
||||
plt.plot(mu, fitness.max(), "k+", markersize=15)
|
||||
plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters
|
||||
plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
|
||||
if i % n != 0:
|
||||
plt.yticks([])
|
||||
|
||||
file = csv_file.with_name('tune_scatter_plots.png') # filename
|
||||
file = csv_file.with_name("tune_scatter_plots.png") # filename
|
||||
plt.savefig(file, dpi=200)
|
||||
plt.close()
|
||||
LOGGER.info(f'Saved {file}')
|
||||
LOGGER.info(f"Saved {file}")
|
||||
|
||||
# Fitness vs iteration
|
||||
x = range(1, len(fitness) + 1)
|
||||
plt.figure(figsize=(10, 6), tight_layout=True)
|
||||
plt.plot(x, fitness, marker='o', linestyle='none', label='fitness')
|
||||
plt.plot(x, gaussian_filter1d(fitness, sigma=3), ':', label='smoothed', linewidth=2) # smoothing line
|
||||
plt.title('Fitness vs Iteration')
|
||||
plt.xlabel('Iteration')
|
||||
plt.ylabel('Fitness')
|
||||
plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
|
||||
plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line
|
||||
plt.title("Fitness vs Iteration")
|
||||
plt.xlabel("Iteration")
|
||||
plt.ylabel("Fitness")
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
file = csv_file.with_name('tune_fitness.png') # filename
|
||||
file = csv_file.with_name("tune_fitness.png") # filename
|
||||
plt.savefig(file, dpi=200)
|
||||
plt.close()
|
||||
LOGGER.info(f'Saved {file}')
|
||||
LOGGER.info(f"Saved {file}")
|
||||
|
||||
|
||||
def output_to_target(output, max_det=300):
|
||||
|
|
@ -861,7 +958,7 @@ def output_to_rotated_target(output, max_det=300):
|
|||
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
||||
|
||||
|
||||
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
|
||||
def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
|
||||
"""
|
||||
Visualize feature maps of a given model module during inference.
|
||||
|
||||
|
|
@ -872,7 +969,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
|
|||
n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
|
||||
save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
|
||||
"""
|
||||
for m in ['Detect', 'Pose', 'Segment']:
|
||||
for m in ["Detect", "Pose", "Segment"]:
|
||||
if m in module_type:
|
||||
return
|
||||
batch, channels, height, width = x.shape # batch, channels, height, width
|
||||
|
|
@ -886,9 +983,9 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
|
|||
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
||||
for i in range(n):
|
||||
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
||||
ax[i].axis('off')
|
||||
ax[i].axis("off")
|
||||
|
||||
LOGGER.info(f'Saving {f}... ({n}/{channels})')
|
||||
plt.savefig(f, dpi=300, bbox_inches='tight')
|
||||
LOGGER.info(f"Saving {f}... ({n}/{channels})")
|
||||
plt.savefig(f, dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
|
||||
np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from .checks import check_version
|
|||
from .metrics import bbox_iou, probiou
|
||||
from .ops import xywhr2xyxyxyxy
|
||||
|
||||
TORCH_1_10 = check_version(torch.__version__, '1.10.0')
|
||||
TORCH_1_10 = check_version(torch.__version__, "1.10.0")
|
||||
|
||||
|
||||
class TaskAlignedAssigner(nn.Module):
|
||||
|
|
@ -61,12 +61,17 @@ class TaskAlignedAssigner(nn.Module):
|
|||
|
||||
if self.n_max_boxes == 0:
|
||||
device = gt_bboxes.device
|
||||
return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device),
|
||||
torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device))
|
||||
return (
|
||||
torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
|
||||
torch.zeros_like(pd_bboxes).to(device),
|
||||
torch.zeros_like(pd_scores).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
)
|
||||
|
||||
mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
|
||||
mask_gt)
|
||||
mask_pos, align_metric, overlaps = self.get_pos_mask(
|
||||
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
||||
)
|
||||
|
||||
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
||||
|
||||
|
|
@ -148,7 +153,7 @@ class TaskAlignedAssigner(nn.Module):
|
|||
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
|
||||
for k in range(self.topk):
|
||||
# Expand topk_idxs for each value of k and add 1 at the specified positions
|
||||
count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
|
||||
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
|
||||
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
|
||||
# Filter invalid bboxes
|
||||
count_tensor.masked_fill_(count_tensor > 1, 0)
|
||||
|
|
@ -192,9 +197,11 @@ class TaskAlignedAssigner(nn.Module):
|
|||
target_labels.clamp_(0)
|
||||
|
||||
# 10x faster than F.one_hot()
|
||||
target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
|
||||
dtype=torch.int64,
|
||||
device=target_labels.device) # (b, h*w, 80)
|
||||
target_scores = torch.zeros(
|
||||
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
|
||||
dtype=torch.int64,
|
||||
device=target_labels.device,
|
||||
) # (b, h*w, 80)
|
||||
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
||||
|
||||
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
|
||||
|
|
@ -252,7 +259,6 @@ class TaskAlignedAssigner(nn.Module):
|
|||
|
||||
|
||||
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
||||
|
||||
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
||||
"""Iou calculation for rotated bounding boxes."""
|
||||
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
|
||||
|
|
@ -295,7 +301,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
|
|||
_, _, h, w = feats[i].shape
|
||||
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
|
||||
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
|
||||
sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
|
||||
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
|
||||
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||||
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
||||
|
|
|
|||
|
|
@ -25,11 +25,11 @@ try:
|
|||
except ImportError:
|
||||
thop = None
|
||||
|
||||
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
|
||||
TORCH_2_0 = check_version(torch.__version__, '2.0.0')
|
||||
TORCHVISION_0_10 = check_version(torchvision.__version__, '0.10.0')
|
||||
TORCHVISION_0_11 = check_version(torchvision.__version__, '0.11.0')
|
||||
TORCHVISION_0_13 = check_version(torchvision.__version__, '0.13.0')
|
||||
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
|
||||
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
|
||||
TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
|
||||
TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")
|
||||
TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -60,13 +60,13 @@ def get_cpu_info():
|
|||
"""Return a string with system CPU information, i.e. 'Apple M2'."""
|
||||
import cpuinfo # pip install py-cpuinfo
|
||||
|
||||
k = 'brand_raw', 'hardware_raw', 'arch_string_raw' # info keys sorted by preference (not all keys always available)
|
||||
k = "brand_raw", "hardware_raw", "arch_string_raw" # info keys sorted by preference (not all keys always available)
|
||||
info = cpuinfo.get_cpu_info() # info dict
|
||||
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], 'unknown')
|
||||
return string.replace('(R)', '').replace('CPU ', '').replace('@ ', '')
|
||||
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
|
||||
return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
|
||||
|
||||
|
||||
def select_device(device='', batch=0, newline=False, verbose=True):
|
||||
def select_device(device="", batch=0, newline=False, verbose=True):
|
||||
"""
|
||||
Selects the appropriate PyTorch device based on the provided arguments.
|
||||
|
||||
|
|
@ -103,49 +103,57 @@ def select_device(device='', batch=0, newline=False, verbose=True):
|
|||
if isinstance(device, torch.device):
|
||||
return device
|
||||
|
||||
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||||
s = f"Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} "
|
||||
device = str(device).lower()
|
||||
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
||||
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||
cpu = device == 'cpu'
|
||||
mps = device in ('mps', 'mps:0') # Apple Metal Performance Shaders (MPS)
|
||||
for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
|
||||
device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||
cpu = device == "cpu"
|
||||
mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS)
|
||||
if cpu or mps:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
|
||||
elif device: # non-cpu device requested
|
||||
if device == 'cuda':
|
||||
device = '0'
|
||||
visible = os.environ.get('CUDA_VISIBLE_DEVICES', None)
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
||||
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
|
||||
if device == "cuda":
|
||||
device = "0"
|
||||
visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
|
||||
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(",", ""))):
|
||||
LOGGER.info(s)
|
||||
install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \
|
||||
'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else ''
|
||||
raise ValueError(f"Invalid CUDA 'device={device}' requested."
|
||||
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
|
||||
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
|
||||
f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}'
|
||||
f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}'
|
||||
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
|
||||
f'{install}')
|
||||
install = (
|
||||
"See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
|
||||
"CUDA devices are seen by torch.\n"
|
||||
if torch.cuda.device_count() == 0
|
||||
else ""
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid CUDA 'device={device}' requested."
|
||||
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
|
||||
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
|
||||
f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
|
||||
f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
|
||||
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
|
||||
f"{install}"
|
||||
)
|
||||
|
||||
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||||
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||
n = len(devices) # device count
|
||||
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
|
||||
raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
||||
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")
|
||||
space = ' ' * (len(s) + 1)
|
||||
raise ValueError(
|
||||
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
||||
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
|
||||
)
|
||||
space = " " * (len(s) + 1)
|
||||
for i, d in enumerate(devices):
|
||||
p = torch.cuda.get_device_properties(i)
|
||||
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
||||
arg = 'cuda:0'
|
||||
arg = "cuda:0"
|
||||
elif mps and TORCH_2_0 and torch.backends.mps.is_available():
|
||||
# Prefer MPS if available
|
||||
s += f'MPS ({get_cpu_info()})\n'
|
||||
arg = 'mps'
|
||||
s += f"MPS ({get_cpu_info()})\n"
|
||||
arg = "mps"
|
||||
else: # revert to CPU
|
||||
s += f'CPU ({get_cpu_info()})\n'
|
||||
arg = 'cpu'
|
||||
s += f"CPU ({get_cpu_info()})\n"
|
||||
arg = "cpu"
|
||||
|
||||
if verbose:
|
||||
LOGGER.info(s if newline else s.rstrip())
|
||||
|
|
@ -161,14 +169,20 @@ def time_sync():
|
|||
|
||||
def fuse_conv_and_bn(conv, bn):
|
||||
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
|
||||
fusedconv = nn.Conv2d(conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
dilation=conv.dilation,
|
||||
groups=conv.groups,
|
||||
bias=True).requires_grad_(False).to(conv.weight.device)
|
||||
fusedconv = (
|
||||
nn.Conv2d(
|
||||
conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
dilation=conv.dilation,
|
||||
groups=conv.groups,
|
||||
bias=True,
|
||||
)
|
||||
.requires_grad_(False)
|
||||
.to(conv.weight.device)
|
||||
)
|
||||
|
||||
# Prepare filters
|
||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||
|
|
@ -185,15 +199,21 @@ def fuse_conv_and_bn(conv, bn):
|
|||
|
||||
def fuse_deconv_and_bn(deconv, bn):
|
||||
"""Fuse ConvTranspose2d() and BatchNorm2d() layers."""
|
||||
fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
|
||||
deconv.out_channels,
|
||||
kernel_size=deconv.kernel_size,
|
||||
stride=deconv.stride,
|
||||
padding=deconv.padding,
|
||||
output_padding=deconv.output_padding,
|
||||
dilation=deconv.dilation,
|
||||
groups=deconv.groups,
|
||||
bias=True).requires_grad_(False).to(deconv.weight.device)
|
||||
fuseddconv = (
|
||||
nn.ConvTranspose2d(
|
||||
deconv.in_channels,
|
||||
deconv.out_channels,
|
||||
kernel_size=deconv.kernel_size,
|
||||
stride=deconv.stride,
|
||||
padding=deconv.padding,
|
||||
output_padding=deconv.output_padding,
|
||||
dilation=deconv.dilation,
|
||||
groups=deconv.groups,
|
||||
bias=True,
|
||||
)
|
||||
.requires_grad_(False)
|
||||
.to(deconv.weight.device)
|
||||
)
|
||||
|
||||
# Prepare filters
|
||||
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
|
||||
|
|
@ -221,18 +241,21 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
|
|||
n_l = len(list(model.modules())) # number of layers
|
||||
if detailed:
|
||||
LOGGER.info(
|
||||
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
|
||||
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}"
|
||||
)
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
name = name.replace('module_list.', '')
|
||||
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' %
|
||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))
|
||||
name = name.replace("module_list.", "")
|
||||
LOGGER.info(
|
||||
"%5g %40s %9s %12g %20s %10.3g %10.3g %10s"
|
||||
% (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype)
|
||||
)
|
||||
|
||||
flops = get_flops(model, imgsz)
|
||||
fused = ' (fused)' if getattr(model, 'is_fused', lambda: False)() else ''
|
||||
fs = f', {flops:.1f} GFLOPs' if flops else ''
|
||||
yaml_file = getattr(model, 'yaml_file', '') or getattr(model, 'yaml', {}).get('yaml_file', '')
|
||||
model_name = Path(yaml_file).stem.replace('yolo', 'YOLO') or 'Model'
|
||||
LOGGER.info(f'{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}')
|
||||
fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
|
||||
fs = f", {flops:.1f} GFLOPs" if flops else ""
|
||||
yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
|
||||
model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
|
||||
LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}")
|
||||
return n_l, n_p, n_g, flops
|
||||
|
||||
|
||||
|
|
@ -262,13 +285,15 @@ def model_info_for_loggers(trainer):
|
|||
"""
|
||||
if trainer.args.profile: # profile ONNX and TensorRT times
|
||||
from ultralytics.utils.benchmarks import ProfileModels
|
||||
|
||||
results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
|
||||
results.pop('model/name')
|
||||
results.pop("model/name")
|
||||
else: # only return PyTorch times from most recent validation
|
||||
results = {
|
||||
'model/parameters': get_num_params(trainer.model),
|
||||
'model/GFLOPs': round(get_flops(trainer.model), 3)}
|
||||
results['model/speed_PyTorch(ms)'] = round(trainer.validator.speed['inference'], 3)
|
||||
"model/parameters": get_num_params(trainer.model),
|
||||
"model/GFLOPs": round(get_flops(trainer.model), 3),
|
||||
}
|
||||
results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
|
||||
return results
|
||||
|
||||
|
||||
|
|
@ -284,14 +309,14 @@ def get_flops(model, imgsz=640):
|
|||
imgsz = [imgsz, imgsz] # expand if int/float
|
||||
try:
|
||||
# Use stride size for input tensor
|
||||
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
||||
stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
|
||||
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 # stride GFLOPs
|
||||
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
|
||||
return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
|
||||
except Exception:
|
||||
# Use actual image size for input tensor (i.e. required for RTDETR models)
|
||||
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
|
||||
return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 # imgsz GFLOPs
|
||||
return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
|
|
@ -301,11 +326,11 @@ def get_flops_with_torch_profiler(model, imgsz=640):
|
|||
if TORCH_2_0:
|
||||
model = de_parallel(model)
|
||||
p = next(model.parameters())
|
||||
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
|
||||
stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
|
||||
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
with torch.profiler.profile(with_flops=True) as prof:
|
||||
model(im)
|
||||
flops = sum(x.flops for x in prof.key_averages()) / 1E9
|
||||
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||||
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||||
return flops
|
||||
|
|
@ -333,7 +358,7 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
|
|||
return img
|
||||
h, w = img.shape[2:]
|
||||
s = (int(h * ratio), int(w * ratio)) # new size
|
||||
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
||||
img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize
|
||||
if not same_shape: # pad/crop img
|
||||
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
|
||||
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
||||
|
|
@ -349,7 +374,7 @@ def make_divisible(x, divisor):
|
|||
def copy_attr(a, b, include=(), exclude=()):
|
||||
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
|
||||
for k, v in b.__dict__.items():
|
||||
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
||||
if (len(include) and k not in include) or k.startswith("_") or k in exclude:
|
||||
continue
|
||||
else:
|
||||
setattr(a, k, v)
|
||||
|
|
@ -357,7 +382,7 @@ def copy_attr(a, b, include=(), exclude=()):
|
|||
|
||||
def get_latest_opset():
|
||||
"""Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
|
||||
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset
|
||||
return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 # opset
|
||||
|
||||
|
||||
def intersect_dicts(da, db, exclude=()):
|
||||
|
|
@ -392,10 +417,10 @@ def init_seeds(seed=0, deterministic=False):
|
|||
if TORCH_2_0:
|
||||
torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
|
||||
torch.backends.cudnn.deterministic = True
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
else:
|
||||
LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.')
|
||||
LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.")
|
||||
else:
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = False
|
||||
|
|
@ -430,13 +455,13 @@ class ModelEMA:
|
|||
v += (1 - d) * msd[k].detach()
|
||||
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
|
||||
|
||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||||
def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
|
||||
"""Updates attributes and saves stripped model with optimizer removed."""
|
||||
if self.enabled:
|
||||
copy_attr(self.ema, model, include, exclude)
|
||||
|
||||
|
||||
def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
|
||||
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
||||
"""
|
||||
Strip optimizer from 'f' to finalize training, optionally save as 's'.
|
||||
|
||||
|
|
@ -456,26 +481,26 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
|
|||
strip_optimizer(f)
|
||||
```
|
||||
"""
|
||||
x = torch.load(f, map_location=torch.device('cpu'))
|
||||
if 'model' not in x:
|
||||
LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.')
|
||||
x = torch.load(f, map_location=torch.device("cpu"))
|
||||
if "model" not in x:
|
||||
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
|
||||
return
|
||||
|
||||
if hasattr(x['model'], 'args'):
|
||||
x['model'].args = dict(x['model'].args) # convert from IterableSimpleNamespace to dict
|
||||
args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None # combine args
|
||||
if x.get('ema'):
|
||||
x['model'] = x['ema'] # replace model with ema
|
||||
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
||||
if hasattr(x["model"], "args"):
|
||||
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
|
||||
args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
|
||||
if x.get("ema"):
|
||||
x["model"] = x["ema"] # replace model with ema
|
||||
for k in "optimizer", "best_fitness", "ema", "updates": # keys
|
||||
x[k] = None
|
||||
x['epoch'] = -1
|
||||
x['model'].half() # to FP16
|
||||
for p in x['model'].parameters():
|
||||
x["epoch"] = -1
|
||||
x["model"].half() # to FP16
|
||||
for p in x["model"].parameters():
|
||||
p.requires_grad = False
|
||||
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
||||
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
||||
# x['model'].args = x['train_args']
|
||||
torch.save(x, s or f)
|
||||
mb = os.path.getsize(s or f) / 1E6 # file size
|
||||
mb = os.path.getsize(s or f) / 1e6 # file size
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
|
||||
|
||||
|
|
@ -496,18 +521,20 @@ def profile(input, ops, n=10, device=None):
|
|||
results = []
|
||||
if not isinstance(device, torch.device):
|
||||
device = select_device(device)
|
||||
LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||||
f"{'input':>24s}{'output':>24s}")
|
||||
LOGGER.info(
|
||||
f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||||
f"{'input':>24s}{'output':>24s}"
|
||||
)
|
||||
|
||||
for x in input if isinstance(input, list) else [input]:
|
||||
x = x.to(device)
|
||||
x.requires_grad = True
|
||||
for m in ops if isinstance(ops, list) else [ops]:
|
||||
m = m.to(device) if hasattr(m, 'to') else m # device
|
||||
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
||||
m = m.to(device) if hasattr(m, "to") else m # device
|
||||
m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
||||
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
|
||||
try:
|
||||
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1E9 * 2 if thop else 0 # GFLOPs
|
||||
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
|
||||
except Exception:
|
||||
flops = 0
|
||||
|
||||
|
|
@ -521,13 +548,13 @@ def profile(input, ops, n=10, device=None):
|
|||
t[2] = time_sync()
|
||||
except Exception: # no backward method
|
||||
# print(e) # for debug
|
||||
t[2] = float('nan')
|
||||
t[2] = float("nan")
|
||||
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
||||
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
||||
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
||||
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
|
||||
mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB)
|
||||
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
|
||||
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
||||
LOGGER.info(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
||||
LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
|
||||
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
||||
except Exception as e:
|
||||
LOGGER.info(e)
|
||||
|
|
@ -548,7 +575,7 @@ class EarlyStopping:
|
|||
"""
|
||||
self.best_fitness = 0.0 # i.e. mAP
|
||||
self.best_epoch = 0
|
||||
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
|
||||
self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop
|
||||
self.possible_stop = False # possible stop may occur next epoch
|
||||
|
||||
def __call__(self, epoch, fitness):
|
||||
|
|
@ -572,8 +599,10 @@ class EarlyStopping:
|
|||
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
||||
stop = delta >= self.patience # stop training if patience exceeded
|
||||
if stop:
|
||||
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
|
||||
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
|
||||
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
|
||||
f'i.e. `patience=300` or use `patience=0` to disable EarlyStopping.')
|
||||
LOGGER.info(
|
||||
f"Stopping training early as no improvement observed in last {self.patience} epochs. "
|
||||
f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
|
||||
f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
|
||||
f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
|
||||
)
|
||||
return stop
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class TritonRemoteModel:
|
|||
output_names (List[str]): The names of the model outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, endpoint: str = '', scheme: str = ''):
|
||||
def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
|
||||
"""
|
||||
Initialize the TritonRemoteModel.
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ class TritonRemoteModel:
|
|||
"""
|
||||
if not endpoint and not scheme: # Parse all args from URL string
|
||||
splits = urlsplit(url)
|
||||
endpoint = splits.path.strip('/').split('/')[0]
|
||||
endpoint = splits.path.strip("/").split("/")[0]
|
||||
scheme = splits.scheme
|
||||
url = splits.netloc
|
||||
|
||||
|
|
@ -44,26 +44,28 @@ class TritonRemoteModel:
|
|||
self.url = url
|
||||
|
||||
# Choose the Triton client based on the communication scheme
|
||||
if scheme == 'http':
|
||||
if scheme == "http":
|
||||
import tritonclient.http as client # noqa
|
||||
|
||||
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
||||
config = self.triton_client.get_model_config(endpoint)
|
||||
else:
|
||||
import tritonclient.grpc as client # noqa
|
||||
|
||||
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
||||
config = self.triton_client.get_model_config(endpoint, as_json=True)['config']
|
||||
config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
|
||||
|
||||
# Sort output names alphabetically, i.e. 'output0', 'output1', etc.
|
||||
config['output'] = sorted(config['output'], key=lambda x: x.get('name'))
|
||||
config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
|
||||
|
||||
# Define model attributes
|
||||
type_map = {'TYPE_FP32': np.float32, 'TYPE_FP16': np.float16, 'TYPE_UINT8': np.uint8}
|
||||
type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
|
||||
self.InferRequestedOutput = client.InferRequestedOutput
|
||||
self.InferInput = client.InferInput
|
||||
self.input_formats = [x['data_type'] for x in config['input']]
|
||||
self.input_formats = [x["data_type"] for x in config["input"]]
|
||||
self.np_input_formats = [type_map[x] for x in self.input_formats]
|
||||
self.input_names = [x['name'] for x in config['input']]
|
||||
self.output_names = [x['name'] for x in config['output']]
|
||||
self.input_names = [x["name"] for x in config["input"]]
|
||||
self.output_names = [x["name"] for x in config["output"]]
|
||||
|
||||
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
|
||||
"""
|
||||
|
|
@ -80,7 +82,7 @@ class TritonRemoteModel:
|
|||
for i, x in enumerate(inputs):
|
||||
if x.dtype != self.np_input_formats[i]:
|
||||
x = x.astype(self.np_input_formats[i])
|
||||
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace('TYPE_', ''))
|
||||
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
|
||||
infer_input.set_data_from_numpy(x)
|
||||
infer_inputs.append(infer_input)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,9 @@ from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir
|
|||
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS
|
||||
|
||||
|
||||
def run_ray_tune(model,
|
||||
space: dict = None,
|
||||
grace_period: int = 10,
|
||||
gpu_per_trial: int = None,
|
||||
max_samples: int = 10,
|
||||
**train_args):
|
||||
def run_ray_tune(
|
||||
model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args
|
||||
):
|
||||
"""
|
||||
Runs hyperparameter tuning using Ray Tune.
|
||||
|
||||
|
|
@ -38,12 +35,12 @@ def run_ray_tune(model,
|
|||
```
|
||||
"""
|
||||
|
||||
LOGGER.info('💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune')
|
||||
LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
|
||||
if train_args is None:
|
||||
train_args = {}
|
||||
|
||||
try:
|
||||
subprocess.run('pip install ray[tune]'.split(), check=True)
|
||||
subprocess.run("pip install ray[tune]".split(), check=True)
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
|
|
@ -56,33 +53,34 @@ def run_ray_tune(model,
|
|||
try:
|
||||
import wandb
|
||||
|
||||
assert hasattr(wandb, '__version__')
|
||||
assert hasattr(wandb, "__version__")
|
||||
except (ImportError, AssertionError):
|
||||
wandb = False
|
||||
|
||||
default_space = {
|
||||
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
||||
'lr0': tune.uniform(1e-5, 1e-1),
|
||||
'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
||||
'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
|
||||
'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
|
||||
'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
|
||||
'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
|
||||
'box': tune.uniform(0.02, 0.2), # box loss gain
|
||||
'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
|
||||
'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||
'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||
'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
|
||||
'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
|
||||
'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
|
||||
'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
|
||||
'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
|
||||
'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||
'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
|
||||
'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
|
||||
'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
|
||||
"lr0": tune.uniform(1e-5, 1e-1),
|
||||
"lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
||||
"momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
|
||||
"weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
|
||||
"warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
|
||||
"warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum
|
||||
"box": tune.uniform(0.02, 0.2), # box loss gain
|
||||
"cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
|
||||
"hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||
"hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||
"hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
|
||||
"degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg)
|
||||
"translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction)
|
||||
"scale": tune.uniform(0.0, 0.9), # image scale (+/- gain)
|
||||
"shear": tune.uniform(0.0, 10.0), # image shear (+/- deg)
|
||||
"perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||
"flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability)
|
||||
"fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability)
|
||||
"mosaic": tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
"mixup": tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
"copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability)
|
||||
}
|
||||
|
||||
# Put the model in ray store
|
||||
task = model.task
|
||||
|
|
@ -107,35 +105,39 @@ def run_ray_tune(model,
|
|||
# Get search space
|
||||
if not space:
|
||||
space = default_space
|
||||
LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.')
|
||||
LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.")
|
||||
|
||||
# Get dataset
|
||||
data = train_args.get('data', TASK2DATA[task])
|
||||
space['data'] = data
|
||||
if 'data' not in train_args:
|
||||
data = train_args.get("data", TASK2DATA[task])
|
||||
space["data"] = data
|
||||
if "data" not in train_args:
|
||||
LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
|
||||
|
||||
# Define the trainable function with allocated resources
|
||||
trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0})
|
||||
trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
|
||||
|
||||
# Define the ASHA scheduler for hyperparameter search
|
||||
asha_scheduler = ASHAScheduler(time_attr='epoch',
|
||||
metric=TASK2METRIC[task],
|
||||
mode='max',
|
||||
max_t=train_args.get('epochs') or DEFAULT_CFG_DICT['epochs'] or 100,
|
||||
grace_period=grace_period,
|
||||
reduction_factor=3)
|
||||
asha_scheduler = ASHAScheduler(
|
||||
time_attr="epoch",
|
||||
metric=TASK2METRIC[task],
|
||||
mode="max",
|
||||
max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
|
||||
grace_period=grace_period,
|
||||
reduction_factor=3,
|
||||
)
|
||||
|
||||
# Define the callbacks for the hyperparameter search
|
||||
tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else []
|
||||
tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
|
||||
|
||||
# Create the Ray Tune hyperparameter search tuner
|
||||
tune_dir = get_save_dir(DEFAULT_CFG, name='tune').resolve() # must be absolute dir
|
||||
tune_dir = get_save_dir(DEFAULT_CFG, name="tune").resolve() # must be absolute dir
|
||||
tune_dir.mkdir(parents=True, exist_ok=True)
|
||||
tuner = tune.Tuner(trainable_with_resources,
|
||||
param_space=space,
|
||||
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
|
||||
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir))
|
||||
tuner = tune.Tuner(
|
||||
trainable_with_resources,
|
||||
param_space=space,
|
||||
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
|
||||
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir),
|
||||
)
|
||||
|
||||
# Run the hyperparameter search
|
||||
tuner.fit()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue