ultralytics 8.2.73 Meta SAM2 Refactor (#14867)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-08-05 08:53:45 +08:00 committed by GitHub
parent bea4c93278
commit 5d9046abda
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 4542 additions and 3624 deletions

View file

@ -20,27 +20,46 @@ from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info
from .build import build_sam
from .predict import Predictor
from .predict import Predictor, SAM2Predictor
class SAM(Model):
"""
SAM (Segment Anything Model) interface class.
SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as
bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B
dataset.
This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
boxes, points, or labels, and features zero-shot performance capabilities.
Attributes:
model (torch.nn.Module): The loaded SAM model.
is_sam2 (bool): Indicates whether the model is SAM2 variant.
task (str): The task type, set to "segment" for SAM models.
Methods:
predict: Performs segmentation prediction on the given image or video source.
info: Logs information about the SAM model.
Examples:
>>> sam = SAM('sam_b.pt')
>>> results = sam.predict('image.jpg', points=[[500, 375]])
>>> for r in results:
>>> print(f"Detected {len(r.masks)} masks")
"""
def __init__(self, model="sam_b.pt") -> None:
"""
Initializes the SAM model with a pre-trained model file.
Initializes the SAM (Segment Anything Model) instance.
Args:
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
Raises:
NotImplementedError: If the model file extension is not .pt or .pth.
Examples:
>>> sam = SAM('sam_b.pt')
>>> print(sam.is_sam2)
"""
if model and Path(model).suffix not in {".pt", ".pth"}:
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
@ -51,30 +70,40 @@ class SAM(Model):
"""
Loads the specified weights into the SAM model.
Args:
weights (str): Path to the weights file.
task (str, optional): Task name. Defaults to None.
"""
if self.is_sam2:
from ..sam2.build import build_sam2
This method initializes the SAM model with the provided weights file, setting up the model architecture
and loading the pre-trained parameters.
self.model = build_sam2(weights)
else:
self.model = build_sam(weights)
Args:
weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
Examples:
>>> sam = SAM('sam_b.pt')
>>> sam._load('path/to/custom_weights.pt')
"""
self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""
Performs segmentation prediction on the given image or video source.
Args:
source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
points (list, optional): List of points for prompted segmentation. Defaults to None.
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
a numpy.ndarray object.
stream (bool): If True, enables real-time streaming.
bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
points (List[List[float]] | None): List of points for prompted segmentation.
labels (List[int] | None): List of labels for prompted segmentation.
**kwargs (Any): Additional keyword arguments for prediction.
Returns:
(list): The model predictions.
(List): The model predictions.
Examples:
>>> sam = SAM('sam_b.pt')
>>> results = sam.predict('image.jpg', points=[[500, 375]])
>>> for r in results:
... print(f"Detected {len(r.masks)} masks")
"""
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
kwargs.update(overrides)
@ -83,17 +112,27 @@ class SAM(Model):
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""
Alias for the 'predict' method.
Performs segmentation prediction on the given image or video source.
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
for segmentation tasks.
Args:
source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
points (list, optional): List of points for prompted segmentation. Defaults to None.
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image
object, or a numpy.ndarray object.
stream (bool): If True, enables real-time streaming.
bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
points (List[List[float]] | None): List of points for prompted segmentation.
labels (List[int] | None): List of labels for prompted segmentation.
**kwargs (Any): Additional keyword arguments to be passed to the predict method.
Returns:
(list): The model predictions.
(List): The model predictions, typically containing segmentation masks and other relevant information.
Examples:
>>> sam = SAM('sam_b.pt')
>>> results = sam('image.jpg', points=[[500, 375]])
>>> print(f"Detected {len(results[0].masks)} masks")
"""
return self.predict(source, stream, bboxes, points, labels, **kwargs)
@ -101,12 +140,20 @@ class SAM(Model):
"""
Logs information about the SAM model.
This method provides details about the Segment Anything Model (SAM), including its architecture,
parameters, and computational requirements.
Args:
detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
verbose (bool, optional): If True, displays information on the console. Defaults to True.
detailed (bool): If True, displays detailed information about the model layers and operations.
verbose (bool): If True, prints the information to the console.
Returns:
(tuple): A tuple containing the model's information.
(Tuple): A tuple containing the model's information (string representations of the model).
Examples:
>>> sam = SAM('sam_b.pt')
>>> info = sam.info()
>>> print(info[0]) # Print summary information
"""
return model_info(self.model, detailed=detailed, verbose=verbose)
@ -116,8 +163,13 @@ class SAM(Model):
Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
Returns:
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
"""
from ..sam2.predict import SAM2Predictor
(Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor
class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
Examples:
>>> sam = SAM('sam_b.pt')
>>> task_map = sam.task_map
>>> print(task_map)
{'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
"""
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}