ultralytics 8.2.70 Segment Anything Model 2 (SAM 2) (#14813)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-07-30 22:06:49 +08:00 committed by GitHub
parent 80f699ae21
commit 8648572809
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 3276 additions and 77 deletions

View file

@ -44,6 +44,7 @@ class SAM(Model):
"""
if model and Path(model).suffix not in {".pt", ".pth"}:
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
self.is_sam2 = "sam2" in Path(model).stem
super().__init__(model=model, task="segment")
def _load(self, weights: str, task=None):
@ -54,7 +55,12 @@ class SAM(Model):
weights (str): Path to the weights file.
task (str, optional): Task name. Defaults to None.
"""
self.model = build_sam(weights)
if self.is_sam2:
from ..sam2.build import build_sam2
self.model = build_sam2(weights)
else:
self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""
@ -112,4 +118,6 @@ class SAM(Model):
Returns:
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
"""
return {"segment": {"predictor": Predictor}}
from ..sam2.predict import SAM2Predictor
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}