ultralytics 8.1.31 NCNN and CLIP updates (#9235)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-03-22 15:51:50 -07:00 committed by GitHub
parent 41c2d8d99f
commit 3c179f87cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 77 additions and 60 deletions

View file

@ -560,7 +560,8 @@ class WorldModel(DetectionModel):
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
"""Initialize YOLOv8 world model with given config and parameters."""
self.txt_feats = torch.randn(1, nc or 80, 512) # placeholder
self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
self.clip_model = None # CLIP model placeholder
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def set_classes(self, text):
@ -571,10 +572,11 @@ class WorldModel(DetectionModel):
check_requirements("git+https://github.com/openai/CLIP.git")
import clip
model, _ = clip.load("ViT-B/32")
device = next(model.parameters()).device
if not self.clip_model:
self.clip_model = clip.load("ViT-B/32")[0]
device = next(self.clip_model.parameters()).device
text_token = clip.tokenize(text).to(device)
txt_feats = model.encode_text(text_token).to(dtype=torch.float32)
txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32)
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
self.model[-1].nc = len(text)
@ -841,7 +843,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in (
if m in {
Classify,
Conv,
ConvTranspose,
@ -867,7 +869,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
DWConvTranspose2d,
C3x,
RepC3,
):
}:
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8)
@ -883,7 +885,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
n = 1
elif m is AIFI:
args = [ch[f], *args]
elif m in (HGStem, HGBlock):
elif m in {HGStem, HGBlock}:
c1, cm, c2 = ch[f], args[0], args[1]
args = [c1, cm, c2, *args[2:]]
if m is HGBlock:
@ -895,7 +897,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
elif m in (Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn):
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn}:
args.append([ch[x] for x in f])
if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
@ -978,7 +980,7 @@ def guess_model_task(model):
def cfg2task(cfg):
"""Guess from YAML dictionary."""
m = cfg["head"][-1][-2].lower() # output module name
if m in ("classify", "classifier", "cls", "fc"):
if m in {"classify", "classifier", "cls", "fc"}:
return "classify"
if m == "detect":
return "detect"