Refactor with simplifications (#19329)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
f4307339ad
commit
82b1ce44cc
4 changed files with 14 additions and 19 deletions
|
|
@ -309,9 +309,8 @@ class Exporter:
|
|||
"WARNING ⚠️ INT8 export requires a missing 'data' arg for calibration. "
|
||||
f"Using default 'data={self.args.data}'."
|
||||
)
|
||||
if tfjs:
|
||||
if ARM64 and LINUX:
|
||||
raise SystemError("TensorFlow.js export not supported on ARM64 Linux")
|
||||
if tfjs and (ARM64 and LINUX):
|
||||
raise SystemError("TensorFlow.js export not supported on ARM64 Linux")
|
||||
|
||||
# Input
|
||||
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
|
||||
|
|
|
|||
|
|
@ -197,12 +197,13 @@ class AutoBackend(nn.Module):
|
|||
import onnxruntime
|
||||
|
||||
providers = ["CPUExecutionProvider"]
|
||||
if cuda and "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
providers.insert(0, "CUDAExecutionProvider")
|
||||
elif cuda: # Only log warning if CUDA was requested but unavailable
|
||||
LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime with CUDA. Using CPU...")
|
||||
device = torch.device("cpu")
|
||||
cuda = False
|
||||
if cuda:
|
||||
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
providers.insert(0, "CUDAExecutionProvider")
|
||||
else: # Only log warning if CUDA was requested but unavailable
|
||||
LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime with CUDA. Using CPU...")
|
||||
device = torch.device("cpu")
|
||||
cuda = False
|
||||
LOGGER.info(f"Using ONNX Runtime {providers[0]}")
|
||||
if onnx:
|
||||
session = onnxruntime.InferenceSession(w, providers=providers)
|
||||
|
|
@ -223,7 +224,7 @@ class AutoBackend(nn.Module):
|
|||
output_names = [x.name for x in session.get_outputs()]
|
||||
metadata = session.get_modelmeta().custom_metadata_map
|
||||
dynamic = isinstance(session.get_outputs()[0].shape[0], str)
|
||||
fp16 = True if "float16" in session.get_inputs()[0].type else False
|
||||
fp16 = "float16" in session.get_inputs()[0].type
|
||||
if not dynamic:
|
||||
io = session.io_binding()
|
||||
bindings = []
|
||||
|
|
|
|||
|
|
@ -317,8 +317,7 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
|
|||
if len(m._parameters):
|
||||
for pn, p in m.named_parameters():
|
||||
LOGGER.info(
|
||||
f"{i:>5g}{mn + '.' + pn:>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}"
|
||||
f"{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}"
|
||||
f"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}"
|
||||
)
|
||||
else: # layers with no learnable params
|
||||
LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue