ultralytics 8.1.14 new YOLOv8-World models (#8054)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-02-14 18:46:26 +08:00 committed by GitHub
parent f9e9cdf2c3
commit 850ca8587f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 683 additions and 32 deletions

View file

@ -1,13 +1,27 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from pathlib import Path
from ultralytics.engine.model import Model
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
from ultralytics.utils import yaml_load, ROOT
class YOLO(Model):
"""YOLO (You Only Look Once) object detection model."""
def __init__(self, model="yolov8n.pt", task=None, verbose=False):
"""Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
stem = Path(model).stem # filename stem without suffix, i.e. "yolov8n"
if "-world" in stem:
new_instance = YOLOWorld(model)
self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__
else:
# Continue with default YOLO initialization
super().__init__(model=model, task=task, verbose=verbose)
@property
def task_map(self):
"""Map head to model, trainer, validator, and predictor classes."""
@ -43,3 +57,49 @@ class YOLO(Model):
"predictor": yolo.obb.OBBPredictor,
},
}
class YOLOWorld(Model):
"""YOLO-World object detection model."""
def __init__(self, model="yolov8s-world.pt") -> None:
"""
Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.
Args:
model (str): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
"""
super().__init__(model=model, task="detect")
# Assign default COCO class names
self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")
@property
def task_map(self):
"""Map head to model, validator, and predictor classes."""
return {
"detect": {
"model": WorldModel,
"validator": yolo.detect.DetectionValidator,
"predictor": yolo.detect.DetectionPredictor,
}
}
def set_classes(self, classes):
"""
Set classes.
Args:
classes (List(str)): A list of categories i.e ["person"].
"""
self.model.set_classes(classes)
# Remove background if it's given
background = " "
if background in classes:
classes.remove(background)
self.model.names = classes
# Reset method class names
# self.predictor = None # reset predictor otherwise old names remain
if self.predictor:
self.predictor.model.names = classes