PyCharm Code and Docs Inspect fixes v1 (#18461)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
126867e355
commit
7f1a50e893
26 changed files with 90 additions and 91 deletions
|
|
@ -76,7 +76,7 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
|
|||
def generate_crop_boxes(
|
||||
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
||||
) -> Tuple[List[List[int]], List[int]]:
|
||||
"""Generates crop boxes of varying sizes for multi-scale image processing, with layered overlapping regions."""
|
||||
"""Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions."""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_h, im_w = im_size
|
||||
short_side = min(im_h, im_w)
|
||||
|
|
|
|||
|
|
@ -502,11 +502,11 @@ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.T
|
|||
|
||||
class MultiScaleAttention(nn.Module):
|
||||
"""
|
||||
Implements multi-scale self-attention with optional query pooling for efficient feature extraction.
|
||||
Implements multiscale self-attention with optional query pooling for efficient feature extraction.
|
||||
|
||||
This class provides a flexible implementation of multi-scale attention, allowing for optional
|
||||
This class provides a flexible implementation of multiscale attention, allowing for optional
|
||||
downsampling of query features through pooling. It's designed to enhance the model's ability to
|
||||
capture multi-scale information in visual tasks.
|
||||
capture multiscale information in visual tasks.
|
||||
|
||||
Attributes:
|
||||
dim (int): Input dimension of the feature map.
|
||||
|
|
@ -518,7 +518,7 @@ class MultiScaleAttention(nn.Module):
|
|||
proj (nn.Linear): Output projection.
|
||||
|
||||
Methods:
|
||||
forward: Applies multi-scale attention to the input tensor.
|
||||
forward: Applies multiscale attention to the input tensor.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
|
|
@ -537,7 +537,7 @@ class MultiScaleAttention(nn.Module):
|
|||
num_heads: int,
|
||||
q_pool: nn.Module = None,
|
||||
):
|
||||
"""Initializes multi-scale attention with optional query pooling for efficient feature extraction."""
|
||||
"""Initializes multiscale attention with optional query pooling for efficient feature extraction."""
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
|
|
@ -552,7 +552,7 @@ class MultiScaleAttention(nn.Module):
|
|||
self.proj = nn.Linear(dim_out, dim_out)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies multi-scale attention with optional query pooling to extract multi-scale features."""
|
||||
"""Applies multiscale attention with optional query pooling to extract multiscale features."""
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (B, H * W, 3, nHead, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
||||
|
|
@ -582,9 +582,9 @@ class MultiScaleAttention(nn.Module):
|
|||
|
||||
class MultiScaleBlock(nn.Module):
|
||||
"""
|
||||
A multi-scale attention block with window partitioning and query pooling for efficient vision transformers.
|
||||
A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
|
||||
|
||||
This class implements a multi-scale attention mechanism with optional window partitioning and downsampling,
|
||||
This class implements a multiscale attention mechanism with optional window partitioning and downsampling,
|
||||
designed for use in vision transformer architectures.
|
||||
|
||||
Attributes:
|
||||
|
|
@ -601,7 +601,7 @@ class MultiScaleBlock(nn.Module):
|
|||
proj (nn.Linear | None): Projection layer for dimension mismatch.
|
||||
|
||||
Methods:
|
||||
forward: Processes input tensor through the multi-scale block.
|
||||
forward: Processes input tensor through the multiscale block.
|
||||
|
||||
Examples:
|
||||
>>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7)
|
||||
|
|
@ -623,7 +623,7 @@ class MultiScaleBlock(nn.Module):
|
|||
act_layer: nn.Module = nn.GELU,
|
||||
window_size: int = 0,
|
||||
):
|
||||
"""Initializes a multi-scale attention block with window partitioning and optional query pooling."""
|
||||
"""Initializes a multiscale attention block with window partitioning and optional query pooling."""
|
||||
super().__init__()
|
||||
|
||||
if isinstance(norm_layer, str):
|
||||
|
|
@ -660,7 +660,7 @@ class MultiScaleBlock(nn.Module):
|
|||
self.proj = nn.Linear(dim, dim_out)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Processes input through multi-scale attention and MLP, with optional windowing and downsampling."""
|
||||
"""Processes input through multiscale attention and MLP, with optional windowing and downsampling."""
|
||||
shortcut = x # B, H, W, C
|
||||
x = self.norm1(x)
|
||||
|
||||
|
|
|
|||
|
|
@ -425,7 +425,7 @@ class SAM2Model(torch.nn.Module):
|
|||
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
|
||||
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
|
||||
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
|
||||
object_score_logits: Tensor of shape (B,) with object score logits.
|
||||
object_score_logits: Tensor of shape (B) with object score logits.
|
||||
|
||||
Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
|
||||
|
||||
|
|
@ -643,7 +643,7 @@ class SAM2Model(torch.nn.Module):
|
|||
if not is_init_cond_frame:
|
||||
# Retrieve the memories encoded with the maskmem backbone
|
||||
to_cat_memory, to_cat_memory_pos_embed = [], []
|
||||
# Add conditioning frames's output first (all cond frames have t_pos=0 for
|
||||
# Add conditioning frame's output first (all cond frames have t_pos=0 for
|
||||
# when getting temporal positional embedding below)
|
||||
assert len(output_dict["cond_frame_outputs"]) > 0
|
||||
# Select a maximum number of temporally closest cond frames for cross attention
|
||||
|
|
|
|||
|
|
@ -1096,7 +1096,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
# to `propagate_in_video_preflight`).
|
||||
consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
|
||||
for is_cond in {False, True}:
|
||||
# Separately consolidate conditioning and non-conditioning temp outptus
|
||||
# Separately consolidate conditioning and non-conditioning temp outputs
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
# Find all the frames that contain temporary outputs for any objects
|
||||
# (these should be the frames that have just received clicks for mask inputs
|
||||
|
|
@ -1161,36 +1161,35 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
assert predictor.dataset is not None
|
||||
assert predictor.dataset.mode == "video"
|
||||
|
||||
inference_state = {}
|
||||
inference_state["num_frames"] = predictor.dataset.frames
|
||||
# inputs on each frame
|
||||
inference_state["point_inputs_per_obj"] = {}
|
||||
inference_state["mask_inputs_per_obj"] = {}
|
||||
# values that don't change across frames (so we only need to hold one copy of them)
|
||||
inference_state["constants"] = {}
|
||||
# mapping between client-side object id and model-side object index
|
||||
inference_state["obj_id_to_idx"] = OrderedDict()
|
||||
inference_state["obj_idx_to_id"] = OrderedDict()
|
||||
inference_state["obj_ids"] = []
|
||||
# A storage to hold the model's tracking results and states on each frame
|
||||
inference_state["output_dict"] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
inference_state = {
|
||||
"num_frames": predictor.dataset.frames,
|
||||
"point_inputs_per_obj": {}, # inputs points on each frame
|
||||
"mask_inputs_per_obj": {}, # inputs mask on each frame
|
||||
"constants": {}, # values that don't change across frames (so we only need to hold one copy of them)
|
||||
# mapping between client-side object id and model-side object index
|
||||
"obj_id_to_idx": OrderedDict(),
|
||||
"obj_idx_to_id": OrderedDict(),
|
||||
"obj_ids": [],
|
||||
# A storage to hold the model's tracking results and states on each frame
|
||||
"output_dict": {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
},
|
||||
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
||||
"output_dict_per_obj": {},
|
||||
# A temporary storage to hold new outputs when user interact with a frame
|
||||
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
|
||||
"temp_output_dict_per_obj": {},
|
||||
# Frames that already holds consolidated outputs from click or mask inputs
|
||||
# (we directly use their consolidated outputs during tracking)
|
||||
"consolidated_frame_inds": {
|
||||
"cond_frame_outputs": set(), # set containing frame indices
|
||||
"non_cond_frame_outputs": set(), # set containing frame indices
|
||||
},
|
||||
# metadata for each tracking frame (e.g. which direction it's tracked)
|
||||
"tracking_has_started": False,
|
||||
"frames_already_tracked": [],
|
||||
}
|
||||
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
||||
inference_state["output_dict_per_obj"] = {}
|
||||
# A temporary storage to hold new outputs when user interact with a frame
|
||||
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
|
||||
inference_state["temp_output_dict_per_obj"] = {}
|
||||
# Frames that already holds consolidated outputs from click or mask inputs
|
||||
# (we directly use their consolidated outputs during tracking)
|
||||
inference_state["consolidated_frame_inds"] = {
|
||||
"cond_frame_outputs": set(), # set containing frame indices
|
||||
"non_cond_frame_outputs": set(), # set containing frame indices
|
||||
}
|
||||
# metadata for each tracking frame (e.g. which direction it's tracked)
|
||||
inference_state["tracking_has_started"] = False
|
||||
inference_state["frames_already_tracked"] = []
|
||||
predictor.inference_state = inference_state
|
||||
|
||||
def get_im_features(self, im, batch=1):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue