diff --git a/ultralytics/models/fastsam/prompt.py b/ultralytics/models/fastsam/prompt.py index 9738252e..89912132 100644 --- a/ultralytics/models/fastsam/prompt.py +++ b/ultralytics/models/fastsam/prompt.py @@ -335,12 +335,12 @@ class FastSAMPrompt: self.results[0].masks.data = torch.tensor(np.array([onemask])) return self.results - def text_prompt(self, text): + def text_prompt(self, text, clip_download_root=None): """Processes a text prompt, applies it to existing results and returns the updated results.""" if self.results[0].masks is not None: format_results = self._format_results(self.results[0], 0) cropped_images, filter_id, annotations = self._crop_image(format_results) - clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device) + clip_model, preprocess = self.clip.load("ViT-B/32", download_root=clip_download_root, device=self.device) scores = self.retrieve(clip_model, preprocess, cropped_images, text, device=self.device) max_idx = torch.argmax(scores) max_idx += sum(np.array(filter_id) <= int(max_idx))