ultralytics 8.1.31 NCNN and CLIP updates (#9235)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
41c2d8d99f
commit
3c179f87cb
8 changed files with 77 additions and 60 deletions
|
|
@ -560,7 +560,8 @@ class WorldModel(DetectionModel):
|
|||
|
||||
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
|
||||
"""Initialize YOLOv8 world model with given config and parameters."""
|
||||
self.txt_feats = torch.randn(1, nc or 80, 512) # placeholder
|
||||
self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
|
||||
self.clip_model = None # CLIP model placeholder
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def set_classes(self, text):
|
||||
|
|
@ -571,10 +572,11 @@ class WorldModel(DetectionModel):
|
|||
check_requirements("git+https://github.com/openai/CLIP.git")
|
||||
import clip
|
||||
|
||||
model, _ = clip.load("ViT-B/32")
|
||||
device = next(model.parameters()).device
|
||||
if not self.clip_model:
|
||||
self.clip_model = clip.load("ViT-B/32")[0]
|
||||
device = next(self.clip_model.parameters()).device
|
||||
text_token = clip.tokenize(text).to(device)
|
||||
txt_feats = model.encode_text(text_token).to(dtype=torch.float32)
|
||||
txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32)
|
||||
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
||||
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
|
||||
self.model[-1].nc = len(text)
|
||||
|
|
@ -841,7 +843,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
||||
|
||||
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
||||
if m in (
|
||||
if m in {
|
||||
Classify,
|
||||
Conv,
|
||||
ConvTranspose,
|
||||
|
|
@ -867,7 +869,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
DWConvTranspose2d,
|
||||
C3x,
|
||||
RepC3,
|
||||
):
|
||||
}:
|
||||
c1, c2 = ch[f], args[0]
|
||||
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
||||
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
||||
|
|
@ -883,7 +885,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
n = 1
|
||||
elif m is AIFI:
|
||||
args = [ch[f], *args]
|
||||
elif m in (HGStem, HGBlock):
|
||||
elif m in {HGStem, HGBlock}:
|
||||
c1, cm, c2 = ch[f], args[0], args[1]
|
||||
args = [c1, cm, c2, *args[2:]]
|
||||
if m is HGBlock:
|
||||
|
|
@ -895,7 +897,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[x] for x in f)
|
||||
elif m in (Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn):
|
||||
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn}:
|
||||
args.append([ch[x] for x in f])
|
||||
if m is Segment:
|
||||
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
||||
|
|
@ -978,7 +980,7 @@ def guess_model_task(model):
|
|||
def cfg2task(cfg):
|
||||
"""Guess from YAML dictionary."""
|
||||
m = cfg["head"][-1][-2].lower() # output module name
|
||||
if m in ("classify", "classifier", "cls", "fc"):
|
||||
if m in {"classify", "classifier", "cls", "fc"}:
|
||||
return "classify"
|
||||
if m == "detect":
|
||||
return "detect"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue