Threadpool fixes and CLI improvements (#550)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher 2023-01-22 17:08:08 +01:00 committed by GitHub
parent d9a0fba251
commit 21b701c4ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 338 additions and 251 deletions

View file

@ -126,15 +126,15 @@ class Exporter:
save_dir (Path): Directory to save results.
"""
def __init__(self, config=DEFAULT_CFG, overrides=None):
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
"""
Initializes the Exporter class.
Args:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(config, overrides)
self.args = get_cfg(cfg, overrides)
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
callbacks.add_integration_callbacks(self)
@ -151,7 +151,7 @@ class Exporter:
# Load PyTorch model
self.device = select_device('cpu' if self.args.device is None else self.args.device)
if self.args.half:
if self.device.type == 'cpu' and not coreml:
if self.device.type == 'cpu' and not coreml and not xml:
LOGGER.info('half=True only compatible with GPU or CoreML export, i.e. use device=0 or format=coreml')
self.args.half = False
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
@ -184,7 +184,7 @@ class Exporter:
y = None
for _ in range(2):
y = model(im) # dry runs
if self.args.half and not coreml:
if self.args.half and not coreml and not xml:
im, model = im.half(), model.half() # to FP16
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
LOGGER.info(
@ -332,7 +332,7 @@ class Exporter:
f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
f_onnx = self.file.with_suffix('.onnx')
cmd = f"mo --input_model {f_onnx} --output_dir {f} --data_type {'FP16' if self.args.half else 'FP32'}"
cmd = f"mo --input_model {f_onnx} --output_dir {f} {'--compress_to_fp16' * self.args.half}"
subprocess.run(cmd.split(), check=True, env=os.environ) # export
yaml_save(Path(f) / self.file.with_suffix('.yaml').name, self.metadata) # add metadata.yaml
return f, None