ultralytics 8.0.41 TF SavedModel and EdgeTPU export (#1034)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
4b866c9718
commit
f6e393c1d2
64 changed files with 604 additions and 351 deletions
|
|
@ -1,6 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
|
@ -338,7 +337,7 @@ def torch_safe_load(weight):
|
|||
|
||||
file = attempt_download_asset(weight) # search online if missing locally
|
||||
try:
|
||||
return torch.load(file, map_location='cpu') # load
|
||||
return torch.load(file, map_location='cpu'), file # load
|
||||
except ModuleNotFoundError as e:
|
||||
if e.name == 'omegaconf': # e.name is missing module name
|
||||
LOGGER.warning(f'WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements.'
|
||||
|
|
@ -347,7 +346,7 @@ def torch_safe_load(weight):
|
|||
f'download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0')
|
||||
if e.name != 'models':
|
||||
check_requirements(e.name) # install missing module
|
||||
return torch.load(file, map_location='cpu') # load
|
||||
return torch.load(file, map_location='cpu'), file # load
|
||||
|
||||
|
||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
|
|
@ -355,13 +354,13 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||
|
||||
ensemble = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
ckpt = torch_safe_load(w) # load ckpt
|
||||
ckpt, w = torch_safe_load(w) # load ckpt
|
||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
model.args = args # attach args to model
|
||||
model.pt_path = weights # attach *.pt file path to model
|
||||
model.pt_path = w # attach *.pt file path to model
|
||||
model.task = guess_model_task(model)
|
||||
if not hasattr(model, 'stride'):
|
||||
model.stride = torch.tensor([32.])
|
||||
|
|
@ -392,7 +391,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||
|
||||
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||
# Loads a single model weights
|
||||
ckpt = torch_safe_load(weight) # load ckpt
|
||||
ckpt, weight = torch_safe_load(weight) # load ckpt
|
||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue