Fix TFLite INT8 quant bug (#13082)

This commit is contained in:
Glenn Jocher 2024-05-24 01:00:17 +02:00 committed by GitHub
parent cb99f71728
commit 11623eeb00
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 32 additions and 37 deletions

View file

@ -83,6 +83,7 @@ from ultralytics.utils import (
WINDOWS,
__version__,
callbacks,
checks,
colorstr,
get_default_args,
yaml_save,
@ -184,6 +185,7 @@ class Exporter:
if sum(flags) != 1:
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
# Device
if fmt == "engine" and self.args.device is None:
@ -243,7 +245,7 @@ class Exporter:
m.dynamic = self.args.dynamic
m.export = True
m.format = self.args.format
elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
elif isinstance(m, C2f) and not is_tf_format:
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
m.forward = m.forward_split
@ -303,7 +305,7 @@ class Exporter:
f[3], _ = self.export_openvino()
if coreml: # CoreML
f[4], _ = self.export_coreml()
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
if is_tf_format: # TensorFlow formats
self.args.int8 |= edgetpu
f[5], keras_model = self.export_saved_model()
if pb or tfjs: # pb prerequisite to tfjs
@ -777,11 +779,10 @@ class Exporter:
_ = self.cache.write_bytes(cache)
# Load dataset w/ builder (for batching) and calibrate
dataset = self.get_int8_calibration_dataloader(prefix)
config.int8_calibrator = EngineCalibrator(
dataset=dataset,
dataset=self.get_int8_calibration_dataloader(prefix),
batch=2 * self.args.batch,
cache=self.file.with_suffix(".cache"),
cache=str(self.file.with_suffix(".cache")),
)
elif half:
@ -813,7 +814,7 @@ class Exporter:
except ImportError:
suffix = "-macos" if MACOS else "-aarch64" if ARM64 else "" if cuda else "-cpu"
version = "" if ARM64 else "<=2.13.1"
check_requirements(f"tensorflow{suffix}{version}")
check_requirements((f"tensorflow{suffix}{version}", "keras"))
import tensorflow as tf # noqa
if ARM64:
check_requirements("cmake") # 'cmake' is needed to build onnxsim on aarch64
@ -855,24 +856,17 @@ class Exporter:
f_onnx, _ = self.export_onnx()
# Export to TF
tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
np_data = None
if self.args.int8:
tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
verbosity = "info"
if self.args.data:
# Generate calibration data for integer quantization
dataloader = self.get_int8_calibration_dataloader(prefix)
images = []
for i, batch in enumerate(dataloader):
if i >= 100: # maximum number of calibration images
break
im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC
images.append(im)
f.mkdir()
images = [batch["img"].permute(0, 2, 3, 1) for batch in self.get_int8_calibration_dataloader(prefix)]
images = torch.cat(images, 0).float()
# mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53]
# std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375]
np.save(str(tmp_file), images.numpy()) # BHWC
np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC
np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
else:
verbosity = "error"