From b0c18b71900148c0598056c7ff79fc560aeab083 Mon Sep 17 00:00:00 2001 From: Francesco Mattioli Date: Tue, 29 Oct 2024 10:31:32 +0100 Subject: [PATCH] Fix arbitrary imgsz for TFLite (#17138) Co-authored-by: UltralyticsAssistant Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> --- ultralytics/engine/exporter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 49e84af9..5104de1c 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -890,8 +890,10 @@ class Exporter: tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file if self.args.data: f.mkdir() - images = [batch["img"].permute(0, 2, 3, 1) for batch in self.get_int8_calibration_dataloader(prefix)] - images = torch.cat(images, 0).float() + images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)] + images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute( + 0, 2, 3, 1 + ) np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]