ultralytics 8.3.8 replace contextlib with try for speed (#16782)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
1e6c454460
commit
a6a577961f
12 changed files with 115 additions and 88 deletions
|
|
@ -1,7 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import json
|
||||
import platform
|
||||
import zipfile
|
||||
|
|
@ -45,8 +44,10 @@ def check_class_names(names):
|
|||
def default_class_names(data=None):
|
||||
"""Applies default class names to an input YAML file or returns numerical class names."""
|
||||
if data:
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
return yaml_load(check_yaml(data))["names"]
|
||||
except: # noqa E722
|
||||
pass
|
||||
return {i: f"class{i}" for i in range(999)} # return default if above errors
|
||||
|
||||
|
||||
|
|
@ -321,8 +322,10 @@ class AutoBackend(nn.Module):
|
|||
with open(w, "rb") as f:
|
||||
gd.ParseFromString(f.read())
|
||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
||||
with contextlib.suppress(StopIteration): # find metadata in SavedModel alongside GraphDef
|
||||
try: # find metadata in SavedModel alongside GraphDef
|
||||
metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
# TFLite or TFLite Edge TPU
|
||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||
|
|
@ -345,10 +348,12 @@ class AutoBackend(nn.Module):
|
|||
input_details = interpreter.get_input_details() # inputs
|
||||
output_details = interpreter.get_output_details() # outputs
|
||||
# Load metadata
|
||||
with contextlib.suppress(zipfile.BadZipFile):
|
||||
try:
|
||||
with zipfile.ZipFile(w, "r") as model:
|
||||
meta_file = model.namelist()[0]
|
||||
metadata = ast.literal_eval(model.read(meta_file).decode("utf-8"))
|
||||
except zipfile.BadZipFile:
|
||||
pass
|
||||
|
||||
# TF.js
|
||||
elif tfjs:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import contextlib
|
||||
import pickle
|
||||
import re
|
||||
import types
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
|
@ -958,8 +959,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
|
||||
for j, a in enumerate(args):
|
||||
if isinstance(a, str):
|
||||
with contextlib.suppress(ValueError):
|
||||
try:
|
||||
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
||||
if m in {
|
||||
|
|
@ -1072,8 +1075,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
|
||||
def yaml_model_load(path):
|
||||
"""Load a YOLOv8 model from a YAML file."""
|
||||
import re
|
||||
|
||||
path = Path(path)
|
||||
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
|
||||
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
|
||||
|
|
@ -1100,11 +1101,10 @@ def guess_model_scale(model_path):
|
|||
Returns:
|
||||
(str): The size character of the model's scale, which can be n, s, m, l, or x.
|
||||
"""
|
||||
with contextlib.suppress(AttributeError):
|
||||
import re
|
||||
|
||||
try:
|
||||
return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
|
||||
return ""
|
||||
except AttributeError:
|
||||
return ""
|
||||
|
||||
|
||||
def guess_model_task(model):
|
||||
|
|
@ -1137,17 +1137,23 @@ def guess_model_task(model):
|
|||
|
||||
# Guess from model cfg
|
||||
if isinstance(model, dict):
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
return cfg2task(model)
|
||||
except: # noqa E722
|
||||
pass
|
||||
|
||||
# Guess from PyTorch model
|
||||
if isinstance(model, nn.Module): # PyTorch model
|
||||
for x in "model.args", "model.model.args", "model.model.model.args":
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
return eval(x)["task"]
|
||||
except: # noqa E722
|
||||
pass
|
||||
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
return cfg2task(eval(x))
|
||||
except: # noqa E722
|
||||
pass
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, Segment):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue