Predictor support (#65)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia 2022-12-07 10:33:10 +05:30 committed by GitHub
parent 479992093c
commit e6737f1207
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 916 additions and 48 deletions

View file

@ -36,6 +36,14 @@ def torch_distributed_zero_first(local_rank: int):
dist.barrier(device_ids=[0])
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
def decorate(fn):
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
return decorate
def DDP_model(model):
# Model DDP creation with checks
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
@ -192,14 +200,6 @@ def copy_attr(a, b, include=(), exclude=()):
setattr(a, k, v)
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
def decorate(fn):
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
return decorate
def intersect_state_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}