Add max_dim==2 argument to check_imgsz() (#789)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: andreaswimmer <53872150+andreaswimmer@users.noreply.github.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
This commit is contained in:
parent
5a80ad98db
commit
0d182e80f1
11 changed files with 96 additions and 52 deletions
|
|
@ -347,23 +347,24 @@ def torch_safe_load(weight):
|
|||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||
|
||||
model = Ensemble()
|
||||
ensemble = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
ckpt = torch_safe_load(w) # load ckpt
|
||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||
ckpt.pt_path = weights # attach *.pt file path to model
|
||||
if not hasattr(ckpt, 'stride'):
|
||||
ckpt.stride = torch.tensor([32.])
|
||||
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||
model.pt_path = weights # attach *.pt file path to model
|
||||
model.task = guess_model_task(model)
|
||||
if not hasattr(model, 'stride'):
|
||||
model.stride = torch.tensor([32.])
|
||||
|
||||
# Append
|
||||
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
||||
ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
|
||||
|
||||
# Module compatibility updates
|
||||
for m in model.modules():
|
||||
for m in ensemble.modules():
|
||||
t = type(m)
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
|
||||
m.inplace = inplace # torch 1.7.0 compatibility
|
||||
|
|
@ -371,16 +372,16 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
||||
# Return model
|
||||
if len(model) == 1:
|
||||
return model[-1]
|
||||
if len(ensemble) == 1:
|
||||
return ensemble[-1]
|
||||
|
||||
# Return ensemble
|
||||
print(f'Ensemble created with {weights}\n')
|
||||
for k in 'names', 'nc', 'yaml':
|
||||
setattr(model, k, getattr(model[0], k))
|
||||
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
|
||||
assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
|
||||
return model
|
||||
setattr(ensemble, k, getattr(ensemble[0], k))
|
||||
ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
|
||||
assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts: {[m.nc for m in ensemble]}'
|
||||
return ensemble
|
||||
|
||||
|
||||
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||
|
|
@ -392,6 +393,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|||
# Model compatibility updates
|
||||
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||
model.pt_path = weight # attach *.pt file path to model
|
||||
model.task = guess_model_task(model)
|
||||
if not hasattr(model, 'stride'):
|
||||
model.stride = torch.tensor([32.])
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue