ultralytics 8.2.84 new SAM flexible imgsz inference (#15882)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
5d66140ce1
commit
7053169fd0
6 changed files with 70 additions and 7 deletions
|
|
@ -90,6 +90,19 @@ class SAMModel(nn.Module):
|
|||
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||
|
||||
def set_imgsz(self, imgsz):
|
||||
"""
|
||||
Set image size to make model compatible with different image sizes.
|
||||
|
||||
Args:
|
||||
imgsz (Tuple[int, int]): The size of the input image.
|
||||
"""
|
||||
if hasattr(self.image_encoder, "set_imgsz"):
|
||||
self.image_encoder.set_imgsz(imgsz)
|
||||
self.prompt_encoder.input_image_size = imgsz
|
||||
self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # 16 is fixed as patch size of ViT model
|
||||
self.image_encoder.img_size = imgsz[0]
|
||||
|
||||
|
||||
class SAM2Model(torch.nn.Module):
|
||||
"""
|
||||
|
|
@ -940,3 +953,14 @@ class SAM2Model(torch.nn.Module):
|
|||
# don't overlap (here sigmoid(-10.0)=4.5398e-05)
|
||||
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
||||
return pred_masks
|
||||
|
||||
def set_imgsz(self, imgsz):
|
||||
"""
|
||||
Set image size to make model compatible with different image sizes.
|
||||
|
||||
Args:
|
||||
imgsz (Tuple[int, int]): The size of the input image.
|
||||
"""
|
||||
self.image_size = imgsz[0]
|
||||
self.sam_prompt_encoder.input_image_size = imgsz
|
||||
self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue