ultralytics 8.2.70 Segment Anything Model 2 (SAM 2) (#14813)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-07-30 22:06:49 +08:00 committed by GitHub
parent 80f699ae21
commit 8648572809
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 3276 additions and 77 deletions

View file

@ -168,7 +168,7 @@ class Predictor(BasePredictor):
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
"""
features = self.model.image_encoder(im) if self.features is None else self.features
features = self.get_im_features(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
@ -334,7 +334,7 @@ class Predictor(BasePredictor):
"""
device = select_device(self.args.device, verbose=verbose)
if model is None:
model = build_sam(self.args.model)
model = self.get_model()
model.eval()
self.model = model.to(device)
self.device = device
@ -348,6 +348,10 @@ class Predictor(BasePredictor):
self.model.fp16 = False
self.done_warmup = True
def get_model(self):
"""Built Segment Anything Model (SAM) model."""
return build_sam(self.args.model)
def postprocess(self, preds, img, orig_imgs):
"""
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
@ -412,16 +416,18 @@ class Predictor(BasePredictor):
AssertionError: If more than one image is set.
"""
if self.model is None:
model = build_sam(self.args.model)
self.setup_model(model)
self.setup_model(model=None)
self.setup_source(image)
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
for batch in self.dataset:
im = self.preprocess(batch[1])
self.features = self.model.image_encoder(im)
self.im = im
self.features = self.get_im_features(im)
break
def get_im_features(self, im):
"""Get image features from the SAM image encoder."""
return self.model.image_encoder(im)
def set_prompts(self, prompts):
"""Set prompts in advance."""
self.prompts = prompts