ultralytics 8.0.54 TFLite export improvements and fixes (#1447)

Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-03-16 15:42:44 +01:00 committed by GitHub
parent 30fc4b537f
commit 701fba4770
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 198 additions and 166 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
"""
AutoBatch utils
Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.
"""
from copy import deepcopy
@ -13,18 +13,35 @@ from ultralytics.yolo.utils.torch_utils import profile
def check_train_batch_size(model, imgsz=640, amp=True):
# Check YOLOv5 training batch size
"""
Check YOLO training batch size using the autobatch() function.
Args:
model (torch.nn.Module): YOLO model to check batch size for.
imgsz (int): Image size used for training.
amp (bool): If True, use automatic mixed precision (AMP) for training.
Returns:
int: Optimal batch size computed using the autobatch() function.
"""
with torch.cuda.amp.autocast(amp):
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
def autobatch(model, imgsz=640, fraction=0.7, batch_size=16):
# Automatically estimate best YOLOv5 batch size to use `fraction` of available CUDA memory
# Usage:
# import torch
# from utils.autobatch import autobatch
# model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
# print(autobatch(model))
def autobatch(model, imgsz=640, fraction=0.67, batch_size=16):
"""
Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
Args:
model: YOLO model to compute batch size for.
imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67.
batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
Returns:
int: The optimal batch size.
"""
# Check device
prefix = colorstr('AutoBatch: ')