Add Custom CLIP Model Download Path (#14517)

Co-authored-by: wangsrGit119 <1215618342@email.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
suke 2024-07-21 02:09:45 +08:00 committed by GitHub
parent 8ce8e5ecc4
commit 8094df3c47
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -335,12 +335,12 @@ class FastSAMPrompt:
self.results[0].masks.data = torch.tensor(np.array([onemask]))
return self.results
def text_prompt(self, text):
def text_prompt(self, text, clip_download_root=None):
"""Processes a text prompt, applies it to existing results and returns the updated results."""
if self.results[0].masks is not None:
format_results = self._format_results(self.results[0], 0)
cropped_images, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
clip_model, preprocess = self.clip.load("ViT-B/32", download_root=clip_download_root, device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_images, text, device=self.device)
max_idx = torch.argmax(scores)
max_idx += sum(np.array(filter_id) <= int(max_idx))