Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Paula Derrenger 2024-11-21 23:51:37 +01:00 committed by GitHub
parent 77c3c0aaac
commit d670bcc2b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 32 additions and 44 deletions

View file

@ -960,10 +960,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
for j, a in enumerate(args):
if isinstance(a, str):
try:
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
except ValueError:
pass
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in {
Classify,
@ -1141,24 +1139,16 @@ def guess_model_task(model):
# Guess from model cfg
if isinstance(model, dict):
try:
with contextlib.suppress(Exception):
return cfg2task(model)
except Exception:
pass
# Guess from PyTorch model
if isinstance(model, nn.Module): # PyTorch model
for x in "model.args", "model.model.args", "model.model.model.args":
try:
with contextlib.suppress(Exception):
return eval(x)["task"]
except Exception:
pass
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
try:
with contextlib.suppress(Exception):
return cfg2task(eval(x))
except Exception:
pass
for m in model.modules():
if isinstance(m, Segment):
return "segment"