Explorer tests require torch>=1.13 (#15930)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-08-31 17:30:44 +02:00 committed by GitHub
parent 59132f23bd
commit 885a56837f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -5,9 +5,11 @@ import pytest
from ultralytics import Explorer from ultralytics import Explorer
from ultralytics.utils import ASSETS from ultralytics.utils import ASSETS
from ultralytics.utils.torch_utils import TORCH_1_13
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
def test_similarity(): def test_similarity():
"""Test the correctness and response length of similarity calculations and SQL queries in the Explorer.""" """Test the correctness and response length of similarity calculations and SQL queries in the Explorer."""
exp = Explorer(data="coco8.yaml") exp = Explorer(data="coco8.yaml")
@ -25,6 +27,7 @@ def test_similarity():
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
def test_det(): def test_det():
"""Test detection functionalities and verify embedding table includes bounding boxes.""" """Test detection functionalities and verify embedding table includes bounding boxes."""
exp = Explorer(data="coco8.yaml", model="yolov8n.pt") exp = Explorer(data="coco8.yaml", model="yolov8n.pt")
@ -38,6 +41,7 @@ def test_det():
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
def test_seg(): def test_seg():
"""Test segmentation functionalities and ensure the embedding table includes segmentation masks.""" """Test segmentation functionalities and ensure the embedding table includes segmentation masks."""
exp = Explorer(data="coco8-seg.yaml", model="yolov8n-seg.pt") exp = Explorer(data="coco8-seg.yaml", model="yolov8n-seg.pt")
@ -50,6 +54,7 @@ def test_seg():
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
def test_pose(): def test_pose():
"""Test pose estimation functionality and verify the embedding table includes keypoints.""" """Test pose estimation functionality and verify the embedding table includes keypoints."""
exp = Explorer(data="coco8-pose.yaml", model="yolov8n-pose.pt") exp = Explorer(data="coco8-pose.yaml", model="yolov8n-pose.pt")