ultralytics 8.3.8 replace contextlib with try for speed (#16782)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-10-08 21:02:40 +02:00 committed by GitHub
parent 1e6c454460
commit a6a577961f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 115 additions and 88 deletions

View file

@ -523,10 +523,11 @@ def read_device_model() -> str:
Returns:
(str): Model file contents if read successfully or empty string otherwise.
"""
with contextlib.suppress(Exception):
try:
with open("/proc/device-tree/model") as f:
return f.read()
return ""
except: # noqa E722
return ""
def is_ubuntu() -> bool:
@ -536,10 +537,11 @@ def is_ubuntu() -> bool:
Returns:
(bool): True if OS is Ubuntu, False otherwise.
"""
with contextlib.suppress(FileNotFoundError):
try:
with open("/etc/os-release") as f:
return "ID=ubuntu" in f.read()
return False
except FileNotFoundError:
return False
def is_colab():
@ -569,11 +571,7 @@ def is_jupyter():
Returns:
(bool): True if running inside a Jupyter Notebook, False otherwise.
"""
with contextlib.suppress(Exception):
from IPython import get_ipython
return get_ipython() is not None
return False
return "get_ipython" in locals()
def is_docker() -> bool:
@ -583,10 +581,11 @@ def is_docker() -> bool:
Returns:
(bool): True if the script is running inside a Docker container, False otherwise.
"""
with contextlib.suppress(Exception):
try:
with open("/proc/self/cgroup") as f:
return "docker" in f.read()
return False
except: # noqa E722
return False
def is_raspberrypi() -> bool:
@ -617,14 +616,15 @@ def is_online() -> bool:
Returns:
(bool): True if connection is successful, False otherwise.
"""
with contextlib.suppress(Exception):
try:
assert str(os.getenv("YOLO_OFFLINE", "")).lower() != "true" # check if ENV var YOLO_OFFLINE="True"
import socket
for dns in ("1.1.1.1", "8.8.8.8"): # check Cloudflare and Google DNS
socket.create_connection(address=(dns, 80), timeout=2.0).close()
return True
return False
except: # noqa E722
return False
def is_pip_package(filepath: str = __name__) -> bool:
@ -711,9 +711,11 @@ def get_git_origin_url():
(str | None): The origin URL of the git repository or None if not git directory.
"""
if IS_GIT_DIR:
with contextlib.suppress(subprocess.CalledProcessError):
try:
origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
return origin.decode().strip()
except subprocess.CalledProcessError:
return None
def get_git_branch():
@ -724,9 +726,11 @@ def get_git_branch():
(str | None): The current git branch name or None if not a git directory.
"""
if IS_GIT_DIR:
with contextlib.suppress(subprocess.CalledProcessError):
try:
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
return origin.decode().strip()
except subprocess.CalledProcessError:
return None
def get_default_args(func):
@ -751,9 +755,11 @@ def get_ubuntu_version():
(str): Ubuntu version or None if not an Ubuntu OS.
"""
if is_ubuntu():
with contextlib.suppress(FileNotFoundError, AttributeError):
try:
with open("/etc/os-release") as f:
return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]
except (FileNotFoundError, AttributeError):
return None
def get_user_config_dir(sub_dir="Ultralytics"):

View file

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
@ -45,26 +44,27 @@ def _log_tensorboard_graph(trainer):
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
# Try simple method first (YOLO)
with contextlib.suppress(Exception):
try:
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
return
# Fallback to TorchScript export steps (RTDETR)
try:
model = deepcopy(de_parallel(trainer.model))
model.eval()
model = model.fuse(verbose=False)
for m in model.modules():
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
m.export = True
m.format = "torchscript"
model(im) # dry run
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
except: # noqa E722
# Fallback to TorchScript export steps (RTDETR)
try:
model = deepcopy(de_parallel(trainer.model))
model.eval()
model = model.fuse(verbose=False)
for m in model.modules():
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
m.export = True
m.format = "torchscript"
model(im) # dry run
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
def on_pretrain_routine_start(trainer):

View file

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import glob
import inspect
import math
@ -271,11 +270,13 @@ def check_latest_pypi_version(package_name="ultralytics"):
Returns:
(str): The latest version of the package.
"""
with contextlib.suppress(Exception):
try:
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
if response.status_code == 200:
return response.json()["info"]["version"]
except: # noqa E722
return None
def check_pip_update_available():
@ -286,7 +287,7 @@ def check_pip_update_available():
(bool): True if an update is available, False otherwise.
"""
if ONLINE and IS_PIP_PACKAGE:
with contextlib.suppress(Exception):
try:
from ultralytics import __version__
latest = check_latest_pypi_version()
@ -296,6 +297,8 @@ def check_pip_update_available():
f"Update with 'pip install -U ultralytics'"
)
return True
except: # noqa E722
pass
return False
@ -577,10 +580,12 @@ def check_yolo(verbose=True, device=""):
ram = psutil.virtual_memory().total
total, used, free = shutil.disk_usage("/")
s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
with contextlib.suppress(Exception): # clear display if ipython is installed
try:
from IPython import display
display.clear_output()
display.clear_output() # clear display if notebook
except ImportError:
pass
else:
s = ""
@ -707,9 +712,10 @@ def check_amp(model):
def git_describe(path=ROOT): # path must be a directory
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
with contextlib.suppress(Exception):
try:
return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
return ""
except: # noqa E722
return ""
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):

View file

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import re
import shutil
import subprocess
@ -53,7 +52,7 @@ def is_url(url, check=False):
valid = is_url("https://www.example.com")
```
"""
with contextlib.suppress(Exception):
try:
url = str(url)
result = parse.urlparse(url)
assert all([result.scheme, result.netloc]) # check if is url
@ -61,7 +60,8 @@ def is_url(url, check=False):
with request.urlopen(url) as response:
return response.getcode() == 200 # check if exists online
return True
return False
except: # noqa E722
return False
def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):

View file

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import math
import warnings
from pathlib import Path
@ -1115,10 +1114,12 @@ def plot_images(
mask = mask.astype(bool)
else:
mask = image_masks[j].astype(bool)
with contextlib.suppress(Exception):
try:
im[y : y + h, x : x + w, :][mask] = (
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
)
except: # noqa E722
pass
annotator.fromarray(im)
if not save:
return np.asarray(annotator.im)

View file

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import gc
import math
import os
@ -113,13 +112,15 @@ def get_cpu_info():
from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
if "cpu_info" not in PERSISTENT_CACHE:
with contextlib.suppress(Exception):
try:
import cpuinfo # pip install py-cpuinfo
k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference
info = cpuinfo.get_cpu_info() # info dict
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
except: # noqa E722
pass
return PERSISTENT_CACHE.get("cpu_info", "unknown")