Threadpool fixes and CLI improvements (#550)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
d9a0fba251
commit
21b701c4ea
22 changed files with 338 additions and 251 deletions
|
|
@ -126,15 +126,15 @@ class Exporter:
|
|||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||
"""
|
||||
Initializes the Exporter class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(config, overrides)
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
|
|
@ -151,7 +151,7 @@ class Exporter:
|
|||
# Load PyTorch model
|
||||
self.device = select_device('cpu' if self.args.device is None else self.args.device)
|
||||
if self.args.half:
|
||||
if self.device.type == 'cpu' and not coreml:
|
||||
if self.device.type == 'cpu' and not coreml and not xml:
|
||||
LOGGER.info('half=True only compatible with GPU or CoreML export, i.e. use device=0 or format=coreml')
|
||||
self.args.half = False
|
||||
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
|
||||
|
|
@ -184,7 +184,7 @@ class Exporter:
|
|||
y = None
|
||||
for _ in range(2):
|
||||
y = model(im) # dry runs
|
||||
if self.args.half and not coreml:
|
||||
if self.args.half and not coreml and not xml:
|
||||
im, model = im.half(), model.half() # to FP16
|
||||
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
|
||||
LOGGER.info(
|
||||
|
|
@ -332,7 +332,7 @@ class Exporter:
|
|||
f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
|
||||
f_onnx = self.file.with_suffix('.onnx')
|
||||
|
||||
cmd = f"mo --input_model {f_onnx} --output_dir {f} --data_type {'FP16' if self.args.half else 'FP32'}"
|
||||
cmd = f"mo --input_model {f_onnx} --output_dir {f} {'--compress_to_fp16' * self.args.half}"
|
||||
subprocess.run(cmd.split(), check=True, env=os.environ) # export
|
||||
yaml_save(Path(f) / self.file.with_suffix('.yaml').name, self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue