ultralytics 8.2.84 new SAM flexible imgsz inference (#15882)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
5d66140ce1
commit
7053169fd0
6 changed files with 70 additions and 7 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.2.83"
|
||||
__version__ = "8.2.84"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class SAM(Model):
|
|||
... print(f"Detected {len(r.masks)} masks")
|
||||
"""
|
||||
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
|
||||
kwargs.update(overrides)
|
||||
kwargs = {**overrides, **kwargs}
|
||||
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
||||
return super().predict(source, stream, prompts=prompts, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -151,7 +151,12 @@ class ImageEncoderViT(nn.Module):
|
|||
"""Processes input through patch embedding, positional embedding, transformer blocks, and neck module."""
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
pos_embed = (
|
||||
F.interpolate(self.pos_embed.permute(0, 3, 1, 2), scale_factor=self.img_size / 1024).permute(0, 2, 3, 1)
|
||||
if self.img_size != 1024
|
||||
else self.pos_embed
|
||||
)
|
||||
x = x + pos_embed
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
return self.neck(x.permute(0, 3, 1, 2))
|
||||
|
|
|
|||
|
|
@ -90,6 +90,19 @@ class SAMModel(nn.Module):
|
|||
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||
|
||||
def set_imgsz(self, imgsz):
|
||||
"""
|
||||
Set image size to make model compatible with different image sizes.
|
||||
|
||||
Args:
|
||||
imgsz (Tuple[int, int]): The size of the input image.
|
||||
"""
|
||||
if hasattr(self.image_encoder, "set_imgsz"):
|
||||
self.image_encoder.set_imgsz(imgsz)
|
||||
self.prompt_encoder.input_image_size = imgsz
|
||||
self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # 16 is fixed as patch size of ViT model
|
||||
self.image_encoder.img_size = imgsz[0]
|
||||
|
||||
|
||||
class SAM2Model(torch.nn.Module):
|
||||
"""
|
||||
|
|
@ -940,3 +953,14 @@ class SAM2Model(torch.nn.Module):
|
|||
# don't overlap (here sigmoid(-10.0)=4.5398e-05)
|
||||
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
||||
return pred_masks
|
||||
|
||||
def set_imgsz(self, imgsz):
|
||||
"""
|
||||
Set image size to make model compatible with different image sizes.
|
||||
|
||||
Args:
|
||||
imgsz (Tuple[int, int]): The size of the input image.
|
||||
"""
|
||||
self.image_size = imgsz[0]
|
||||
self.sam_prompt_encoder.input_image_size = imgsz
|
||||
self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
|
||||
|
|
|
|||
|
|
@ -982,10 +982,31 @@ class TinyViT(nn.Module):
|
|||
layer = self.layers[i]
|
||||
x = layer(x)
|
||||
batch, _, channel = x.shape
|
||||
x = x.view(batch, 64, 64, channel)
|
||||
x = x.view(batch, self.patches_resolution[0] // 4, self.patches_resolution[1] // 4, channel)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
return self.neck(x)
|
||||
|
||||
def forward(self, x):
|
||||
"""Performs the forward pass through the TinyViT model, extracting features from the input image."""
|
||||
return self.forward_features(x)
|
||||
|
||||
def set_imgsz(self, imgsz=[1024, 1024]):
|
||||
"""
|
||||
Set image size to make model compatible with different image sizes.
|
||||
|
||||
Args:
|
||||
imgsz (Tuple[int, int]): The size of the input image.
|
||||
"""
|
||||
imgsz = [s // 4 for s in imgsz]
|
||||
self.patches_resolution = imgsz
|
||||
for i, layer in enumerate(self.layers):
|
||||
input_resolution = (
|
||||
imgsz[0] // (2 ** (i - 1 if i == 3 else i)),
|
||||
imgsz[1] // (2 ** (i - 1 if i == 3 else i)),
|
||||
)
|
||||
layer.input_resolution = input_resolution
|
||||
if layer.downsample is not None:
|
||||
layer.downsample.input_resolution = input_resolution
|
||||
if isinstance(layer, BasicLayer):
|
||||
for b in layer.blocks:
|
||||
b.input_resolution = input_resolution
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides.update(dict(task="segment", mode="predict", imgsz=1024))
|
||||
overrides.update(dict(task="segment", mode="predict"))
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.retina_masks = True
|
||||
self.im = None
|
||||
|
|
@ -455,8 +455,11 @@ class Predictor(BasePredictor):
|
|||
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
|
||||
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
|
||||
|
||||
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
|
||||
masks = masks > self.model.mask_threshold # to bool
|
||||
if len(masks) == 0:
|
||||
masks = None
|
||||
else:
|
||||
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
|
||||
masks = masks > self.model.mask_threshold # to bool
|
||||
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
|
||||
# Reset segment-all mode.
|
||||
self.segment_all = False
|
||||
|
|
@ -522,6 +525,10 @@ class Predictor(BasePredictor):
|
|||
|
||||
def get_im_features(self, im):
|
||||
"""Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
|
||||
assert (
|
||||
isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1]
|
||||
), f"SAM models only support square image size, but got {self.imgsz}."
|
||||
self.model.set_imgsz(self.imgsz)
|
||||
return self.model.image_encoder(im)
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
|
|
@ -761,6 +768,12 @@ class SAM2Predictor(Predictor):
|
|||
|
||||
def get_im_features(self, im):
|
||||
"""Extracts image features from the SAM image encoder for subsequent processing."""
|
||||
assert (
|
||||
isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1]
|
||||
), f"SAM 2 models only support square image size, but got {self.imgsz}."
|
||||
self.model.set_imgsz(self.imgsz)
|
||||
self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]]
|
||||
|
||||
backbone_out = self.model.forward_image(im)
|
||||
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
||||
if self.model.directly_add_no_mem_embed:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue