Update .pre-commit-config.yaml (#1026)
This commit is contained in:
parent
9047d737f4
commit
edd3ff1669
76 changed files with 928 additions and 935 deletions
|
|
@ -144,7 +144,7 @@ class Exporter:
|
|||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, model=None):
|
||||
self.run_callbacks("on_export_start")
|
||||
self.run_callbacks('on_export_start')
|
||||
t = time.time()
|
||||
format = self.args.format.lower() # to lowercase
|
||||
if format in {'tensorrt', 'trt'}: # engine aliases
|
||||
|
|
@ -207,7 +207,7 @@ class Exporter:
|
|||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
||||
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
||||
self.metadata = {
|
||||
'description': f"Ultralytics {self.pretty_name} model trained on {self.args.data}",
|
||||
'description': f'Ultralytics {self.pretty_name} model trained on {self.args.data}',
|
||||
'author': 'Ultralytics',
|
||||
'license': 'GPL-3.0 https://ultralytics.com/license',
|
||||
'version': __version__,
|
||||
|
|
@ -215,7 +215,7 @@ class Exporter:
|
|||
'names': model.names} # model metadata
|
||||
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
||||
f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)")
|
||||
f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
|
||||
|
||||
# Exports
|
||||
f = [''] * len(fmts) # exported filenames
|
||||
|
|
@ -259,15 +259,15 @@ class Exporter:
|
|||
s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
|
||||
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
|
||||
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
|
||||
data = f"data={self.args.data}" if model.task == 'segment' and format == 'pb' else ''
|
||||
data = f'data={self.args.data}' if model.task == 'segment' and format == 'pb' else ''
|
||||
LOGGER.info(
|
||||
f'\nExport complete ({time.time() - t:.1f}s)'
|
||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||
f"\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}"
|
||||
f"\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}"
|
||||
f"\nVisualize: https://netron.app")
|
||||
f'\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}'
|
||||
f'\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}'
|
||||
f'\nVisualize: https://netron.app')
|
||||
|
||||
self.run_callbacks("on_export_end")
|
||||
self.run_callbacks('on_export_end')
|
||||
return f # return list of exported files/dirs
|
||||
|
||||
@try_export
|
||||
|
|
@ -277,7 +277,7 @@ class Exporter:
|
|||
f = self.file.with_suffix('.torchscript')
|
||||
|
||||
ts = torch.jit.trace(self.model, self.im, strict=False)
|
||||
d = {"shape": self.im.shape, "stride": int(max(self.model.stride)), "names": self.model.names}
|
||||
d = {'shape': self.im.shape, 'stride': int(max(self.model.stride)), 'names': self.model.names}
|
||||
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
|
||||
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
||||
LOGGER.info(f'{prefix} optimizing for mobile...')
|
||||
|
|
@ -354,7 +354,7 @@ class Exporter:
|
|||
|
||||
ov_model = mo.convert_model(f_onnx,
|
||||
model_name=self.pretty_name,
|
||||
framework="onnx",
|
||||
framework='onnx',
|
||||
compress_to_fp16=self.args.half) # export
|
||||
ov.serialize(ov_model, f_ov) # save
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
|
|
@ -471,7 +471,7 @@ class Exporter:
|
|||
if self.args.dynamic:
|
||||
shape = self.im.shape
|
||||
if shape[0] <= 1:
|
||||
LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
|
||||
LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
|
||||
profile = builder.create_optimization_profile()
|
||||
for inp in inputs:
|
||||
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
|
||||
|
|
@ -509,8 +509,8 @@ class Exporter:
|
|||
except ImportError:
|
||||
check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
|
||||
import tensorflow as tf # noqa
|
||||
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon", "tflite_support"),
|
||||
cmds="--extra-index-url https://pypi.ngc.nvidia.com")
|
||||
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support'),
|
||||
cmds='--extra-index-url https://pypi.ngc.nvidia.com')
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, '_saved_model')
|
||||
|
|
@ -632,7 +632,7 @@ class Exporter:
|
|||
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
|
||||
|
||||
tflite_model = converter.convert()
|
||||
open(f, "wb").write(tflite_model)
|
||||
open(f, 'wb').write(tflite_model)
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
|
|
@ -656,7 +656,7 @@ class Exporter:
|
|||
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
|
||||
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
|
||||
|
||||
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}"
|
||||
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}'
|
||||
subprocess.run(cmd.split(), check=True)
|
||||
self._add_tflite_metadata(f)
|
||||
return f, None
|
||||
|
|
@ -707,8 +707,8 @@ class Exporter:
|
|||
|
||||
# Creates input info.
|
||||
input_meta = _metadata_fb.TensorMetadataT()
|
||||
input_meta.name = "image"
|
||||
input_meta.description = "Input image to be detected."
|
||||
input_meta.name = 'image'
|
||||
input_meta.description = 'Input image to be detected.'
|
||||
input_meta.content = _metadata_fb.ContentT()
|
||||
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
|
||||
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
|
||||
|
|
@ -716,8 +716,8 @@ class Exporter:
|
|||
|
||||
# Creates output info.
|
||||
output_meta = _metadata_fb.TensorMetadataT()
|
||||
output_meta.name = "output"
|
||||
output_meta.description = "Coordinates of detected objects, class labels, and confidence score."
|
||||
output_meta.name = 'output'
|
||||
output_meta.description = 'Coordinates of detected objects, class labels, and confidence score.'
|
||||
|
||||
# Label file
|
||||
tmp_file = Path('/tmp/meta.txt')
|
||||
|
|
@ -868,8 +868,8 @@ class Exporter:
|
|||
|
||||
|
||||
def export(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n.yaml"
|
||||
cfg.format = cfg.format or "torchscript"
|
||||
cfg.model = cfg.model or 'yolov8n.yaml'
|
||||
cfg.format = cfg.format or 'torchscript'
|
||||
|
||||
# exporter = Exporter(cfg)
|
||||
#
|
||||
|
|
@ -888,7 +888,7 @@ def export(cfg=DEFAULT_CFG):
|
|||
model.export(**vars(cfg))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
CLI:
|
||||
yolo mode=export model=yolov8n.yaml format=onnx
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue