ultralytics 8.2.70 Segment Anything Model 2 (SAM 2) (#14813)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
80f699ae21
commit
8648572809
36 changed files with 3276 additions and 77 deletions
|
|
@ -211,6 +211,8 @@ class PromptEncoder(nn.Module):
|
|||
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
||||
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
||||
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
||||
point_embedding[labels == 2] += self.point_embeddings[2].weight
|
||||
point_embedding[labels == 3] += self.point_embeddings[3].weight
|
||||
return point_embedding
|
||||
|
||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -226,8 +228,8 @@ class PromptEncoder(nn.Module):
|
|||
"""Embeds mask inputs."""
|
||||
return self.mask_downscaling(masks)
|
||||
|
||||
@staticmethod
|
||||
def _get_batch_size(
|
||||
self,
|
||||
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue