ultralytics 8.1.14 new YOLOv8-World models (#8054)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
f9e9cdf2c3
commit
850ca8587f
19 changed files with 683 additions and 32 deletions
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
from .rtdetr import RTDETR
|
||||
from .sam import SAM
|
||||
from .yolo import YOLO
|
||||
from .yolo import YOLO, YOLOWorld
|
||||
|
||||
__all__ = "YOLO", "RTDETR", "SAM" # allow simpler import
|
||||
__all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld" # allow simpler import
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class FastSAMPrompt:
|
|||
|
||||
# Import and assign clip
|
||||
try:
|
||||
import clip # for linear_assignment
|
||||
import clip
|
||||
except ImportError:
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
from ultralytics.models.yolo import classify, detect, obb, pose, segment
|
||||
|
||||
from .model import YOLO
|
||||
from .model import YOLO, YOLOWorld
|
||||
|
||||
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO"
|
||||
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld"
|
||||
|
|
|
|||
|
|
@ -1,13 +1,27 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
|
||||
from ultralytics.utils import yaml_load, ROOT
|
||||
|
||||
|
||||
class YOLO(Model):
|
||||
"""YOLO (You Only Look Once) object detection model."""
|
||||
|
||||
def __init__(self, model="yolov8n.pt", task=None, verbose=False):
|
||||
"""Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
|
||||
stem = Path(model).stem # filename stem without suffix, i.e. "yolov8n"
|
||||
if "-world" in stem:
|
||||
new_instance = YOLOWorld(model)
|
||||
self.__class__ = type(new_instance)
|
||||
self.__dict__ = new_instance.__dict__
|
||||
else:
|
||||
# Continue with default YOLO initialization
|
||||
super().__init__(model=model, task=task, verbose=verbose)
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Map head to model, trainer, validator, and predictor classes."""
|
||||
|
|
@ -43,3 +57,49 @@ class YOLO(Model):
|
|||
"predictor": yolo.obb.OBBPredictor,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class YOLOWorld(Model):
|
||||
"""YOLO-World object detection model."""
|
||||
|
||||
def __init__(self, model="yolov8s-world.pt") -> None:
|
||||
"""
|
||||
Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.
|
||||
|
||||
Args:
|
||||
model (str): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
|
||||
"""
|
||||
super().__init__(model=model, task="detect")
|
||||
|
||||
# Assign default COCO class names
|
||||
self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Map head to model, validator, and predictor classes."""
|
||||
return {
|
||||
"detect": {
|
||||
"model": WorldModel,
|
||||
"validator": yolo.detect.DetectionValidator,
|
||||
"predictor": yolo.detect.DetectionPredictor,
|
||||
}
|
||||
}
|
||||
|
||||
def set_classes(self, classes):
|
||||
"""
|
||||
Set classes.
|
||||
|
||||
Args:
|
||||
classes (List(str)): A list of categories i.e ["person"].
|
||||
"""
|
||||
self.model.set_classes(classes)
|
||||
# Remove background if it's given
|
||||
background = " "
|
||||
if background in classes:
|
||||
classes.remove(background)
|
||||
self.model.names = classes
|
||||
|
||||
# Reset method class names
|
||||
# self.predictor = None # reset predictor otherwise old names remain
|
||||
if self.predictor:
|
||||
self.predictor.model.names = classes
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue