diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index b74704c5..0a2776f3 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -73,8 +73,8 @@ class Predictor(BasePredictor): >>> predictor = Predictor() >>> predictor.setup_model(model_path="sam_model.pt") >>> predictor.set_image("image.jpg") - >>> masks, scores, boxes = predictor.generate() - >>> results = predictor.postprocess((masks, scores, boxes), im, orig_img) + >>> bboxes = [[100, 100, 200, 200]] + >>> results = predictor(bboxes=bboxes) """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): @@ -191,7 +191,7 @@ class Predictor(BasePredictor): >>> predictor = Predictor() >>> predictor.setup_model(model_path="sam_model.pt") >>> predictor.set_image("image.jpg") - >>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]]) + >>> results = predictor(bboxes=[[0, 0, 100, 100]]) """ # Override prompts if any stored in self.prompts bboxes = self.prompts.pop("bboxes", bboxes) @@ -646,8 +646,8 @@ class SAM2Predictor(Predictor): >>> predictor = SAM2Predictor(cfg) >>> predictor.set_image("path/to/image.jpg") >>> bboxes = [[100, 100, 200, 200]] - >>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes) - >>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}") + >>> result = predictor(bboxes=bboxes)[0] + >>> print(f"Predicted {len(result.masks)} masks with average score {result.boxes.conf.mean():.2f}") """ _bb_feat_sizes = [ @@ -694,8 +694,8 @@ class SAM2Predictor(Predictor): >>> predictor = SAM2Predictor(cfg) >>> image = torch.rand(1, 3, 640, 640) >>> bboxes = [[100, 100, 200, 200]] - >>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes) - >>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}") + >>> result = predictor(image, bboxes=bboxes)[0] + >>> print(f"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}") Notes: - The method supports batched inference for multiple objects when points or bboxes are provided.