Fix broken examples in SAM Predictor docstrings (#18665)

This commit is contained in:
Mohammed Yasin 2025-01-14 01:04:29 +08:00 committed by GitHub
parent a6303020e6
commit ffd8df3751
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -73,8 +73,8 @@ class Predictor(BasePredictor):
>>> predictor = Predictor() >>> predictor = Predictor()
>>> predictor.setup_model(model_path="sam_model.pt") >>> predictor.setup_model(model_path="sam_model.pt")
>>> predictor.set_image("image.jpg") >>> predictor.set_image("image.jpg")
>>> masks, scores, boxes = predictor.generate() >>> bboxes = [[100, 100, 200, 200]]
>>> results = predictor.postprocess((masks, scores, boxes), im, orig_img) >>> results = predictor(bboxes=bboxes)
""" """
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
@ -191,7 +191,7 @@ class Predictor(BasePredictor):
>>> predictor = Predictor() >>> predictor = Predictor()
>>> predictor.setup_model(model_path="sam_model.pt") >>> predictor.setup_model(model_path="sam_model.pt")
>>> predictor.set_image("image.jpg") >>> predictor.set_image("image.jpg")
>>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]]) >>> results = predictor(bboxes=[[0, 0, 100, 100]])
""" """
# Override prompts if any stored in self.prompts # Override prompts if any stored in self.prompts
bboxes = self.prompts.pop("bboxes", bboxes) bboxes = self.prompts.pop("bboxes", bboxes)
@ -646,8 +646,8 @@ class SAM2Predictor(Predictor):
>>> predictor = SAM2Predictor(cfg) >>> predictor = SAM2Predictor(cfg)
>>> predictor.set_image("path/to/image.jpg") >>> predictor.set_image("path/to/image.jpg")
>>> bboxes = [[100, 100, 200, 200]] >>> bboxes = [[100, 100, 200, 200]]
>>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes) >>> result = predictor(bboxes=bboxes)[0]
>>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}") >>> print(f"Predicted {len(result.masks)} masks with average score {result.boxes.conf.mean():.2f}")
""" """
_bb_feat_sizes = [ _bb_feat_sizes = [
@ -694,8 +694,8 @@ class SAM2Predictor(Predictor):
>>> predictor = SAM2Predictor(cfg) >>> predictor = SAM2Predictor(cfg)
>>> image = torch.rand(1, 3, 640, 640) >>> image = torch.rand(1, 3, 640, 640)
>>> bboxes = [[100, 100, 200, 200]] >>> bboxes = [[100, 100, 200, 200]]
>>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes) >>> result = predictor(image, bboxes=bboxes)[0]
>>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}") >>> print(f"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}")
Notes: Notes:
- The method supports batched inference for multiple objects when points or bboxes are provided. - The method supports batched inference for multiple objects when points or bboxes are provided.