ultralytics 8.2.10 add Classify and OBB Tasks to Results.summary() (#11653)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
parent
f2f1afd269
commit
9d48190e6d
5 changed files with 46 additions and 53 deletions
|
|
@ -18,6 +18,7 @@ from ultralytics.utils import (
|
||||||
checks,
|
checks,
|
||||||
)
|
)
|
||||||
from ultralytics.utils.torch_utils import TORCH_1_9, TORCH_1_13
|
from ultralytics.utils.torch_utils import TORCH_1_9, TORCH_1_13
|
||||||
|
|
||||||
from . import MODEL, SOURCE
|
from . import MODEL, SOURCE
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import yaml
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ultralytics import RTDETR, YOLO
|
from ultralytics import RTDETR, YOLO
|
||||||
from ultralytics.cfg import TASK2DATA
|
from ultralytics.cfg import MODELS, TASK2DATA
|
||||||
from ultralytics.data.build import load_inference_source
|
from ultralytics.data.build import load_inference_source
|
||||||
from ultralytics.utils import (
|
from ultralytics.utils import (
|
||||||
ASSETS,
|
ASSETS,
|
||||||
|
|
@ -76,42 +76,27 @@ def test_predict_txt():
|
||||||
_ = YOLO(MODEL)(source=txt_file, imgsz=32)
|
_ = YOLO(MODEL)(source=txt_file, imgsz=32)
|
||||||
|
|
||||||
|
|
||||||
def test_predict_img():
|
@pytest.mark.parametrize("model_name", MODELS)
|
||||||
|
def test_predict_img(model_name):
|
||||||
"""Test YOLO prediction on various types of image sources."""
|
"""Test YOLO prediction on various types of image sources."""
|
||||||
model = YOLO(MODEL)
|
model = YOLO(WEIGHTS_DIR / model_name)
|
||||||
seg_model = YOLO(WEIGHTS_DIR / "yolov8n-seg.pt")
|
im = cv2.imread(str(SOURCE)) # uint8 numpy array
|
||||||
cls_model = YOLO(WEIGHTS_DIR / "yolov8n-cls.pt")
|
|
||||||
pose_model = YOLO(WEIGHTS_DIR / "yolov8n-pose.pt")
|
|
||||||
obb_model = YOLO(WEIGHTS_DIR / "yolov8n-obb.pt")
|
|
||||||
im = cv2.imread(str(SOURCE))
|
|
||||||
assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1 # PIL
|
assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1 # PIL
|
||||||
assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1 # ndarray
|
assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1 # ndarray
|
||||||
|
assert len(model(torch.rand((2, 3, 32, 32)), imgsz=32)) == 2 # batch-size 2 Tensor, FP32 0.0-1.0 RGB order
|
||||||
assert len(model(source=[im, im], save=True, save_txt=True, imgsz=32)) == 2 # batch
|
assert len(model(source=[im, im], save=True, save_txt=True, imgsz=32)) == 2 # batch
|
||||||
assert len(list(model(source=[im, im], save=True, stream=True, imgsz=32))) == 2 # stream
|
assert len(list(model(source=[im, im], save=True, stream=True, imgsz=32))) == 2 # stream
|
||||||
assert len(model(torch.zeros(320, 640, 3).numpy(), imgsz=32)) == 1 # tensor to numpy
|
assert len(model(torch.zeros(320, 640, 3).numpy().astype(np.uint8), imgsz=32)) == 1 # tensor to numpy
|
||||||
batch = [
|
batch = [
|
||||||
str(SOURCE), # filename
|
str(SOURCE), # filename
|
||||||
Path(SOURCE), # Path
|
Path(SOURCE), # Path
|
||||||
"https://ultralytics.com/images/zidane.jpg" if ONLINE else SOURCE, # URI
|
"https://ultralytics.com/images/zidane.jpg" if ONLINE else SOURCE, # URI
|
||||||
cv2.imread(str(SOURCE)), # OpenCV
|
cv2.imread(str(SOURCE)), # OpenCV
|
||||||
Image.open(SOURCE), # PIL
|
Image.open(SOURCE), # PIL
|
||||||
np.zeros((320, 640, 3)),
|
np.zeros((320, 640, 3), dtype=np.uint8), # numpy
|
||||||
] # numpy
|
]
|
||||||
assert len(model(batch, imgsz=32)) == len(batch) # multiple sources in a batch
|
assert len(model(batch, imgsz=32)) == len(batch) # multiple sources in a batch
|
||||||
|
|
||||||
# Test tensor inference
|
|
||||||
im = torch.rand((4, 3, 32, 32)) # batch-size 4, FP32 0.0-1.0 RGB order
|
|
||||||
results = model(im, imgsz=32)
|
|
||||||
assert len(results) == im.shape[0]
|
|
||||||
results = seg_model(im, imgsz=32)
|
|
||||||
assert len(results) == im.shape[0]
|
|
||||||
results = cls_model(im, imgsz=32)
|
|
||||||
assert len(results) == im.shape[0]
|
|
||||||
results = pose_model(im, imgsz=32)
|
|
||||||
assert len(results) == im.shape[0]
|
|
||||||
results = obb_model(im, imgsz=32)
|
|
||||||
assert len(results) == im.shape[0]
|
|
||||||
|
|
||||||
|
|
||||||
def test_predict_grey_and_4ch():
|
def test_predict_grey_and_4ch():
|
||||||
"""Test YOLO prediction on SOURCE converted to greyscale and 4-channel images."""
|
"""Test YOLO prediction on SOURCE converted to greyscale and 4-channel images."""
|
||||||
|
|
@ -236,10 +221,10 @@ def test_predict_callback_and_setup():
|
||||||
print(boxes)
|
print(boxes)
|
||||||
|
|
||||||
|
|
||||||
def test_results():
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
def test_results(model):
|
||||||
"""Test various result formats for the YOLO model."""
|
"""Test various result formats for the YOLO model."""
|
||||||
for m in "yolov8n-pose.pt", "yolov8n-seg.pt", "yolov8n.pt", "yolov8n-cls.pt":
|
results = YOLO(WEIGHTS_DIR / model)([SOURCE, SOURCE], imgsz=160)
|
||||||
results = YOLO(WEIGHTS_DIR / m)([SOURCE, SOURCE], imgsz=160)
|
|
||||||
for r in results:
|
for r in results:
|
||||||
r = r.cpu().numpy()
|
r = r.cpu().numpy()
|
||||||
r = r.to(device="cpu", dtype=torch.float32)
|
r = r.to(device="cpu", dtype=torch.float32)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.2.9"
|
__version__ = "8.2.10"
|
||||||
|
|
||||||
from ultralytics.data.explorer.explorer import Explorer
|
from ultralytics.data.explorer.explorer import Explorer
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@ TASK2METRIC = {
|
||||||
"pose": "metrics/mAP50-95(P)",
|
"pose": "metrics/mAP50-95(P)",
|
||||||
"obb": "metrics/mAP50-95(B)",
|
"obb": "metrics/mAP50-95(B)",
|
||||||
}
|
}
|
||||||
|
MODELS = {TASK2MODEL[task] for task in TASKS}
|
||||||
|
|
||||||
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
|
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
|
||||||
CLI_HELP_MSG = f"""
|
CLI_HELP_MSG = f"""
|
||||||
|
|
|
||||||
|
|
@ -387,26 +387,32 @@ class Results(SimpleClass):
|
||||||
|
|
||||||
def summary(self, normalize=False, decimals=5):
|
def summary(self, normalize=False, decimals=5):
|
||||||
"""Convert the results to a summarized format."""
|
"""Convert the results to a summarized format."""
|
||||||
if self.probs is not None:
|
|
||||||
LOGGER.warning("Warning: Classify results do not support the `summary()` method yet.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create list of detection dictionaries
|
# Create list of detection dictionaries
|
||||||
results = []
|
results = []
|
||||||
data = self.boxes.data.cpu().tolist()
|
if self.probs is not None:
|
||||||
|
class_id = self.probs.top1
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": self.names[class_id],
|
||||||
|
"class": class_id,
|
||||||
|
"confidence": round(self.probs.top1conf.item(), decimals),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
data = self.boxes or self.obb
|
||||||
|
is_obb = self.obb is not None
|
||||||
h, w = self.orig_shape if normalize else (1, 1)
|
h, w = self.orig_shape if normalize else (1, 1)
|
||||||
for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id
|
for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id
|
||||||
box = {
|
class_id, conf = int(row.cls), round(row.conf.item(), decimals)
|
||||||
"x1": round(row[0] / w, decimals),
|
box = (row.xyxyxyxy if is_obb else row.xyxy).squeeze().reshape(-1, 2).tolist()
|
||||||
"y1": round(row[1] / h, decimals),
|
xy = {}
|
||||||
"x2": round(row[2] / w, decimals),
|
for i, b in enumerate(box):
|
||||||
"y2": round(row[3] / h, decimals),
|
xy[f"x{i + 1}"] = round(b[0] / w, decimals)
|
||||||
}
|
xy[f"y{i + 1}"] = round(b[1] / h, decimals)
|
||||||
conf = round(row[-2], decimals)
|
result = {"name": self.names[class_id], "class": class_id, "confidence": conf, "box": xy}
|
||||||
class_id = int(row[-1])
|
if data.is_track:
|
||||||
result = {"name": self.names[class_id], "class": class_id, "confidence": conf, "box": box}
|
result["track_id"] = int(row.id.item()) # track ID
|
||||||
if self.boxes.is_track:
|
|
||||||
result["track_id"] = int(row[-3]) # track ID
|
|
||||||
if self.masks:
|
if self.masks:
|
||||||
result["segments"] = {
|
result["segments"] = {
|
||||||
"x": (self.masks.xy[i][:, 0] / w).round(decimals).tolist(),
|
"x": (self.masks.xy[i][:, 0] / w).round(decimals).tolist(),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue