ultralytics 8.2.10 add Classify and OBB Tasks to Results.summary() (#11653)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
Glenn Jocher 2024-05-06 12:26:01 +02:00 committed by GitHub
parent f2f1afd269
commit 9d48190e6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 46 additions and 53 deletions

View file

@ -18,6 +18,7 @@ from ultralytics.utils import (
checks, checks,
) )
from ultralytics.utils.torch_utils import TORCH_1_9, TORCH_1_13 from ultralytics.utils.torch_utils import TORCH_1_9, TORCH_1_13
from . import MODEL, SOURCE from . import MODEL, SOURCE

View file

@ -12,7 +12,7 @@ import yaml
from PIL import Image from PIL import Image
from ultralytics import RTDETR, YOLO from ultralytics import RTDETR, YOLO
from ultralytics.cfg import TASK2DATA from ultralytics.cfg import MODELS, TASK2DATA
from ultralytics.data.build import load_inference_source from ultralytics.data.build import load_inference_source
from ultralytics.utils import ( from ultralytics.utils import (
ASSETS, ASSETS,
@ -76,42 +76,27 @@ def test_predict_txt():
_ = YOLO(MODEL)(source=txt_file, imgsz=32) _ = YOLO(MODEL)(source=txt_file, imgsz=32)
def test_predict_img(): @pytest.mark.parametrize("model_name", MODELS)
def test_predict_img(model_name):
"""Test YOLO prediction on various types of image sources.""" """Test YOLO prediction on various types of image sources."""
model = YOLO(MODEL) model = YOLO(WEIGHTS_DIR / model_name)
seg_model = YOLO(WEIGHTS_DIR / "yolov8n-seg.pt") im = cv2.imread(str(SOURCE)) # uint8 numpy array
cls_model = YOLO(WEIGHTS_DIR / "yolov8n-cls.pt")
pose_model = YOLO(WEIGHTS_DIR / "yolov8n-pose.pt")
obb_model = YOLO(WEIGHTS_DIR / "yolov8n-obb.pt")
im = cv2.imread(str(SOURCE))
assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1 # PIL assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1 # PIL
assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1 # ndarray assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1 # ndarray
assert len(model(torch.rand((2, 3, 32, 32)), imgsz=32)) == 2 # batch-size 2 Tensor, FP32 0.0-1.0 RGB order
assert len(model(source=[im, im], save=True, save_txt=True, imgsz=32)) == 2 # batch assert len(model(source=[im, im], save=True, save_txt=True, imgsz=32)) == 2 # batch
assert len(list(model(source=[im, im], save=True, stream=True, imgsz=32))) == 2 # stream assert len(list(model(source=[im, im], save=True, stream=True, imgsz=32))) == 2 # stream
assert len(model(torch.zeros(320, 640, 3).numpy(), imgsz=32)) == 1 # tensor to numpy assert len(model(torch.zeros(320, 640, 3).numpy().astype(np.uint8), imgsz=32)) == 1 # tensor to numpy
batch = [ batch = [
str(SOURCE), # filename str(SOURCE), # filename
Path(SOURCE), # Path Path(SOURCE), # Path
"https://ultralytics.com/images/zidane.jpg" if ONLINE else SOURCE, # URI "https://ultralytics.com/images/zidane.jpg" if ONLINE else SOURCE, # URI
cv2.imread(str(SOURCE)), # OpenCV cv2.imread(str(SOURCE)), # OpenCV
Image.open(SOURCE), # PIL Image.open(SOURCE), # PIL
np.zeros((320, 640, 3)), np.zeros((320, 640, 3), dtype=np.uint8), # numpy
] # numpy ]
assert len(model(batch, imgsz=32)) == len(batch) # multiple sources in a batch assert len(model(batch, imgsz=32)) == len(batch) # multiple sources in a batch
# Test tensor inference
im = torch.rand((4, 3, 32, 32)) # batch-size 4, FP32 0.0-1.0 RGB order
results = model(im, imgsz=32)
assert len(results) == im.shape[0]
results = seg_model(im, imgsz=32)
assert len(results) == im.shape[0]
results = cls_model(im, imgsz=32)
assert len(results) == im.shape[0]
results = pose_model(im, imgsz=32)
assert len(results) == im.shape[0]
results = obb_model(im, imgsz=32)
assert len(results) == im.shape[0]
def test_predict_grey_and_4ch(): def test_predict_grey_and_4ch():
"""Test YOLO prediction on SOURCE converted to greyscale and 4-channel images.""" """Test YOLO prediction on SOURCE converted to greyscale and 4-channel images."""
@ -236,10 +221,10 @@ def test_predict_callback_and_setup():
print(boxes) print(boxes)
def test_results(): @pytest.mark.parametrize("model", MODELS)
def test_results(model):
"""Test various result formats for the YOLO model.""" """Test various result formats for the YOLO model."""
for m in "yolov8n-pose.pt", "yolov8n-seg.pt", "yolov8n.pt", "yolov8n-cls.pt": results = YOLO(WEIGHTS_DIR / model)([SOURCE, SOURCE], imgsz=160)
results = YOLO(WEIGHTS_DIR / m)([SOURCE, SOURCE], imgsz=160)
for r in results: for r in results:
r = r.cpu().numpy() r = r.cpu().numpy()
r = r.to(device="cpu", dtype=torch.float32) r = r.to(device="cpu", dtype=torch.float32)

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.9" __version__ = "8.2.10"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

View file

@ -53,6 +53,7 @@ TASK2METRIC = {
"pose": "metrics/mAP50-95(P)", "pose": "metrics/mAP50-95(P)",
"obb": "metrics/mAP50-95(B)", "obb": "metrics/mAP50-95(B)",
} }
MODELS = {TASK2MODEL[task] for task in TASKS}
ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
CLI_HELP_MSG = f""" CLI_HELP_MSG = f"""

View file

@ -387,26 +387,32 @@ class Results(SimpleClass):
def summary(self, normalize=False, decimals=5): def summary(self, normalize=False, decimals=5):
"""Convert the results to a summarized format.""" """Convert the results to a summarized format."""
if self.probs is not None:
LOGGER.warning("Warning: Classify results do not support the `summary()` method yet.")
return
# Create list of detection dictionaries # Create list of detection dictionaries
results = [] results = []
data = self.boxes.data.cpu().tolist() if self.probs is not None:
class_id = self.probs.top1
results.append(
{
"name": self.names[class_id],
"class": class_id,
"confidence": round(self.probs.top1conf.item(), decimals),
}
)
return results
data = self.boxes or self.obb
is_obb = self.obb is not None
h, w = self.orig_shape if normalize else (1, 1) h, w = self.orig_shape if normalize else (1, 1)
for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id
box = { class_id, conf = int(row.cls), round(row.conf.item(), decimals)
"x1": round(row[0] / w, decimals), box = (row.xyxyxyxy if is_obb else row.xyxy).squeeze().reshape(-1, 2).tolist()
"y1": round(row[1] / h, decimals), xy = {}
"x2": round(row[2] / w, decimals), for i, b in enumerate(box):
"y2": round(row[3] / h, decimals), xy[f"x{i + 1}"] = round(b[0] / w, decimals)
} xy[f"y{i + 1}"] = round(b[1] / h, decimals)
conf = round(row[-2], decimals) result = {"name": self.names[class_id], "class": class_id, "confidence": conf, "box": xy}
class_id = int(row[-1]) if data.is_track:
result = {"name": self.names[class_id], "class": class_id, "confidence": conf, "box": box} result["track_id"] = int(row.id.item()) # track ID
if self.boxes.is_track:
result["track_id"] = int(row[-3]) # track ID
if self.masks: if self.masks:
result["segments"] = { result["segments"] = {
"x": (self.masks.xy[i][:, 0] / w).round(decimals).tolist(), "x": (self.masks.xy[i][:, 0] / w).round(decimals).tolist(),