ultralytics 8.0.37 add TFLite metadata in AutoBackend (#953)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Aarni Koskela <akx@iki.fi>
This commit is contained in:
parent
20fe708f31
commit
bdc6cd4d8b
18 changed files with 86 additions and 46 deletions
|
|
@ -1,5 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
|
@ -427,6 +428,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
||||
m = eval(m) if isinstance(m, str) else m # eval strings
|
||||
for j, a in enumerate(args):
|
||||
# TODO: re-implement with eval() removal if possible
|
||||
# args[j] = (locals()[a] if a in locals() else ast.literal_eval(a)) if isinstance(a, str) else a
|
||||
with contextlib.suppress(NameError):
|
||||
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
||||
|
||||
|
|
@ -480,28 +483,9 @@ def guess_model_task(model):
|
|||
Raises:
|
||||
SyntaxError: If the task of the model could not be determined.
|
||||
"""
|
||||
cfg = None
|
||||
if isinstance(model, dict):
|
||||
cfg = model
|
||||
elif isinstance(model, nn.Module): # PyTorch model
|
||||
for x in 'model.args', 'model.model.args', 'model.model.model.args':
|
||||
with contextlib.suppress(Exception):
|
||||
return eval(x)['task']
|
||||
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
|
||||
with contextlib.suppress(Exception):
|
||||
cfg = eval(x)
|
||||
break
|
||||
elif isinstance(model, (str, Path)):
|
||||
model = str(model)
|
||||
if '-seg' in model:
|
||||
return "segment"
|
||||
elif '-cls' in model:
|
||||
return "classify"
|
||||
else:
|
||||
return "detect"
|
||||
|
||||
# Guess from YAML dictionary
|
||||
if cfg:
|
||||
def cfg2task(cfg):
|
||||
# Guess from YAML dictionary
|
||||
m = cfg["head"][-1][-2].lower() # output module name
|
||||
if m in ["classify", "classifier", "cls", "fc"]:
|
||||
return "classify"
|
||||
|
|
@ -510,8 +494,20 @@ def guess_model_task(model):
|
|||
if m in ["segment"]:
|
||||
return "segment"
|
||||
|
||||
# Guess from model cfg
|
||||
if isinstance(model, dict):
|
||||
with contextlib.suppress(Exception):
|
||||
return cfg2task(model)
|
||||
|
||||
# Guess from PyTorch model
|
||||
if isinstance(model, nn.Module):
|
||||
if isinstance(model, nn.Module): # PyTorch model
|
||||
for x in 'model.args', 'model.model.args', 'model.model.model.args':
|
||||
with contextlib.suppress(Exception):
|
||||
return eval(x)['task']
|
||||
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
|
||||
with contextlib.suppress(Exception):
|
||||
return cfg2task(eval(x))
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, Detect):
|
||||
return "detect"
|
||||
|
|
@ -520,6 +516,16 @@ def guess_model_task(model):
|
|||
elif isinstance(m, Classify):
|
||||
return "classify"
|
||||
|
||||
# Guess from model filename
|
||||
if isinstance(model, (str, Path)):
|
||||
model = Path(model).stem
|
||||
if '-seg' in model:
|
||||
return "segment"
|
||||
elif '-cls' in model:
|
||||
return "classify"
|
||||
else:
|
||||
return "detect"
|
||||
|
||||
# Unable to determine task from model
|
||||
raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, "
|
||||
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue