ultralytics 8.0.229 add model.embed() method (#7098)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2023-12-22 15:32:06 +01:00 committed by GitHub
parent 38eaf5e29f
commit 5b3e20379f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 65 additions and 14 deletions

View file

@ -511,3 +511,13 @@ def test_model_tune():
"""Tune YOLO model for performance."""
YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
def test_model_embeddings():
"""Test YOLO model embeddings."""
model_detect = YOLO(MODEL)
model_segment = YOLO(WEIGHTS_DIR / 'yolov8n-seg.pt')
for batch in [SOURCE], [SOURCE, SOURCE]: # test batch size 1 and 2
assert len(model_detect.embed(source=batch, imgsz=32)) == len(batch)
assert len(model_segment.embed(source=batch, imgsz=32)) == len(batch)