ultralytics 8.1.43 40% faster ultralytics imports (#9547)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-04-05 15:29:09 +02:00 committed by GitHub
parent 99c61d6f7b
commit a2628657a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 240 additions and 225 deletions

View file

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import importlib.metadata
import inspect
import logging.config
import os
@ -42,6 +43,8 @@ TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar form
LOGGING_NAME = "ultralytics"
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
PYTHON_VERSION = platform.python_version()
TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
HELP_MSG = """
Usage examples for running YOLOv8:
@ -476,7 +479,7 @@ def is_online() -> bool:
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
try:
test_connection = socket.create_connection(address=(host, 53), timeout=2)
test_connection = socket.create_connection(address=(host, 80), timeout=2)
except (socket.timeout, socket.gaierror, OSError):
continue
else:

View file

@ -69,8 +69,7 @@ def benchmark(
benchmark(model='yolov8n.pt', imgsz=640)
```
"""
import pandas as pd
import pandas as pd # scope for faster 'import ultralytics'
pd.options.display.max_columns = 10
pd.options.display.width = 120

View file

@ -7,8 +7,6 @@ try:
assert SETTINGS["clearml"] is True # verify integration is enabled
import clearml
from clearml import Task
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from clearml.binding.matplotlib_bind import PatchedMatplotlib
assert hasattr(clearml, "__version__") # verify package is not directory
@ -61,8 +59,11 @@ def on_pretrain_routine_start(trainer):
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
try:
if task := Task.current_task():
# Make sure the automatic pytorch and matplotlib bindings are disabled!
# WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
# We are logging these plots and model files manually in the integration
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from clearml.binding.matplotlib_bind import PatchedMatplotlib
PatchPyTorchModelIO.update_current_task(None)
PatchedMatplotlib.update_current_task(None)
else:

View file

@ -9,10 +9,6 @@ try:
import wandb as wb
assert hasattr(wb, "__version__") # verify package is not directory
import numpy as np
import pandas as pd
_processed_plots = {}
except (ImportError, AssertionError):
@ -38,7 +34,9 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
Returns:
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
"""
df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
import pandas # scope for faster 'import ultralytics'
df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
fields = {"x": "x", "y": "y", "class": "class"}
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
return wb.plot_table(
@ -77,6 +75,8 @@ def _plot_curve(
Note:
The function leverages the '_custom_table' function to generate the actual visualization.
"""
import numpy as np
# Create new x
if names is None:
names = []

View file

@ -18,15 +18,16 @@ import cv2
import numpy as np
import requests
import torch
from matplotlib import font_manager
from ultralytics.utils import (
ASSETS,
AUTOINSTALL,
LINUX,
LOGGER,
PYTHON_VERSION,
ONLINE,
ROOT,
TORCHVISION_VERSION,
USER_CONFIG_DIR,
Retry,
SimpleNamespace,
@ -41,13 +42,10 @@ from ultralytics.utils import (
is_github_action_running,
is_jupyter,
is_kaggle,
is_online,
is_pip_package,
url2file,
)
PYTHON_VERSION = platform.python_version()
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
"""
@ -304,9 +302,10 @@ def check_font(font="Arial.ttf"):
Returns:
file (Path): Resolved font file path.
"""
name = Path(font).name
from matplotlib import font_manager
# Check USER_CONFIG_DIR
name = Path(font).name
file = USER_CONFIG_DIR / name
if file.exists():
return file
@ -390,7 +389,7 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
try:
t = time.time()
assert is_online(), "AutoUpdate skipped (offline)"
assert ONLINE, "AutoUpdate skipped (offline)"
with Retry(times=2, delay=1): # run up to 2 times with 1-second retry delay
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
dt = time.time() - t
@ -419,14 +418,12 @@ def check_torchvision():
Torchvision versions.
"""
import torchvision
# Compatibility table
compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
# Extract only the major and minor versions
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2])
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
if v_torch in compatibility_table:
compatible_versions = compatibility_table[v_torch]

View file

@ -395,19 +395,19 @@ class ConfusionMatrix:
names (tuple): Names of classes, used as labels on the plot.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
"""
import seaborn as sn
import seaborn # scope for faster 'import ultralytics'
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names
sn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
ticklabels = (list(names) + ["background"]) if labels else "auto"
with warnings.catch_warnings():
warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(
seaborn.heatmap(
array,
ax=ax,
annot=nc < 30,

View file

@ -9,7 +9,6 @@ import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from ultralytics.utils import LOGGER
from ultralytics.utils.metrics import batch_probiou
@ -206,6 +205,7 @@ def non_max_suppression(
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
"""
import torchvision # scope for faster 'import ultralytics'
# Checks
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"

View file

@ -671,8 +671,8 @@ class Annotator:
@plt_settings()
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
"""Plot training labels including class histograms and box statistics."""
import pandas as pd
import seaborn as sn
import pandas # scope for faster 'import ultralytics'
import seaborn # scope for faster 'import ultralytics'
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
@ -682,10 +682,10 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
nc = int(cls.max() + 1) # number of classes
boxes = boxes[:1000000] # limit to 1M boxes
x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"])
x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
# Seaborn correlogram
sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
plt.close()
@ -700,8 +700,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
else:
ax[0].set_xlabel("classes")
sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
# Rectangles
boxes[:, 0:2] = 0.5 # center
@ -933,7 +933,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
plot_results('path/to/results.csv', segment=True)
```
"""
import pandas as pd
import pandas as pd # scope for faster 'import ultralytics'
from scipy.ndimage import gaussian_filter1d
save_dir = Path(file).parent if file else Path(dir)
@ -1019,7 +1019,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
>>> plot_tune_results('path/to/tune_results.csv')
"""
import pandas as pd
import pandas as pd # scope for faster 'import ultralytics'
from scipy.ndimage import gaussian_filter1d
# Scatter plots for each hyperparameter

View file

@ -14,10 +14,17 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, __version__
from ultralytics.utils.checks import PYTHON_VERSION, check_version
from ultralytics.utils import (
DEFAULT_CFG_DICT,
DEFAULT_CFG_KEYS,
LOGGER,
PYTHON_VERSION,
TORCHVISION_VERSION,
colorstr,
__version__,
)
from ultralytics.utils.checks import check_version
try:
import thop
@ -28,9 +35,9 @@ except ImportError:
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")
TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
@contextmanager