Add NNCF support in OpenVINO export (#4671)
Co-authored-by: 下北泽miHoMo红茶坊 <39751846+kisaragychihaya@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
d2cf7acce0
commit
45ba99973d
4 changed files with 72 additions and 45 deletions
|
|
@ -65,12 +65,7 @@ class Compose:
|
|||
|
||||
def __repr__(self):
|
||||
"""Return string representation of object."""
|
||||
format_string = f'{self.__class__.__name__}('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
format_string += f' {t}'
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
return f"{self.__class__.__name__}({', '.join([f'{t}' for t in self.transforms])})"
|
||||
|
||||
|
||||
class BaseMixTransform:
|
||||
|
|
|
|||
|
|
@ -58,9 +58,12 @@ from copy import deepcopy
|
|||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.data.dataset import YOLODataset
|
||||
from ultralytics.data.utils import check_det_dataset
|
||||
from ultralytics.nn.autobackend import check_class_names
|
||||
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
||||
|
|
@ -275,10 +278,11 @@ class Exporter:
|
|||
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
|
||||
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
|
||||
predict_data = f'data={data}' if model.task == 'segment' and format == 'pb' else ''
|
||||
q = 'int8' if self.args.int8 else 'half' if self.args.half else '' # quantization
|
||||
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||
f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {predict_data}'
|
||||
f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {s}'
|
||||
f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}'
|
||||
f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}'
|
||||
f'\nVisualize: https://netron.app')
|
||||
|
||||
self.run_callbacks('on_export_end')
|
||||
|
|
@ -367,27 +371,54 @@ class Exporter:
|
|||
|
||||
LOGGER.info(f'\n{prefix} starting export with openvino {ov.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
|
||||
fq = str(self.file).replace(self.file.suffix, f'_int8_openvino_model{os.sep}')
|
||||
f_onnx = self.file.with_suffix('.onnx')
|
||||
f_ov = str(Path(f) / self.file.with_suffix('.xml').name)
|
||||
fq_ov = str(Path(fq) / self.file.with_suffix('.xml').name)
|
||||
|
||||
def serialize(ov_model, file):
|
||||
"""Set RT info, serialize and save metadata YAML."""
|
||||
ov_model.set_rt_info('YOLOv8', ['model_info', 'model_type'])
|
||||
ov_model.set_rt_info(True, ['model_info', 'reverse_input_channels'])
|
||||
ov_model.set_rt_info(114, ['model_info', 'pad_value'])
|
||||
ov_model.set_rt_info([255.0], ['model_info', 'scale_values'])
|
||||
ov_model.set_rt_info(self.args.iou, ['model_info', 'iou_threshold'])
|
||||
ov_model.set_rt_info([v.replace(' ', '_') for v in self.model.names.values()], ['model_info', 'labels'])
|
||||
if self.model.task != 'classify':
|
||||
ov_model.set_rt_info('fit_to_window_letterbox', ['model_info', 'resize_type'])
|
||||
|
||||
ov.serialize(ov_model, file) # save
|
||||
yaml_save(Path(file).parent / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
|
||||
ov_model = mo.convert_model(f_onnx,
|
||||
model_name=self.pretty_name,
|
||||
framework='onnx',
|
||||
compress_to_fp16=self.args.half) # export
|
||||
|
||||
# Set RT info
|
||||
ov_model.set_rt_info('YOLOv8', ['model_info', 'model_type'])
|
||||
ov_model.set_rt_info(True, ['model_info', 'reverse_input_channels'])
|
||||
ov_model.set_rt_info(114, ['model_info', 'pad_value'])
|
||||
ov_model.set_rt_info([255.0], ['model_info', 'scale_values'])
|
||||
ov_model.set_rt_info(self.args.iou, ['model_info', 'iou_threshold'])
|
||||
ov_model.set_rt_info([v.replace(' ', '_') for k, v in sorted(self.model.names.items())],
|
||||
['model_info', 'labels'])
|
||||
if self.model.task != 'classify':
|
||||
ov_model.set_rt_info('fit_to_window_letterbox', ['model_info', 'resize_type'])
|
||||
if self.args.int8:
|
||||
assert self.args.data, "INT8 export requires a data argument for calibration, i.e. 'data=coco8.yaml'"
|
||||
check_requirements('nncf>=2.5.0')
|
||||
import nncf
|
||||
|
||||
ov.serialize(ov_model, f_ov) # save
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
def transform_fn(data_item):
|
||||
"""Quantization transform function."""
|
||||
im = data_item['img'].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
|
||||
return np.expand_dims(im, 0) if im.ndim == 3 else im
|
||||
|
||||
# Generate calibration data for integer quantization
|
||||
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
||||
data = check_det_dataset(self.args.data)
|
||||
dataset = YOLODataset(data['val'], data=data, imgsz=self.imgsz[0], augment=False)
|
||||
quantization_dataset = nncf.Dataset(dataset, transform_fn)
|
||||
ignored_scope = nncf.IgnoredScope(types=['Multiply', 'Subtract', 'Sigmoid']) # ignore operation
|
||||
quantized_ov_model = nncf.quantize(ov_model,
|
||||
quantization_dataset,
|
||||
preset=nncf.QuantizationPreset.MIXED,
|
||||
ignored_scope=ignored_scope)
|
||||
serialize(quantized_ov_model, fq_ov)
|
||||
return fq, None
|
||||
|
||||
serialize(ov_model, f_ov)
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
|
|
@ -633,19 +664,13 @@ class Exporter:
|
|||
if self.args.int8:
|
||||
verbosity = '--verbosity info'
|
||||
if self.args.data:
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.data.dataset import YOLODataset
|
||||
from ultralytics.data.utils import check_det_dataset
|
||||
|
||||
# Generate calibration data for integer quantization
|
||||
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
||||
data = check_det_dataset(self.args.data)
|
||||
dataset = YOLODataset(data['val'], data=data, imgsz=self.imgsz[0], augment=False)
|
||||
images = []
|
||||
n_images = 100 # maximum number of images
|
||||
for n, batch in enumerate(dataset):
|
||||
if n >= n_images:
|
||||
for i, batch in enumerate(dataset):
|
||||
if i >= 100: # maximum number of calibration images
|
||||
break
|
||||
im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC
|
||||
images.append(im)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue