ultralytics 8.2.11 new TensorRT INT8 export feature (#10165)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
1d9745182d
commit
fcfc44ea9c
15 changed files with 601 additions and 176 deletions
|
|
@ -1,10 +1,14 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from pathlib import Path
|
||||
from itertools import product
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.utils import ASSETS, WEIGHTS_DIR
|
||||
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
|
||||
|
||||
from . import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODEL, SOURCE
|
||||
|
||||
|
|
@ -23,6 +27,34 @@ def test_export_engine():
|
|||
YOLO(f)(SOURCE, device=0)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available")
|
||||
@pytest.mark.parametrize(
|
||||
"task, dynamic, int8, half, batch",
|
||||
[ # generate all combinations but exclude those where both int8 and half are True
|
||||
(task, dynamic, int8, half, batch)
|
||||
# Note: tests reduced below pending compute availability expansion as GPU CI runner utilization is high
|
||||
# for task, dynamic, int8, half, batch in product(TASKS, [True, False], [True, False], [True, False], [1, 2])
|
||||
for task, dynamic, int8, half, batch in product(TASKS, [True], [True], [False], [2])
|
||||
if not (int8 and half) # exclude cases where both int8 and half are True
|
||||
],
|
||||
)
|
||||
def test_export_engine_matrix(task, dynamic, int8, half, batch):
|
||||
"""Test YOLO exports to TensorRT format."""
|
||||
file = YOLO(TASK2MODEL[task]).export(
|
||||
format="engine",
|
||||
imgsz=32,
|
||||
dynamic=dynamic,
|
||||
int8=int8,
|
||||
half=half,
|
||||
batch=batch,
|
||||
data=TASK2DATA[task],
|
||||
)
|
||||
YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference
|
||||
Path(file).unlink() # cleanup
|
||||
Path(file).with_suffix(".cache").unlink() if int8 else None # cleanup INT8 cache
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available")
|
||||
def test_train():
|
||||
"""Test model training on a minimal dataset."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue