Fixed YOLO heads docstrings (#16822)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
75b3b695be
commit
8e92930a60
2 changed files with 9 additions and 75 deletions
|
|
@ -1,66 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import PIL
|
||||
import pytest
|
||||
|
||||
from ultralytics import Explorer
|
||||
from ultralytics.utils import ASSETS
|
||||
from ultralytics.utils.torch_utils import TORCH_1_13
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
|
||||
def test_similarity():
|
||||
"""Test the correctness and response length of similarity calculations and SQL queries in the Explorer."""
|
||||
exp = Explorer(data="coco8.yaml")
|
||||
exp.create_embeddings_table()
|
||||
similar = exp.get_similar(idx=1)
|
||||
assert len(similar) == 4
|
||||
similar = exp.get_similar(img=ASSETS / "bus.jpg")
|
||||
assert len(similar) == 4
|
||||
similar = exp.get_similar(idx=[1, 2], limit=2)
|
||||
assert len(similar) == 2
|
||||
sim_idx = exp.similarity_index()
|
||||
assert len(sim_idx) == 4
|
||||
sql = exp.sql_query("WHERE labels LIKE '%zebra%'")
|
||||
assert len(sql) == 1
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
|
||||
def test_det():
|
||||
"""Test detection functionalities and verify embedding table includes bounding boxes."""
|
||||
exp = Explorer(data="coco8.yaml", model="yolo11n.pt")
|
||||
exp.create_embeddings_table(force=True)
|
||||
assert len(exp.table.head()["bboxes"]) > 0
|
||||
similar = exp.get_similar(idx=[1, 2], limit=10)
|
||||
assert len(similar) > 0
|
||||
# This is a loose test, just checks errors not correctness
|
||||
similar = exp.plot_similar(idx=[1, 2], limit=10)
|
||||
assert isinstance(similar, PIL.Image.Image)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
|
||||
def test_seg():
|
||||
"""Test segmentation functionalities and ensure the embedding table includes segmentation masks."""
|
||||
exp = Explorer(data="coco8-seg.yaml", model="yolo11n-seg.pt")
|
||||
exp.create_embeddings_table(force=True)
|
||||
assert len(exp.table.head()["masks"]) > 0
|
||||
similar = exp.get_similar(idx=[1, 2], limit=10)
|
||||
assert len(similar) > 0
|
||||
similar = exp.plot_similar(idx=[1, 2], limit=10)
|
||||
assert isinstance(similar, PIL.Image.Image)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
|
||||
def test_pose():
|
||||
"""Test pose estimation functionality and verify the embedding table includes keypoints."""
|
||||
exp = Explorer(data="coco8-pose.yaml", model="yolo11n-pose.pt")
|
||||
exp.create_embeddings_table(force=True)
|
||||
assert len(exp.table.head()["keypoints"]) > 0
|
||||
similar = exp.get_similar(idx=[1, 2], limit=10)
|
||||
assert len(similar) > 0
|
||||
similar = exp.plot_similar(idx=[1, 2], limit=10)
|
||||
assert isinstance(similar, PIL.Image.Image)
|
||||
|
|
@ -19,7 +19,7 @@ __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10D
|
|||
|
||||
|
||||
class Detect(nn.Module):
|
||||
"""YOLOv8 Detect head for detection models."""
|
||||
"""YOLO Detect head for detection models."""
|
||||
|
||||
dynamic = False # force grid reconstruction
|
||||
export = False # export mode
|
||||
|
|
@ -30,7 +30,7 @@ class Detect(nn.Module):
|
|||
strides = torch.empty(0) # init
|
||||
|
||||
def __init__(self, nc=80, ch=()):
|
||||
"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""
|
||||
"""Initializes the YOLO detection layer with specified number of classes and channels."""
|
||||
super().__init__()
|
||||
self.nc = nc # number of classes
|
||||
self.nl = len(ch) # number of detection layers
|
||||
|
|
@ -162,7 +162,7 @@ class Detect(nn.Module):
|
|||
|
||||
|
||||
class Segment(Detect):
|
||||
"""YOLOv8 Segment head for segmentation models."""
|
||||
"""YOLO Segment head for segmentation models."""
|
||||
|
||||
def __init__(self, nc=80, nm=32, npr=256, ch=()):
|
||||
"""Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
|
||||
|
|
@ -187,7 +187,7 @@ class Segment(Detect):
|
|||
|
||||
|
||||
class OBB(Detect):
|
||||
"""YOLOv8 OBB detection head for detection with rotation models."""
|
||||
"""YOLO OBB detection head for detection with rotation models."""
|
||||
|
||||
def __init__(self, nc=80, ne=1, ch=()):
|
||||
"""Initialize OBB with number of classes `nc` and layer channels `ch`."""
|
||||
|
|
@ -217,7 +217,7 @@ class OBB(Detect):
|
|||
|
||||
|
||||
class Pose(Detect):
|
||||
"""YOLOv8 Pose head for keypoints models."""
|
||||
"""YOLO Pose head for keypoints models."""
|
||||
|
||||
def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
|
||||
"""Initialize YOLO network with default parameters and Convolutional Layers."""
|
||||
|
|
@ -257,10 +257,10 @@ class Pose(Detect):
|
|||
|
||||
|
||||
class Classify(nn.Module):
|
||||
"""YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
||||
"""YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
|
||||
"""Initializes YOLOv8 classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
|
||||
"""Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
|
||||
super().__init__()
|
||||
c_ = 1280 # efficientnet_b0 size
|
||||
self.conv = Conv(c1, c_, k, s, p, g)
|
||||
|
|
@ -277,10 +277,10 @@ class Classify(nn.Module):
|
|||
|
||||
|
||||
class WorldDetect(Detect):
|
||||
"""Head for integrating YOLOv8 detection models with semantic understanding from text embeddings."""
|
||||
"""Head for integrating YOLO detection models with semantic understanding from text embeddings."""
|
||||
|
||||
def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
|
||||
"""Initialize YOLOv8 detection layer with nc classes and layer channels ch."""
|
||||
"""Initialize YOLO detection layer with nc classes and layer channels ch."""
|
||||
super().__init__(nc, ch)
|
||||
c3 = max(ch[0], min(self.nc, 100))
|
||||
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue