ultralytics 8.2.15 get_latest_opset() compat for torch<1.13.0 (#12652)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
IMhx233 2024-05-14 01:28:57 +08:00 committed by GitHub
parent b3dfbdda02
commit 8d17af7e32
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 3 deletions

View file

@ -399,8 +399,13 @@ def copy_attr(a, b, include=(), exclude=()):
def get_latest_opset():
"""Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 # opset
"""Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity."""
if TORCH_1_13:
# If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
# Otherwise for PyTorch<=1.12 return the corresponding predefined opset
version = torch.onnx.producer_version.rsplit(".", 1)[0] # i.e. '2.3'
return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)
def intersect_dicts(da, db, exclude=()):