Fix broken examples in SAM Predictor docstrings (#18665)

This commit is contained in:
Mohammed Yasin 2025-01-14 01:04:29 +08:00 committed by GitHub
parent a6303020e6
commit ffd8df3751
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -73,8 +73,8 @@ class Predictor(BasePredictor):
>>> predictor = Predictor()
>>> predictor.setup_model(model_path="sam_model.pt")
>>> predictor.set_image("image.jpg")
>>> masks, scores, boxes = predictor.generate()
>>> results = predictor.postprocess((masks, scores, boxes), im, orig_img)
>>> bboxes = [[100, 100, 200, 200]]
>>> results = predictor(bboxes=bboxes)
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
@ -191,7 +191,7 @@ class Predictor(BasePredictor):
>>> predictor = Predictor()
>>> predictor.setup_model(model_path="sam_model.pt")
>>> predictor.set_image("image.jpg")
>>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]])
>>> results = predictor(bboxes=[[0, 0, 100, 100]])
"""
# Override prompts if any stored in self.prompts
bboxes = self.prompts.pop("bboxes", bboxes)
@ -646,8 +646,8 @@ class SAM2Predictor(Predictor):
>>> predictor = SAM2Predictor(cfg)
>>> predictor.set_image("path/to/image.jpg")
>>> bboxes = [[100, 100, 200, 200]]
>>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes)
>>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}")
>>> result = predictor(bboxes=bboxes)[0]
>>> print(f"Predicted {len(result.masks)} masks with average score {result.boxes.conf.mean():.2f}")
"""
_bb_feat_sizes = [
@ -694,8 +694,8 @@ class SAM2Predictor(Predictor):
>>> predictor = SAM2Predictor(cfg)
>>> image = torch.rand(1, 3, 640, 640)
>>> bboxes = [[100, 100, 200, 200]]
>>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes)
>>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}")
>>> result = predictor(image, bboxes=bboxes)[0]
>>> print(f"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}")
Notes:
- The method supports batched inference for multiple objects when points or bboxes are provided.