ultralytics 8.2.70 Segment Anything Model 2 (SAM 2) (#14813)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
80f699ae21
commit
8648572809
36 changed files with 3276 additions and 77 deletions
|
|
@ -168,7 +168,7 @@ class Predictor(BasePredictor):
|
|||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||
"""
|
||||
features = self.model.image_encoder(im) if self.features is None else self.features
|
||||
features = self.get_im_features(im) if self.features is None else self.features
|
||||
|
||||
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
|
||||
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
||||
|
|
@ -334,7 +334,7 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
device = select_device(self.args.device, verbose=verbose)
|
||||
if model is None:
|
||||
model = build_sam(self.args.model)
|
||||
model = self.get_model()
|
||||
model.eval()
|
||||
self.model = model.to(device)
|
||||
self.device = device
|
||||
|
|
@ -348,6 +348,10 @@ class Predictor(BasePredictor):
|
|||
self.model.fp16 = False
|
||||
self.done_warmup = True
|
||||
|
||||
def get_model(self):
|
||||
"""Built Segment Anything Model (SAM) model."""
|
||||
return build_sam(self.args.model)
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
|
||||
|
|
@ -412,16 +416,18 @@ class Predictor(BasePredictor):
|
|||
AssertionError: If more than one image is set.
|
||||
"""
|
||||
if self.model is None:
|
||||
model = build_sam(self.args.model)
|
||||
self.setup_model(model)
|
||||
self.setup_model(model=None)
|
||||
self.setup_source(image)
|
||||
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
|
||||
for batch in self.dataset:
|
||||
im = self.preprocess(batch[1])
|
||||
self.features = self.model.image_encoder(im)
|
||||
self.im = im
|
||||
self.features = self.get_im_features(im)
|
||||
break
|
||||
|
||||
def get_im_features(self, im):
|
||||
"""Get image features from the SAM image encoder."""
|
||||
return self.model.image_encoder(im)
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
"""Set prompts in advance."""
|
||||
self.prompts = prompts
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue