ultralytics 8.2.73 Meta SAM2 Refactor (#14867)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-08-05 08:53:45 +08:00 committed by GitHub
parent bea4c93278
commit 5d9046abda
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 4542 additions and 3624 deletions

View file

@ -11,7 +11,7 @@ import torch
def is_box_near_crop_edge(
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
) -> torch.Tensor:
"""Return a boolean tensor indicating if boxes are near the crop edge."""
"""Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
@ -22,7 +22,7 @@ def is_box_near_crop_edge(
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
"""Yield batches of data from the input arguments."""
"""Yields batches of data from input arguments with specified batch size for efficient processing."""
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
@ -33,12 +33,26 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
"""
Computes the stability score for a batch of masks.
The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high
and low values.
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
high and low values.
Args:
masks (torch.Tensor): Batch of predicted mask logits.
mask_threshold (float): Threshold value for creating binary masks.
threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
Returns:
(torch.Tensor): Stability scores for each mask in the batch.
Notes:
- One mask is always contained inside the other.
- Save memory by preventing unnecessary cast to torch.int64
- Memory is saved by preventing unnecessary cast to torch.int64.
Examples:
>>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
>>> mask_threshold = 0.5
>>> threshold_offset = 0.1
>>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
"""
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
@ -46,7 +60,7 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
def build_point_grid(n_per_side: int) -> np.ndarray:
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1]."""
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
@ -55,18 +69,14 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
"""Generate point grids for all crop layers."""
"""Generates point grids for multiple crop layers with varying scales and densities."""
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
def generate_crop_boxes(
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
) -> Tuple[List[List[int]], List[int]]:
"""
Generates a list of crop boxes of different sizes.
Each layer has (2**i)**2 boxes for the ith layer.
"""
"""Generates crop boxes of varying sizes for multi-scale image processing, with layered overlapping regions."""
crop_boxes, layer_idxs = [], []
im_h, im_w = im_size
short_side = min(im_h, im_w)
@ -99,7 +109,7 @@ def generate_crop_boxes(
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
"""Uncrop bounding boxes by adding the crop box offset."""
"""Uncrop bounding boxes by adding the crop box offset to their coordinates."""
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
# Check if boxes has a channel dimension
@ -109,7 +119,7 @@ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
"""Uncrop points by adding the crop box offset."""
"""Uncrop points by adding the crop box offset to their coordinates."""
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0]], device=points.device)
# Check if points has a channel dimension
@ -119,7 +129,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
"""Uncrop masks by padding them to the original image size."""
"""Uncrop masks by padding them to the original image size, handling coordinate transformations."""
x0, y0, x1, y1 = crop_box
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
return masks
@ -130,7 +140,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
"""Removes small disconnected regions or holes in a mask based on area threshold and mode."""
import cv2 # type: ignore
assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
@ -150,11 +160,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""
Calculates boxes in XYXY format around masks.
Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
"""
"""Calculates bounding boxes in XYXY format around binary masks, handling empty masks and various input shapes."""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)