ultralytics 8.2.84 new SAM flexible imgsz inference (#15882)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
5d66140ce1
commit
7053169fd0
6 changed files with 70 additions and 7 deletions
|
|
@ -95,7 +95,7 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides.update(dict(task="segment", mode="predict", imgsz=1024))
|
||||
overrides.update(dict(task="segment", mode="predict"))
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.retina_masks = True
|
||||
self.im = None
|
||||
|
|
@ -455,8 +455,11 @@ class Predictor(BasePredictor):
|
|||
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
|
||||
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
|
||||
|
||||
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
|
||||
masks = masks > self.model.mask_threshold # to bool
|
||||
if len(masks) == 0:
|
||||
masks = None
|
||||
else:
|
||||
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
|
||||
masks = masks > self.model.mask_threshold # to bool
|
||||
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
|
||||
# Reset segment-all mode.
|
||||
self.segment_all = False
|
||||
|
|
@ -522,6 +525,10 @@ class Predictor(BasePredictor):
|
|||
|
||||
def get_im_features(self, im):
|
||||
"""Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
|
||||
assert (
|
||||
isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1]
|
||||
), f"SAM models only support square image size, but got {self.imgsz}."
|
||||
self.model.set_imgsz(self.imgsz)
|
||||
return self.model.image_encoder(im)
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
|
|
@ -761,6 +768,12 @@ class SAM2Predictor(Predictor):
|
|||
|
||||
def get_im_features(self, im):
|
||||
"""Extracts image features from the SAM image encoder for subsequent processing."""
|
||||
assert (
|
||||
isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1]
|
||||
), f"SAM 2 models only support square image size, but got {self.imgsz}."
|
||||
self.model.set_imgsz(self.imgsz)
|
||||
self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]]
|
||||
|
||||
backbone_out = self.model.forward_image(im)
|
||||
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
||||
if self.model.directly_add_no_mem_embed:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue