ultralytics 8.0.134 add MobileSAM support (#3474)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
c55a98ab8e
commit
201e69e4e4
32 changed files with 1472 additions and 841 deletions
|
|
@ -4,8 +4,8 @@ SAM model interface
|
|||
"""
|
||||
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.utils.torch_utils import model_info
|
||||
|
||||
from ...yolo.utils.torch_utils import model_info
|
||||
from .build import build_sam
|
||||
from .predict import Predictor
|
||||
|
||||
|
|
@ -20,16 +20,16 @@ class SAM:
|
|||
self.task = 'segment' # required
|
||||
self.predictor = None # reuse predictor
|
||||
|
||||
def predict(self, source, stream=False, **kwargs):
|
||||
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""Predicts and returns segmentation masks for given image or video source."""
|
||||
overrides = dict(conf=0.25, task='segment', mode='predict')
|
||||
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
if not self.predictor:
|
||||
self.predictor = Predictor(overrides=overrides)
|
||||
self.predictor.setup_model(model=self.model)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
return self.predictor(source, stream=stream)
|
||||
return self.predictor(source, stream=stream, bboxes=bboxes, points=points, labels=labels)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""Function trains models but raises an error as SAM models do not support training."""
|
||||
|
|
@ -39,9 +39,9 @@ class SAM:
|
|||
"""Run validation given dataset."""
|
||||
raise NotImplementedError("SAM models don't support validation")
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Raises error if object has no requested attribute."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue