Implement all missing docstrings (#5298)
Co-authored-by: snyk-bot <snyk-bot@snyk.io> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e7f0658744
commit
7fd5dcbd86
26 changed files with 649 additions and 79 deletions
|
|
@ -9,14 +9,45 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|||
|
||||
|
||||
class FastSAMPredictor(DetectionPredictor):
|
||||
"""
|
||||
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
|
||||
YOLO framework.
|
||||
|
||||
This class extends the DetectionPredictor, customizing the prediction pipeline specifically for fast SAM.
|
||||
It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing
|
||||
for single-class segmentation.
|
||||
|
||||
Attributes:
|
||||
cfg (dict): Configuration parameters for prediction.
|
||||
overrides (dict, optional): Optional parameter overrides for custom behavior.
|
||||
_callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initializes FastSAMPredictor class by inheriting from DetectionPredictor and setting task to 'segment'."""
|
||||
"""
|
||||
Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration parameters for prediction.
|
||||
overrides (dict, optional): Optional parameter overrides for custom behavior.
|
||||
_callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocesses the predictions, applies non-max suppression, scales the boxes, and returns the results."""
|
||||
"""
|
||||
Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image
|
||||
size, and returns the final results.
|
||||
|
||||
Args:
|
||||
preds (list): The raw output predictions from the model.
|
||||
img (torch.Tensor): The processed image tensor.
|
||||
orig_imgs (list | torch.Tensor): The original image or list of images.
|
||||
|
||||
Returns:
|
||||
(list): A list of Results objects, each containing processed boxes, masks, and other metadata.
|
||||
"""
|
||||
p = ops.non_max_suppression(
|
||||
preds[0],
|
||||
self.args.conf,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,15 @@ from ultralytics.utils import TQDM
|
|||
|
||||
|
||||
class FastSAMPrompt:
|
||||
"""
|
||||
Fast Segment Anything Model class for image annotation and visualization.
|
||||
|
||||
Attributes:
|
||||
device (str): Computing device ('cuda' or 'cpu').
|
||||
results: Object detection or segmentation results.
|
||||
source: Source image or image path.
|
||||
clip: CLIP model for linear assignment.
|
||||
"""
|
||||
|
||||
def __init__(self, source, results, device='cuda') -> None:
|
||||
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
|
||||
|
|
@ -92,6 +101,20 @@ class FastSAMPrompt:
|
|||
better_quality=True,
|
||||
retina=False,
|
||||
with_contours=True):
|
||||
"""
|
||||
Plots annotations, bounding boxes, and points on images and saves the output.
|
||||
|
||||
Args:
|
||||
annotations (list): Annotations to be plotted.
|
||||
output (str or Path): Output directory for saving the plots.
|
||||
bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
|
||||
points (list, optional): Points to be plotted. Defaults to None.
|
||||
point_label (list, optional): Labels for the points. Defaults to None.
|
||||
mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
|
||||
better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. Defaults to True.
|
||||
retina (bool, optional): Whether to use retina mask. Defaults to False.
|
||||
with_contours (bool, optional): Whether to plot contours. Defaults to True.
|
||||
"""
|
||||
pbar = TQDM(annotations, total=len(annotations))
|
||||
for ann in pbar:
|
||||
result_name = os.path.basename(ann.path)
|
||||
|
|
@ -160,6 +183,20 @@ class FastSAMPrompt:
|
|||
target_height=960,
|
||||
target_width=960,
|
||||
):
|
||||
"""
|
||||
Quickly shows the mask annotations on the given matplotlib axis.
|
||||
|
||||
Args:
|
||||
annotation (array-like): Mask annotation.
|
||||
ax (matplotlib.axes.Axes): Matplotlib axis.
|
||||
random_color (bool, optional): Whether to use random color for masks. Defaults to False.
|
||||
bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
|
||||
points (list, optional): Points to be plotted. Defaults to None.
|
||||
pointlabel (list, optional): Labels for the points. Defaults to None.
|
||||
retinamask (bool, optional): Whether to use retina mask. Defaults to True.
|
||||
target_height (int, optional): Target height for resizing. Defaults to 960.
|
||||
target_width (int, optional): Target width for resizing. Defaults to 960.
|
||||
"""
|
||||
n, h, w = annotation.shape # batch, height, width
|
||||
|
||||
areas = np.sum(annotation, axis=(1, 2))
|
||||
|
|
|
|||
|
|
@ -5,9 +5,35 @@ from ultralytics.utils.metrics import SegmentMetrics
|
|||
|
||||
|
||||
class FastSAMValidator(SegmentationValidator):
|
||||
"""
|
||||
Custom validation class for fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
|
||||
|
||||
Extends the SegmentationValidator class, customizing the validation process specifically for fast SAM. This class
|
||||
sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
|
||||
to avoid errors during validation.
|
||||
|
||||
Attributes:
|
||||
dataloader: The data loader object used for validation.
|
||||
save_dir (str): The directory where validation results will be saved.
|
||||
pbar: A progress bar object.
|
||||
args: Additional arguments for customization.
|
||||
_callbacks: List of callback functions to be invoked during validation.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
|
||||
"""
|
||||
Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
||||
save_dir (Path, optional): Directory to save results.
|
||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
_callbacks (dict): Dictionary to store various callback functions.
|
||||
|
||||
Notes:
|
||||
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue