ultralytics 8.2.9 OpenVINO INT8 fixes and tests (#10423)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Burhan 2024-05-05 09:11:17 -04:00 committed by GitHub
parent 299797ff9e
commit 2583f842b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 250 additions and 206 deletions

View file

@ -64,9 +64,10 @@ from pathlib import Path
import numpy as np
import torch
from ultralytics.cfg import get_cfg
from ultralytics.cfg import TASK2DATA, get_cfg
from ultralytics.data import build_dataloader
from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import check_class_names, default_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
@ -169,7 +170,7 @@ class Exporter:
callbacks.add_integration_callbacks(self)
@smart_inference_mode()
def __call__(self, model=None):
def __call__(self, model=None) -> str:
"""Returns list of exported files/dirs after running callbacks."""
self.run_callbacks("on_export_start")
t = time.time()
@ -211,7 +212,12 @@ class Exporter:
"(torchscript, onnx, openvino, engine, coreml) formats. "
"See https://docs.ultralytics.com/models/yolo-world for details."
)
if self.args.int8 and not self.args.data:
self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data
LOGGER.warning(
"WARNING ⚠️ INT8 export requires a missing 'data' arg for calibration. "
f"Using default 'data={self.args.data}'."
)
# Input
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
file = Path(
@ -333,6 +339,23 @@ class Exporter:
self.run_callbacks("on_export_end")
return f # return list of exported files/dirs
def get_int8_calibration_dataloader(self, prefix=""):
"""Build and return a dataloader suitable for calibration of INT8 models."""
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
data = (check_cls_dataset if self.model.task == "classify" else check_det_dataset)(self.args.data)
dataset = YOLODataset(
data[self.args.split or "val"],
data=data,
task=self.model.task,
imgsz=self.imgsz[0],
augment=False,
batch_size=self.args.batch,
)
n = len(dataset)
if n < 300:
LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.")
return build_dataloader(dataset, batch=self.args.batch, workers=0) # required for batch loading
@try_export
def export_torchscript(self, prefix=colorstr("TorchScript:")):
"""YOLOv8 TorchScript model export."""
@ -442,37 +465,21 @@ class Exporter:
if self.args.int8:
fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
if not self.args.data:
self.args.data = DEFAULT_CFG.data or "coco128.yaml"
LOGGER.warning(
f"{prefix} WARNING ⚠️ INT8 export requires a missing 'data' arg for calibration. "
f"Using default 'data={self.args.data}'."
)
check_requirements("nncf>=2.8.0")
import nncf
def transform_fn(data_item):
def transform_fn(data_item) -> np.ndarray:
"""Quantization transform function."""
assert (
data_item["img"].dtype == torch.uint8
), "Input image must be uint8 for the quantization preprocessing"
im = data_item["img"].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
data_item: torch.Tensor = data_item["img"] if isinstance(data_item, dict) else data_item
assert data_item.dtype == torch.uint8, "Input image must be uint8 for the quantization preprocessing"
im = data_item.numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
return np.expand_dims(im, 0) if im.ndim == 3 else im
# Generate calibration data for integer quantization
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
data = check_det_dataset(self.args.data)
dataset = YOLODataset(data["val"], data=data, task=self.model.task, imgsz=self.imgsz[0], augment=False)
n = len(dataset)
if n < 300:
LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.")
quantization_dataset = nncf.Dataset(dataset, transform_fn)
ignored_scope = None
if isinstance(self.model.model[-1], Detect):
# Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2])
ignored_scope = nncf.IgnoredScope( # ignore operations
patterns=[
f".*{head_module_name}/.*/Add",
@ -485,7 +492,10 @@ class Exporter:
)
quantized_ov_model = nncf.quantize(
ov_model, quantization_dataset, preset=nncf.QuantizationPreset.MIXED, ignored_scope=ignored_scope
model=ov_model,
calibration_dataset=nncf.Dataset(self.get_int8_calibration_dataloader(prefix), transform_fn),
preset=nncf.QuantizationPreset.MIXED,
ignored_scope=ignored_scope,
)
serialize(quantized_ov_model, fq_ov)
return fq, None
@ -787,11 +797,9 @@ class Exporter:
verbosity = "info"
if self.args.data:
# Generate calibration data for integer quantization
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
data = check_det_dataset(self.args.data)
dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
dataloader = self.get_int8_calibration_dataloader(prefix)
images = []
for i, batch in enumerate(dataset):
for i, batch in enumerate(dataloader):
if i >= 100: # maximum number of calibration images
break
im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC