Add Custom CLIP Model Download Path (#14517)
Co-authored-by: wangsrGit119 <1215618342@email.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
8ce8e5ecc4
commit
8094df3c47
1 changed files with 2 additions and 2 deletions
|
|
@ -335,12 +335,12 @@ class FastSAMPrompt:
|
|||
self.results[0].masks.data = torch.tensor(np.array([onemask]))
|
||||
return self.results
|
||||
|
||||
def text_prompt(self, text):
|
||||
def text_prompt(self, text, clip_download_root=None):
|
||||
"""Processes a text prompt, applies it to existing results and returns the updated results."""
|
||||
if self.results[0].masks is not None:
|
||||
format_results = self._format_results(self.results[0], 0)
|
||||
cropped_images, filter_id, annotations = self._crop_image(format_results)
|
||||
clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
|
||||
clip_model, preprocess = self.clip.load("ViT-B/32", download_root=clip_download_root, device=self.device)
|
||||
scores = self.retrieve(clip_model, preprocess, cropped_images, text, device=self.device)
|
||||
max_idx = torch.argmax(scores)
|
||||
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue