Ultralytics Code Refactor https://ultralytics.com/actions (#16047)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
95d54828bb
commit
ac2c2be8f3
12 changed files with 45 additions and 62 deletions
|
|
@ -736,7 +736,7 @@ class PositionEmbeddingSine(nn.Module):
|
|||
self.num_pos_feats = num_pos_feats // 2
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
if scale is not None and not normalize:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
|
|
@ -763,8 +763,7 @@ class PositionEmbeddingSine(nn.Module):
|
|||
def encode_boxes(self, x, y, w, h):
|
||||
"""Encodes box coordinates and dimensions into positional embeddings for detection."""
|
||||
pos_x, pos_y = self._encode_xy(x, y)
|
||||
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
|
||||
return pos
|
||||
return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
|
||||
|
||||
encode = encode_boxes # Backwards compatibility
|
||||
|
||||
|
|
@ -775,8 +774,7 @@ class PositionEmbeddingSine(nn.Module):
|
|||
assert bx == by and nx == ny and bx == bl and nx == nl
|
||||
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
|
||||
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
|
||||
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
|
||||
return pos
|
||||
return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor):
|
||||
|
|
|
|||
|
|
@ -435,9 +435,9 @@ class SAM2MaskDecoder(nn.Module):
|
|||
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
|
||||
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
|
||||
|
||||
hyper_in_list: List[torch.Tensor] = []
|
||||
for i in range(self.num_mask_tokens):
|
||||
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
|
||||
hyper_in_list: List[torch.Tensor] = [
|
||||
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
|
||||
]
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding.shape
|
||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
||||
|
|
@ -459,8 +459,7 @@ class SAM2MaskDecoder(nn.Module):
|
|||
stability_delta = self.dynamic_multimask_stability_delta
|
||||
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
||||
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
||||
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
|
||||
return stability_scores
|
||||
return torch.where(area_u > 0, area_i / area_u, 1.0)
|
||||
|
||||
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -491,12 +491,11 @@ class ImageEncoder(nn.Module):
|
|||
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
||||
|
||||
src = features[-1]
|
||||
output = {
|
||||
return {
|
||||
"vision_features": src,
|
||||
"vision_pos_enc": pos,
|
||||
"backbone_fpn": features,
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class FpnNeck(nn.Module):
|
||||
|
|
@ -577,7 +576,7 @@ class FpnNeck(nn.Module):
|
|||
|
||||
self.convs.append(current)
|
||||
self.fpn_interp_model = fpn_interp_model
|
||||
assert fuse_type in ["sum", "avg"]
|
||||
assert fuse_type in {"sum", "avg"}
|
||||
self.fuse_type = fuse_type
|
||||
|
||||
# levels to have top-down features in its outputs
|
||||
|
|
|
|||
|
|
@ -671,26 +671,19 @@ class SAM2Model(torch.nn.Module):
|
|||
t_rel = self.num_maskmem - t_pos # how many frames before current frame
|
||||
if t_rel == 1:
|
||||
# for t_rel == 1, we take the last frame (regardless of r)
|
||||
if not track_in_reverse:
|
||||
# the frame immediately before this frame (i.e. frame_idx - 1)
|
||||
prev_frame_idx = frame_idx - t_rel
|
||||
else:
|
||||
# the frame immediately after this frame (i.e. frame_idx + 1)
|
||||
prev_frame_idx = frame_idx + t_rel
|
||||
prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel
|
||||
elif not track_in_reverse:
|
||||
# first find the nearest frame among every r-th frames before this frame
|
||||
# for r=1, this would be (frame_idx - 2)
|
||||
prev_frame_idx = ((frame_idx - 2) // r) * r
|
||||
# then seek further among every r-th frames
|
||||
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
|
||||
else:
|
||||
# for t_rel >= 2, we take the memory frame from every r-th frames
|
||||
if not track_in_reverse:
|
||||
# first find the nearest frame among every r-th frames before this frame
|
||||
# for r=1, this would be (frame_idx - 2)
|
||||
prev_frame_idx = ((frame_idx - 2) // r) * r
|
||||
# then seek further among every r-th frames
|
||||
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
|
||||
else:
|
||||
# first find the nearest frame among every r-th frames after this frame
|
||||
# for r=1, this would be (frame_idx + 2)
|
||||
prev_frame_idx = -(-(frame_idx + 2) // r) * r
|
||||
# then seek further among every r-th frames
|
||||
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
|
||||
# first find the nearest frame among every r-th frames after this frame
|
||||
# for r=1, this would be (frame_idx + 2)
|
||||
prev_frame_idx = -(-(frame_idx + 2) // r) * r
|
||||
# then seek further among every r-th frames
|
||||
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
|
||||
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
|
||||
if out is None:
|
||||
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
|
||||
|
|
@ -739,7 +732,7 @@ class SAM2Model(torch.nn.Module):
|
|||
if out is not None:
|
||||
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
||||
# If we have at least one object pointer, add them to the across attention
|
||||
if len(pos_and_ptrs) > 0:
|
||||
if pos_and_ptrs:
|
||||
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
||||
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
|
||||
obj_ptrs = torch.stack(ptrs_list, dim=0)
|
||||
|
|
@ -930,12 +923,11 @@ class SAM2Model(torch.nn.Module):
|
|||
def _use_multimask(self, is_init_cond_frame, point_inputs):
|
||||
"""Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
|
||||
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
|
||||
multimask_output = (
|
||||
return (
|
||||
self.multimask_output_in_sam
|
||||
and (is_init_cond_frame or self.multimask_output_for_tracking)
|
||||
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
|
||||
)
|
||||
return multimask_output
|
||||
|
||||
def _apply_non_overlapping_constraints(self, pred_masks):
|
||||
"""Applies non-overlapping constraints to masks, keeping highest scoring object per location."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue