ultralytics 8.1.43 40% faster ultralytics imports (#9547)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
99c61d6f7b
commit
a2628657a1
21 changed files with 240 additions and 225 deletions
|
|
@ -1,6 +1,7 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import logging.config
|
||||
import os
|
||||
|
|
@ -42,6 +43,8 @@ TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar form
|
|||
LOGGING_NAME = "ultralytics"
|
||||
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
|
||||
ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
|
||||
PYTHON_VERSION = platform.python_version()
|
||||
TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
|
||||
HELP_MSG = """
|
||||
Usage examples for running YOLOv8:
|
||||
|
||||
|
|
@ -476,7 +479,7 @@ def is_online() -> bool:
|
|||
|
||||
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
|
||||
try:
|
||||
test_connection = socket.create_connection(address=(host, 53), timeout=2)
|
||||
test_connection = socket.create_connection(address=(host, 80), timeout=2)
|
||||
except (socket.timeout, socket.gaierror, OSError):
|
||||
continue
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -69,8 +69,7 @@ def benchmark(
|
|||
benchmark(model='yolov8n.pt', imgsz=640)
|
||||
```
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import pandas as pd # scope for faster 'import ultralytics'
|
||||
|
||||
pd.options.display.max_columns = 10
|
||||
pd.options.display.width = 120
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@ try:
|
|||
assert SETTINGS["clearml"] is True # verify integration is enabled
|
||||
import clearml
|
||||
from clearml import Task
|
||||
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
||||
|
||||
assert hasattr(clearml, "__version__") # verify package is not directory
|
||||
|
||||
|
|
@ -61,8 +59,11 @@ def on_pretrain_routine_start(trainer):
|
|||
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
|
||||
try:
|
||||
if task := Task.current_task():
|
||||
# Make sure the automatic pytorch and matplotlib bindings are disabled!
|
||||
# WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
|
||||
# We are logging these plots and model files manually in the integration
|
||||
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
||||
|
||||
PatchPyTorchModelIO.update_current_task(None)
|
||||
PatchedMatplotlib.update_current_task(None)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -9,10 +9,6 @@ try:
|
|||
import wandb as wb
|
||||
|
||||
assert hasattr(wb, "__version__") # verify package is not directory
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
_processed_plots = {}
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
|
|
@ -38,7 +34,9 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
|
|||
Returns:
|
||||
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
|
||||
"""
|
||||
df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
||||
import pandas # scope for faster 'import ultralytics'
|
||||
|
||||
df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
||||
fields = {"x": "x", "y": "y", "class": "class"}
|
||||
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
|
||||
return wb.plot_table(
|
||||
|
|
@ -77,6 +75,8 @@ def _plot_curve(
|
|||
Note:
|
||||
The function leverages the '_custom_table' function to generate the actual visualization.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Create new x
|
||||
if names is None:
|
||||
names = []
|
||||
|
|
|
|||
|
|
@ -18,15 +18,16 @@ import cv2
|
|||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from matplotlib import font_manager
|
||||
|
||||
from ultralytics.utils import (
|
||||
ASSETS,
|
||||
AUTOINSTALL,
|
||||
LINUX,
|
||||
LOGGER,
|
||||
PYTHON_VERSION,
|
||||
ONLINE,
|
||||
ROOT,
|
||||
TORCHVISION_VERSION,
|
||||
USER_CONFIG_DIR,
|
||||
Retry,
|
||||
SimpleNamespace,
|
||||
|
|
@ -41,13 +42,10 @@ from ultralytics.utils import (
|
|||
is_github_action_running,
|
||||
is_jupyter,
|
||||
is_kaggle,
|
||||
is_online,
|
||||
is_pip_package,
|
||||
url2file,
|
||||
)
|
||||
|
||||
PYTHON_VERSION = platform.python_version()
|
||||
|
||||
|
||||
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
|
||||
"""
|
||||
|
|
@ -304,9 +302,10 @@ def check_font(font="Arial.ttf"):
|
|||
Returns:
|
||||
file (Path): Resolved font file path.
|
||||
"""
|
||||
name = Path(font).name
|
||||
from matplotlib import font_manager
|
||||
|
||||
# Check USER_CONFIG_DIR
|
||||
name = Path(font).name
|
||||
file = USER_CONFIG_DIR / name
|
||||
if file.exists():
|
||||
return file
|
||||
|
|
@ -390,7 +389,7 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
|
|||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
|
||||
try:
|
||||
t = time.time()
|
||||
assert is_online(), "AutoUpdate skipped (offline)"
|
||||
assert ONLINE, "AutoUpdate skipped (offline)"
|
||||
with Retry(times=2, delay=1): # run up to 2 times with 1-second retry delay
|
||||
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
|
||||
dt = time.time() - t
|
||||
|
|
@ -419,14 +418,12 @@ def check_torchvision():
|
|||
Torchvision versions.
|
||||
"""
|
||||
|
||||
import torchvision
|
||||
|
||||
# Compatibility table
|
||||
compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
|
||||
|
||||
# Extract only the major and minor versions
|
||||
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
|
||||
v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2])
|
||||
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
|
||||
|
||||
if v_torch in compatibility_table:
|
||||
compatible_versions = compatibility_table[v_torch]
|
||||
|
|
|
|||
|
|
@ -395,19 +395,19 @@ class ConfusionMatrix:
|
|||
names (tuple): Names of classes, used as labels on the plot.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
||||
"""
|
||||
import seaborn as sn
|
||||
import seaborn # scope for faster 'import ultralytics'
|
||||
|
||||
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
|
||||
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
||||
nc, nn = self.nc, len(names) # number of classes, names
|
||||
sn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
||||
seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
||||
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
||||
ticklabels = (list(names) + ["background"]) if labels else "auto"
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
||||
sn.heatmap(
|
||||
seaborn.heatmap(
|
||||
array,
|
||||
ax=ax,
|
||||
annot=nc < 30,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import cv2
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.metrics import batch_probiou
|
||||
|
|
@ -206,6 +205,7 @@ def non_max_suppression(
|
|||
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
|
||||
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
# Checks
|
||||
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
||||
|
|
|
|||
|
|
@ -671,8 +671,8 @@ class Annotator:
|
|||
@plt_settings()
|
||||
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
||||
"""Plot training labels including class histograms and box statistics."""
|
||||
import pandas as pd
|
||||
import seaborn as sn
|
||||
import pandas # scope for faster 'import ultralytics'
|
||||
import seaborn # scope for faster 'import ultralytics'
|
||||
|
||||
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
|
||||
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
|
||||
|
|
@ -682,10 +682,10 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|||
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
||||
nc = int(cls.max() + 1) # number of classes
|
||||
boxes = boxes[:1000000] # limit to 1M boxes
|
||||
x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"])
|
||||
x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
|
||||
|
||||
# Seaborn correlogram
|
||||
sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
||||
seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
||||
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
|
||||
plt.close()
|
||||
|
||||
|
|
@ -700,8 +700,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|||
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
||||
else:
|
||||
ax[0].set_xlabel("classes")
|
||||
sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
|
||||
sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
|
||||
seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
|
||||
seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
|
||||
|
||||
# Rectangles
|
||||
boxes[:, 0:2] = 0.5 # center
|
||||
|
|
@ -933,7 +933,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
|
|||
plot_results('path/to/results.csv', segment=True)
|
||||
```
|
||||
"""
|
||||
import pandas as pd
|
||||
import pandas as pd # scope for faster 'import ultralytics'
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
save_dir = Path(file).parent if file else Path(dir)
|
||||
|
|
@ -1019,7 +1019,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
|
|||
>>> plot_tune_results('path/to/tune_results.csv')
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import pandas as pd # scope for faster 'import ultralytics'
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
# Scatter plots for each hyperparameter
|
||||
|
|
|
|||
|
|
@ -14,10 +14,17 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, __version__
|
||||
from ultralytics.utils.checks import PYTHON_VERSION, check_version
|
||||
from ultralytics.utils import (
|
||||
DEFAULT_CFG_DICT,
|
||||
DEFAULT_CFG_KEYS,
|
||||
LOGGER,
|
||||
PYTHON_VERSION,
|
||||
TORCHVISION_VERSION,
|
||||
colorstr,
|
||||
__version__,
|
||||
)
|
||||
from ultralytics.utils.checks import check_version
|
||||
|
||||
try:
|
||||
import thop
|
||||
|
|
@ -28,9 +35,9 @@ except ImportError:
|
|||
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
|
||||
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
|
||||
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
|
||||
TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
|
||||
TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")
|
||||
TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
|
||||
TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
|
||||
TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
|
||||
TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue