Fixed YOLO heads docstrings (#16822)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Francesco Mattioli 2024-10-11 13:41:40 +02:00 committed by GitHub
parent 75b3b695be
commit 8e92930a60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 75 deletions

View file

@ -1,66 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import PIL
import pytest
from ultralytics import Explorer
from ultralytics.utils import ASSETS
from ultralytics.utils.torch_utils import TORCH_1_13
@pytest.mark.slow
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
def test_similarity():
"""Test the correctness and response length of similarity calculations and SQL queries in the Explorer."""
exp = Explorer(data="coco8.yaml")
exp.create_embeddings_table()
similar = exp.get_similar(idx=1)
assert len(similar) == 4
similar = exp.get_similar(img=ASSETS / "bus.jpg")
assert len(similar) == 4
similar = exp.get_similar(idx=[1, 2], limit=2)
assert len(similar) == 2
sim_idx = exp.similarity_index()
assert len(sim_idx) == 4
sql = exp.sql_query("WHERE labels LIKE '%zebra%'")
assert len(sql) == 1
@pytest.mark.slow
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
def test_det():
"""Test detection functionalities and verify embedding table includes bounding boxes."""
exp = Explorer(data="coco8.yaml", model="yolo11n.pt")
exp.create_embeddings_table(force=True)
assert len(exp.table.head()["bboxes"]) > 0
similar = exp.get_similar(idx=[1, 2], limit=10)
assert len(similar) > 0
# This is a loose test, just checks errors not correctness
similar = exp.plot_similar(idx=[1, 2], limit=10)
assert isinstance(similar, PIL.Image.Image)
@pytest.mark.slow
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
def test_seg():
"""Test segmentation functionalities and ensure the embedding table includes segmentation masks."""
exp = Explorer(data="coco8-seg.yaml", model="yolo11n-seg.pt")
exp.create_embeddings_table(force=True)
assert len(exp.table.head()["masks"]) > 0
similar = exp.get_similar(idx=[1, 2], limit=10)
assert len(similar) > 0
similar = exp.plot_similar(idx=[1, 2], limit=10)
assert isinstance(similar, PIL.Image.Image)
@pytest.mark.slow
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
def test_pose():
"""Test pose estimation functionality and verify the embedding table includes keypoints."""
exp = Explorer(data="coco8-pose.yaml", model="yolo11n-pose.pt")
exp.create_embeddings_table(force=True)
assert len(exp.table.head()["keypoints"]) > 0
similar = exp.get_similar(idx=[1, 2], limit=10)
assert len(similar) > 0
similar = exp.plot_similar(idx=[1, 2], limit=10)
assert isinstance(similar, PIL.Image.Image)

View file

@ -19,7 +19,7 @@ __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10D
class Detect(nn.Module): class Detect(nn.Module):
"""YOLOv8 Detect head for detection models.""" """YOLO Detect head for detection models."""
dynamic = False # force grid reconstruction dynamic = False # force grid reconstruction
export = False # export mode export = False # export mode
@ -30,7 +30,7 @@ class Detect(nn.Module):
strides = torch.empty(0) # init strides = torch.empty(0) # init
def __init__(self, nc=80, ch=()): def __init__(self, nc=80, ch=()):
"""Initializes the YOLOv8 detection layer with specified number of classes and channels.""" """Initializes the YOLO detection layer with specified number of classes and channels."""
super().__init__() super().__init__()
self.nc = nc # number of classes self.nc = nc # number of classes
self.nl = len(ch) # number of detection layers self.nl = len(ch) # number of detection layers
@ -162,7 +162,7 @@ class Detect(nn.Module):
class Segment(Detect): class Segment(Detect):
"""YOLOv8 Segment head for segmentation models.""" """YOLO Segment head for segmentation models."""
def __init__(self, nc=80, nm=32, npr=256, ch=()): def __init__(self, nc=80, nm=32, npr=256, ch=()):
"""Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.""" """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
@ -187,7 +187,7 @@ class Segment(Detect):
class OBB(Detect): class OBB(Detect):
"""YOLOv8 OBB detection head for detection with rotation models.""" """YOLO OBB detection head for detection with rotation models."""
def __init__(self, nc=80, ne=1, ch=()): def __init__(self, nc=80, ne=1, ch=()):
"""Initialize OBB with number of classes `nc` and layer channels `ch`.""" """Initialize OBB with number of classes `nc` and layer channels `ch`."""
@ -217,7 +217,7 @@ class OBB(Detect):
class Pose(Detect): class Pose(Detect):
"""YOLOv8 Pose head for keypoints models.""" """YOLO Pose head for keypoints models."""
def __init__(self, nc=80, kpt_shape=(17, 3), ch=()): def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
"""Initialize YOLO network with default parameters and Convolutional Layers.""" """Initialize YOLO network with default parameters and Convolutional Layers."""
@ -257,10 +257,10 @@ class Pose(Detect):
class Classify(nn.Module): class Classify(nn.Module):
"""YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2).""" """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
"""Initializes YOLOv8 classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.""" """Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
super().__init__() super().__init__()
c_ = 1280 # efficientnet_b0 size c_ = 1280 # efficientnet_b0 size
self.conv = Conv(c1, c_, k, s, p, g) self.conv = Conv(c1, c_, k, s, p, g)
@ -277,10 +277,10 @@ class Classify(nn.Module):
class WorldDetect(Detect): class WorldDetect(Detect):
"""Head for integrating YOLOv8 detection models with semantic understanding from text embeddings.""" """Head for integrating YOLO detection models with semantic understanding from text embeddings."""
def __init__(self, nc=80, embed=512, with_bn=False, ch=()): def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
"""Initialize YOLOv8 detection layer with nc classes and layer channels ch.""" """Initialize YOLO detection layer with nc classes and layer channels ch."""
super().__init__(nc, ch) super().__init__(nc, ch)
c3 = max(ch[0], min(self.nc, 100)) c3 = max(ch[0], min(self.nc, 100))
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch) self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)