From cd172e9d126bd0d059b630973e310c648c719be7 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 31 Mar 2024 19:07:59 +0200 Subject: [PATCH] Scope `ultralytics/CLIP` imports (#9449) Signed-off-by: Glenn Jocher --- ultralytics/models/fastsam/prompt.py | 9 ++++---- ultralytics/models/yolo/world/train.py | 31 +++++++++++++------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/ultralytics/models/fastsam/prompt.py b/ultralytics/models/fastsam/prompt.py index 544938a5..e754f08e 100644 --- a/ultralytics/models/fastsam/prompt.py +++ b/ultralytics/models/fastsam/prompt.py @@ -9,7 +9,7 @@ import numpy as np import torch from PIL import Image -from ultralytics.utils import TQDM +from ultralytics.utils import TQDM, checks class FastSAMPrompt: @@ -33,9 +33,7 @@ class FastSAMPrompt: try: import clip except ImportError: - from ultralytics.utils.checks import check_requirements - - check_requirements("git+https://github.com/ultralytics/CLIP.git") + checks.check_requirements("git+https://github.com/ultralytics/CLIP.git") import clip self.clip = clip @@ -115,7 +113,8 @@ class FastSAMPrompt: points (list, optional): Points to be plotted. Defaults to None. point_label (list, optional): Labels for the points. Defaults to None. mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True. - better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. Defaults to True. + better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. + Defaults to True. retina (bool, optional): Whether to use retina mask. Defaults to False. with_contours (bool, optional): Whether to plot contours. Defaults to True. """ diff --git a/ultralytics/models/yolo/world/train.py b/ultralytics/models/yolo/world/train.py index 38cd4cf6..6f51d443 100644 --- a/ultralytics/models/yolo/world/train.py +++ b/ultralytics/models/yolo/world/train.py @@ -1,18 +1,12 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from ultralytics.models import yolo -from ultralytics.nn.tasks import WorldModel -from ultralytics.utils import DEFAULT_CFG, RANK -from ultralytics.data import build_yolo_dataset -from ultralytics.utils.torch_utils import de_parallel -from ultralytics.utils.checks import check_requirements import itertools -try: - import clip -except ImportError: - check_requirements("git+https://github.com/ultralytics/CLIP.git") - import clip +from ultralytics.data import build_yolo_dataset +from ultralytics.models import yolo +from ultralytics.nn.tasks import WorldModel +from ultralytics.utils import DEFAULT_CFG, RANK, checks +from ultralytics.utils.torch_utils import de_parallel def on_pretrain_routine_end(trainer): @@ -22,10 +16,9 @@ def on_pretrain_routine_end(trainer): names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())] de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False) device = next(trainer.model.parameters()).device - text_model, _ = clip.load("ViT-B/32", device=device) - for p in text_model.parameters(): + trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device) + for p in trainer.text_model.parameters(): p.requires_grad_(False) - trainer.text_model = text_model class WorldTrainer(yolo.detect.DetectionTrainer): @@ -48,6 +41,14 @@ class WorldTrainer(yolo.detect.DetectionTrainer): overrides = {} super().__init__(cfg, overrides, _callbacks) + # Import and assign clip + try: + import clip + except ImportError: + checks.check_requirements("git+https://github.com/ultralytics/CLIP.git") + import clip + self.clip = clip + def get_model(self, cfg=None, weights=None, verbose=True): """Return WorldModel initialized with specified config and weights.""" # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`. @@ -84,7 +85,7 @@ class WorldTrainer(yolo.detect.DetectionTrainer): # NOTE: add text features texts = list(itertools.chain(*batch["texts"])) - text_token = clip.tokenize(texts).to(batch["img"].device) + text_token = self.clip.tokenize(texts).to(batch["img"].device) txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32 txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])