Restrict ONNX ExecutionProviders (#18400)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
9935b45377
commit
c9a48b281e
6 changed files with 23 additions and 20 deletions
|
|
@ -192,14 +192,14 @@ class AutoBackend(nn.Module):
|
|||
check_requirements("numpy==1.23.5")
|
||||
import onnxruntime
|
||||
|
||||
providers = onnxruntime.get_available_providers()
|
||||
if not cuda and "CUDAExecutionProvider" in providers:
|
||||
providers.remove("CUDAExecutionProvider")
|
||||
elif cuda and "CUDAExecutionProvider" not in providers:
|
||||
LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime session with CUDA. Falling back to CPU...")
|
||||
providers = ["CPUExecutionProvider"]
|
||||
if cuda and "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
providers.insert(0, "CUDAExecutionProvider")
|
||||
elif cuda: # Only log warning if CUDA was requested but unavailable
|
||||
LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime with CUDA. Using CPU...")
|
||||
device = torch.device("cpu")
|
||||
cuda = False
|
||||
LOGGER.info(f"Preferring ONNX Runtime {providers[0]}")
|
||||
LOGGER.info(f"Using ONNX Runtime {providers[0]}")
|
||||
if onnx:
|
||||
session = onnxruntime.InferenceSession(w, providers=providers)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue