Refactor with simplifications (#19329)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-02-20 19:43:53 +08:00 committed by GitHub
parent f4307339ad
commit 82b1ce44cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 14 additions and 19 deletions

View file

@ -197,12 +197,13 @@ class AutoBackend(nn.Module):
import onnxruntime
providers = ["CPUExecutionProvider"]
if cuda and "CUDAExecutionProvider" in onnxruntime.get_available_providers():
providers.insert(0, "CUDAExecutionProvider")
elif cuda: # Only log warning if CUDA was requested but unavailable
LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime with CUDA. Using CPU...")
device = torch.device("cpu")
cuda = False
if cuda:
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
providers.insert(0, "CUDAExecutionProvider")
else: # Only log warning if CUDA was requested but unavailable
LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime with CUDA. Using CPU...")
device = torch.device("cpu")
cuda = False
LOGGER.info(f"Using ONNX Runtime {providers[0]}")
if onnx:
session = onnxruntime.InferenceSession(w, providers=providers)
@ -223,7 +224,7 @@ class AutoBackend(nn.Module):
output_names = [x.name for x in session.get_outputs()]
metadata = session.get_modelmeta().custom_metadata_map
dynamic = isinstance(session.get_outputs()[0].shape[0], str)
fp16 = True if "float16" in session.get_inputs()[0].type else False
fp16 = "float16" in session.get_inputs()[0].type
if not dynamic:
io = session.io_binding()
bindings = []