From 8e92930a60ae6554b68fc529b7e7163452dfadf1 Mon Sep 17 00:00:00 2001 From: Francesco Mattioli Date: Fri, 11 Oct 2024 13:41:40 +0200 Subject: [PATCH] Fixed YOLO heads docstrings (#16822) Co-authored-by: Glenn Jocher --- tests/test_explorer.py | 66 ---------------------------------- ultralytics/nn/modules/head.py | 18 +++++----- 2 files changed, 9 insertions(+), 75 deletions(-) delete mode 100644 tests/test_explorer.py diff --git a/tests/test_explorer.py b/tests/test_explorer.py deleted file mode 100644 index 45b0a31e..00000000 --- a/tests/test_explorer.py +++ /dev/null @@ -1,66 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license - -import PIL -import pytest - -from ultralytics import Explorer -from ultralytics.utils import ASSETS -from ultralytics.utils.torch_utils import TORCH_1_13 - - -@pytest.mark.slow -@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13") -def test_similarity(): - """Test the correctness and response length of similarity calculations and SQL queries in the Explorer.""" - exp = Explorer(data="coco8.yaml") - exp.create_embeddings_table() - similar = exp.get_similar(idx=1) - assert len(similar) == 4 - similar = exp.get_similar(img=ASSETS / "bus.jpg") - assert len(similar) == 4 - similar = exp.get_similar(idx=[1, 2], limit=2) - assert len(similar) == 2 - sim_idx = exp.similarity_index() - assert len(sim_idx) == 4 - sql = exp.sql_query("WHERE labels LIKE '%zebra%'") - assert len(sql) == 1 - - -@pytest.mark.slow -@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13") -def test_det(): - """Test detection functionalities and verify embedding table includes bounding boxes.""" - exp = Explorer(data="coco8.yaml", model="yolo11n.pt") - exp.create_embeddings_table(force=True) - assert len(exp.table.head()["bboxes"]) > 0 - similar = exp.get_similar(idx=[1, 2], limit=10) - assert len(similar) > 0 - # This is a loose test, just checks errors not correctness - similar = exp.plot_similar(idx=[1, 2], limit=10) - assert isinstance(similar, PIL.Image.Image) - - -@pytest.mark.slow -@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13") -def test_seg(): - """Test segmentation functionalities and ensure the embedding table includes segmentation masks.""" - exp = Explorer(data="coco8-seg.yaml", model="yolo11n-seg.pt") - exp.create_embeddings_table(force=True) - assert len(exp.table.head()["masks"]) > 0 - similar = exp.get_similar(idx=[1, 2], limit=10) - assert len(similar) > 0 - similar = exp.plot_similar(idx=[1, 2], limit=10) - assert isinstance(similar, PIL.Image.Image) - - -@pytest.mark.slow -@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13") -def test_pose(): - """Test pose estimation functionality and verify the embedding table includes keypoints.""" - exp = Explorer(data="coco8-pose.yaml", model="yolo11n-pose.pt") - exp.create_embeddings_table(force=True) - assert len(exp.table.head()["keypoints"]) > 0 - similar = exp.get_similar(idx=[1, 2], limit=10) - assert len(similar) > 0 - similar = exp.plot_similar(idx=[1, 2], limit=10) - assert isinstance(similar, PIL.Image.Image) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 5f293177..60911e77 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -19,7 +19,7 @@ __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10D class Detect(nn.Module): - """YOLOv8 Detect head for detection models.""" + """YOLO Detect head for detection models.""" dynamic = False # force grid reconstruction export = False # export mode @@ -30,7 +30,7 @@ class Detect(nn.Module): strides = torch.empty(0) # init def __init__(self, nc=80, ch=()): - """Initializes the YOLOv8 detection layer with specified number of classes and channels.""" + """Initializes the YOLO detection layer with specified number of classes and channels.""" super().__init__() self.nc = nc # number of classes self.nl = len(ch) # number of detection layers @@ -162,7 +162,7 @@ class Detect(nn.Module): class Segment(Detect): - """YOLOv8 Segment head for segmentation models.""" + """YOLO Segment head for segmentation models.""" def __init__(self, nc=80, nm=32, npr=256, ch=()): """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.""" @@ -187,7 +187,7 @@ class Segment(Detect): class OBB(Detect): - """YOLOv8 OBB detection head for detection with rotation models.""" + """YOLO OBB detection head for detection with rotation models.""" def __init__(self, nc=80, ne=1, ch=()): """Initialize OBB with number of classes `nc` and layer channels `ch`.""" @@ -217,7 +217,7 @@ class OBB(Detect): class Pose(Detect): - """YOLOv8 Pose head for keypoints models.""" + """YOLO Pose head for keypoints models.""" def __init__(self, nc=80, kpt_shape=(17, 3), ch=()): """Initialize YOLO network with default parameters and Convolutional Layers.""" @@ -257,10 +257,10 @@ class Pose(Detect): class Classify(nn.Module): - """YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2).""" + """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).""" def __init__(self, c1, c2, k=1, s=1, p=None, g=1): - """Initializes YOLOv8 classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.""" + """Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.""" super().__init__() c_ = 1280 # efficientnet_b0 size self.conv = Conv(c1, c_, k, s, p, g) @@ -277,10 +277,10 @@ class Classify(nn.Module): class WorldDetect(Detect): - """Head for integrating YOLOv8 detection models with semantic understanding from text embeddings.""" + """Head for integrating YOLO detection models with semantic understanding from text embeddings.""" def __init__(self, nc=80, embed=512, with_bn=False, ch=()): - """Initialize YOLOv8 detection layer with nc classes and layer channels ch.""" + """Initialize YOLO detection layer with nc classes and layer channels ch.""" super().__init__(nc, ch) c3 = max(ch[0], min(self.nc, 100)) self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)