ultralytics 8.0.134 add MobileSAM support (#3474)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Chaoning Zhang 2023-07-13 20:25:56 +08:00 committed by GitHub
parent c55a98ab8e
commit 201e69e4e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
32 changed files with 1472 additions and 841 deletions

View file

@ -4,8 +4,8 @@ SAM model interface
"""
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.utils.torch_utils import model_info
from ...yolo.utils.torch_utils import model_info
from .build import build_sam
from .predict import Predictor
@ -20,16 +20,16 @@ class SAM:
self.task = 'segment' # required
self.predictor = None # reuse predictor
def predict(self, source, stream=False, **kwargs):
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Predicts and returns segmentation masks for given image or video source."""
overrides = dict(conf=0.25, task='segment', mode='predict')
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
overrides.update(kwargs) # prefer kwargs
if not self.predictor:
self.predictor = Predictor(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream)
return self.predictor(source, stream=stream, bboxes=bboxes, points=points, labels=labels)
def train(self, **kwargs):
"""Function trains models but raises an error as SAM models do not support training."""
@ -39,9 +39,9 @@ class SAM:
"""Run validation given dataset."""
raise NotImplementedError("SAM models don't support validation")
def __call__(self, source=None, stream=False, **kwargs):
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
return self.predict(source, stream, bboxes, points, labels, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""