ultralytics 8.3.8 replace contextlib with try for speed (#16782)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
1e6c454460
commit
a6a577961f
12 changed files with 115 additions and 88 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.3.7"
|
||||
__version__ = "8.3.8"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -669,8 +669,9 @@ def smart_value(v):
|
|||
elif v_lower == "false":
|
||||
return False
|
||||
else:
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
return eval(v)
|
||||
except: # noqa E722
|
||||
return v
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from itertools import repeat
|
||||
|
|
@ -483,7 +482,7 @@ class ClassificationDataset:
|
|||
desc = f"{self.prefix}Scanning {self.root}..."
|
||||
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
||||
|
||||
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
||||
try:
|
||||
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
||||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||
|
|
@ -495,6 +494,7 @@ class ClassificationDataset:
|
|||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
return samples
|
||||
|
||||
except (FileNotFoundError, AssertionError, AttributeError):
|
||||
# Run scan if *.cache retrieval failed
|
||||
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
|
@ -60,12 +59,14 @@ def exif_size(img: Image.Image):
|
|||
"""Returns exif-corrected PIL size."""
|
||||
s = img.size # (width, height)
|
||||
if img.format == "JPEG": # only support JPEG images
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
exif = img.getexif()
|
||||
if exif:
|
||||
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
|
||||
if rotation in {6, 8}: # rotation 270 or 90
|
||||
s = s[1], s[0]
|
||||
except: # noqa E722
|
||||
pass
|
||||
return s
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import json
|
||||
import platform
|
||||
import zipfile
|
||||
|
|
@ -45,8 +44,10 @@ def check_class_names(names):
|
|||
def default_class_names(data=None):
|
||||
"""Applies default class names to an input YAML file or returns numerical class names."""
|
||||
if data:
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
return yaml_load(check_yaml(data))["names"]
|
||||
except: # noqa E722
|
||||
pass
|
||||
return {i: f"class{i}" for i in range(999)} # return default if above errors
|
||||
|
||||
|
||||
|
|
@ -321,8 +322,10 @@ class AutoBackend(nn.Module):
|
|||
with open(w, "rb") as f:
|
||||
gd.ParseFromString(f.read())
|
||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
||||
with contextlib.suppress(StopIteration): # find metadata in SavedModel alongside GraphDef
|
||||
try: # find metadata in SavedModel alongside GraphDef
|
||||
metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
# TFLite or TFLite Edge TPU
|
||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||
|
|
@ -345,10 +348,12 @@ class AutoBackend(nn.Module):
|
|||
input_details = interpreter.get_input_details() # inputs
|
||||
output_details = interpreter.get_output_details() # outputs
|
||||
# Load metadata
|
||||
with contextlib.suppress(zipfile.BadZipFile):
|
||||
try:
|
||||
with zipfile.ZipFile(w, "r") as model:
|
||||
meta_file = model.namelist()[0]
|
||||
metadata = ast.literal_eval(model.read(meta_file).decode("utf-8"))
|
||||
except zipfile.BadZipFile:
|
||||
pass
|
||||
|
||||
# TF.js
|
||||
elif tfjs:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import contextlib
|
||||
import pickle
|
||||
import re
|
||||
import types
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
|
@ -958,8 +959,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
|
||||
for j, a in enumerate(args):
|
||||
if isinstance(a, str):
|
||||
with contextlib.suppress(ValueError):
|
||||
try:
|
||||
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
||||
if m in {
|
||||
|
|
@ -1072,8 +1075,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
|
||||
def yaml_model_load(path):
|
||||
"""Load a YOLOv8 model from a YAML file."""
|
||||
import re
|
||||
|
||||
path = Path(path)
|
||||
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
|
||||
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
|
||||
|
|
@ -1100,10 +1101,9 @@ def guess_model_scale(model_path):
|
|||
Returns:
|
||||
(str): The size character of the model's scale, which can be n, s, m, l, or x.
|
||||
"""
|
||||
with contextlib.suppress(AttributeError):
|
||||
import re
|
||||
|
||||
try:
|
||||
return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
|
||||
except AttributeError:
|
||||
return ""
|
||||
|
||||
|
||||
|
|
@ -1137,17 +1137,23 @@ def guess_model_task(model):
|
|||
|
||||
# Guess from model cfg
|
||||
if isinstance(model, dict):
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
return cfg2task(model)
|
||||
except: # noqa E722
|
||||
pass
|
||||
|
||||
# Guess from PyTorch model
|
||||
if isinstance(model, nn.Module): # PyTorch model
|
||||
for x in "model.args", "model.model.args", "model.model.model.args":
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
return eval(x)["task"]
|
||||
except: # noqa E722
|
||||
pass
|
||||
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
return cfg2task(eval(x))
|
||||
except: # noqa E722
|
||||
pass
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, Segment):
|
||||
|
|
|
|||
|
|
@ -523,9 +523,10 @@ def read_device_model() -> str:
|
|||
Returns:
|
||||
(str): Model file contents if read successfully or empty string otherwise.
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
with open("/proc/device-tree/model") as f:
|
||||
return f.read()
|
||||
except: # noqa E722
|
||||
return ""
|
||||
|
||||
|
||||
|
|
@ -536,9 +537,10 @@ def is_ubuntu() -> bool:
|
|||
Returns:
|
||||
(bool): True if OS is Ubuntu, False otherwise.
|
||||
"""
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
try:
|
||||
with open("/etc/os-release") as f:
|
||||
return "ID=ubuntu" in f.read()
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -569,11 +571,7 @@ def is_jupyter():
|
|||
Returns:
|
||||
(bool): True if running inside a Jupyter Notebook, False otherwise.
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
from IPython import get_ipython
|
||||
|
||||
return get_ipython() is not None
|
||||
return False
|
||||
return "get_ipython" in locals()
|
||||
|
||||
|
||||
def is_docker() -> bool:
|
||||
|
|
@ -583,9 +581,10 @@ def is_docker() -> bool:
|
|||
Returns:
|
||||
(bool): True if the script is running inside a Docker container, False otherwise.
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
with open("/proc/self/cgroup") as f:
|
||||
return "docker" in f.read()
|
||||
except: # noqa E722
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -617,13 +616,14 @@ def is_online() -> bool:
|
|||
Returns:
|
||||
(bool): True if connection is successful, False otherwise.
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
assert str(os.getenv("YOLO_OFFLINE", "")).lower() != "true" # check if ENV var YOLO_OFFLINE="True"
|
||||
import socket
|
||||
|
||||
for dns in ("1.1.1.1", "8.8.8.8"): # check Cloudflare and Google DNS
|
||||
socket.create_connection(address=(dns, 80), timeout=2.0).close()
|
||||
return True
|
||||
except: # noqa E722
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -711,9 +711,11 @@ def get_git_origin_url():
|
|||
(str | None): The origin URL of the git repository or None if not git directory.
|
||||
"""
|
||||
if IS_GIT_DIR:
|
||||
with contextlib.suppress(subprocess.CalledProcessError):
|
||||
try:
|
||||
origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
|
||||
return origin.decode().strip()
|
||||
except subprocess.CalledProcessError:
|
||||
return None
|
||||
|
||||
|
||||
def get_git_branch():
|
||||
|
|
@ -724,9 +726,11 @@ def get_git_branch():
|
|||
(str | None): The current git branch name or None if not a git directory.
|
||||
"""
|
||||
if IS_GIT_DIR:
|
||||
with contextlib.suppress(subprocess.CalledProcessError):
|
||||
try:
|
||||
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
||||
return origin.decode().strip()
|
||||
except subprocess.CalledProcessError:
|
||||
return None
|
||||
|
||||
|
||||
def get_default_args(func):
|
||||
|
|
@ -751,9 +755,11 @@ def get_ubuntu_version():
|
|||
(str): Ubuntu version or None if not an Ubuntu OS.
|
||||
"""
|
||||
if is_ubuntu():
|
||||
with contextlib.suppress(FileNotFoundError, AttributeError):
|
||||
try:
|
||||
with open("/etc/os-release") as f:
|
||||
return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]
|
||||
except (FileNotFoundError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def get_user_config_dir(sub_dir="Ultralytics"):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
|
||||
|
||||
|
|
@ -45,12 +44,13 @@ def _log_tensorboard_graph(trainer):
|
|||
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
|
||||
|
||||
# Try simple method first (YOLO)
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
|
||||
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
|
||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||
return
|
||||
|
||||
except: # noqa E722
|
||||
# Fallback to TorchScript export steps (RTDETR)
|
||||
try:
|
||||
model = deepcopy(de_parallel(trainer.model))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import glob
|
||||
import inspect
|
||||
import math
|
||||
|
|
@ -271,11 +270,13 @@ def check_latest_pypi_version(package_name="ultralytics"):
|
|||
Returns:
|
||||
(str): The latest version of the package.
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
|
||||
response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
|
||||
if response.status_code == 200:
|
||||
return response.json()["info"]["version"]
|
||||
except: # noqa E722
|
||||
return None
|
||||
|
||||
|
||||
def check_pip_update_available():
|
||||
|
|
@ -286,7 +287,7 @@ def check_pip_update_available():
|
|||
(bool): True if an update is available, False otherwise.
|
||||
"""
|
||||
if ONLINE and IS_PIP_PACKAGE:
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
from ultralytics import __version__
|
||||
|
||||
latest = check_latest_pypi_version()
|
||||
|
|
@ -296,6 +297,8 @@ def check_pip_update_available():
|
|||
f"Update with 'pip install -U ultralytics'"
|
||||
)
|
||||
return True
|
||||
except: # noqa E722
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -577,10 +580,12 @@ def check_yolo(verbose=True, device=""):
|
|||
ram = psutil.virtual_memory().total
|
||||
total, used, free = shutil.disk_usage("/")
|
||||
s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
|
||||
with contextlib.suppress(Exception): # clear display if ipython is installed
|
||||
try:
|
||||
from IPython import display
|
||||
|
||||
display.clear_output()
|
||||
display.clear_output() # clear display if notebook
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
s = ""
|
||||
|
||||
|
|
@ -707,8 +712,9 @@ def check_amp(model):
|
|||
|
||||
def git_describe(path=ROOT): # path must be a directory
|
||||
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
|
||||
except: # noqa E722
|
||||
return ""
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
|
|
@ -53,7 +52,7 @@ def is_url(url, check=False):
|
|||
valid = is_url("https://www.example.com")
|
||||
```
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
url = str(url)
|
||||
result = parse.urlparse(url)
|
||||
assert all([result.scheme, result.netloc]) # check if is url
|
||||
|
|
@ -61,6 +60,7 @@ def is_url(url, check=False):
|
|||
with request.urlopen(url) as response:
|
||||
return response.getcode() == 200 # check if exists online
|
||||
return True
|
||||
except: # noqa E722
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
|
@ -1115,10 +1114,12 @@ def plot_images(
|
|||
mask = mask.astype(bool)
|
||||
else:
|
||||
mask = image_masks[j].astype(bool)
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
im[y : y + h, x : x + w, :][mask] = (
|
||||
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
||||
)
|
||||
except: # noqa E722
|
||||
pass
|
||||
annotator.fromarray(im)
|
||||
if not save:
|
||||
return np.asarray(annotator.im)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
|
|
@ -113,13 +112,15 @@ def get_cpu_info():
|
|||
from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
|
||||
|
||||
if "cpu_info" not in PERSISTENT_CACHE:
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
import cpuinfo # pip install py-cpuinfo
|
||||
|
||||
k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference
|
||||
info = cpuinfo.get_cpu_info() # info dict
|
||||
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
|
||||
PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
|
||||
except: # noqa E722
|
||||
pass
|
||||
return PERSISTENT_CACHE.get("cpu_info", "unknown")
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue