ultralytics 8.2.70 Segment Anything Model 2 (SAM 2) (#14813)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
80f699ae21
commit
8648572809
36 changed files with 3276 additions and 77 deletions
|
|
@ -44,6 +44,7 @@ class SAM(Model):
|
|||
"""
|
||||
if model and Path(model).suffix not in {".pt", ".pth"}:
|
||||
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
||||
self.is_sam2 = "sam2" in Path(model).stem
|
||||
super().__init__(model=model, task="segment")
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
|
|
@ -54,7 +55,12 @@ class SAM(Model):
|
|||
weights (str): Path to the weights file.
|
||||
task (str, optional): Task name. Defaults to None.
|
||||
"""
|
||||
self.model = build_sam(weights)
|
||||
if self.is_sam2:
|
||||
from ..sam2.build import build_sam2
|
||||
|
||||
self.model = build_sam2(weights)
|
||||
else:
|
||||
self.model = build_sam(weights)
|
||||
|
||||
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""
|
||||
|
|
@ -112,4 +118,6 @@ class SAM(Model):
|
|||
Returns:
|
||||
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
|
||||
"""
|
||||
return {"segment": {"predictor": Predictor}}
|
||||
from ..sam2.predict import SAM2Predictor
|
||||
|
||||
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue