ultralytics 8.3.8 replace contextlib with try for speed (#16782)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-10-08 21:02:40 +02:00 committed by GitHub
parent 1e6c454460
commit a6a577961f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 115 additions and 88 deletions

View file

@ -2,6 +2,7 @@
import contextlib
import pickle
import re
import types
from copy import deepcopy
from pathlib import Path
@ -958,8 +959,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
try:
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
except ValueError:
pass
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in {
@ -1072,8 +1075,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
def yaml_model_load(path):
"""Load a YOLOv8 model from a YAML file."""
import re
path = Path(path)
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
@ -1100,11 +1101,10 @@ def guess_model_scale(model_path):
Returns:
(str): The size character of the model's scale, which can be n, s, m, l, or x.
"""
with contextlib.suppress(AttributeError):
import re
try:
return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
return ""
except AttributeError:
return ""
def guess_model_task(model):
@ -1137,17 +1137,23 @@ def guess_model_task(model):
# Guess from model cfg
if isinstance(model, dict):
with contextlib.suppress(Exception):
try:
return cfg2task(model)
except: # noqa E722
pass
# Guess from PyTorch model
if isinstance(model, nn.Module): # PyTorch model
for x in "model.args", "model.model.args", "model.model.model.args":
with contextlib.suppress(Exception):
try:
return eval(x)["task"]
except: # noqa E722
pass
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
with contextlib.suppress(Exception):
try:
return cfg2task(eval(x))
except: # noqa E722
pass
for m in model.modules():
if isinstance(m, Segment):