ultralytics 8.1.14 new YOLOv8-World models (#8054)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
f9e9cdf2c3
commit
850ca8587f
19 changed files with 683 additions and 32 deletions
|
|
@ -19,6 +19,8 @@ from ultralytics.nn.modules import (
|
|||
Bottleneck,
|
||||
BottleneckCSP,
|
||||
C2f,
|
||||
C2fAttn,
|
||||
ImagePoolingAttn,
|
||||
C3Ghost,
|
||||
C3x,
|
||||
Classify,
|
||||
|
|
@ -40,6 +42,7 @@ from ultralytics.nn.modules import (
|
|||
ResNetLayer,
|
||||
RTDETRDecoder,
|
||||
Segment,
|
||||
WorldDetect,
|
||||
)
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
||||
|
|
@ -222,7 +225,7 @@ class BaseModel(nn.Module):
|
|||
"""
|
||||
self = super()._apply(fn)
|
||||
m = self.model[-1] # Detect()
|
||||
if isinstance(m, (Detect, Segment)):
|
||||
if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
|
||||
m.stride = fn(m.stride)
|
||||
m.anchors = fn(m.anchors)
|
||||
m.strides = fn(m.strides)
|
||||
|
|
@ -281,7 +284,7 @@ class DetectionModel(BaseModel):
|
|||
|
||||
# Build strides
|
||||
m = self.model[-1] # Detect()
|
||||
if isinstance(m, (Detect, Segment, Pose, OBB)):
|
||||
if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
|
||||
s = 256 # 2x min stride
|
||||
m.inplace = self.inplace
|
||||
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
|
||||
|
|
@ -546,6 +549,77 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
return x
|
||||
|
||||
|
||||
class WorldModel(DetectionModel):
|
||||
"""YOLOv8 World Model."""
|
||||
|
||||
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
|
||||
"""Initialize YOLOv8 world model with given config and parameters."""
|
||||
self.txt_feats = torch.randn(1, nc or 80, 512) # placeholder
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def set_classes(self, text):
|
||||
"""Perform a forward pass with optional profiling, visualization, and embedding extraction."""
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
check_requirements("git+https://github.com/openai/CLIP.git")
|
||||
import clip
|
||||
|
||||
model, _ = clip.load("ViT-B/32")
|
||||
device = next(model.parameters()).device
|
||||
text_token = clip.tokenize(text).to(device)
|
||||
txt_feats = model.encode_text(text_token).to(dtype=torch.float32)
|
||||
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
||||
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
||||
self.model[-1].nc = len(text)
|
||||
|
||||
def init_criterion(self):
|
||||
"""Initialize the loss criterion for the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
|
||||
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
|
||||
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
|
||||
embed (list, optional): A list of feature vectors/embeddings to return.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Model's output tensor.
|
||||
"""
|
||||
txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
|
||||
if len(txt_feats) != len(x):
|
||||
txt_feats = txt_feats.repeat(len(x), 1, 1)
|
||||
ori_txt_feats = txt_feats.clone()
|
||||
y, dt, embeddings = [], [], [] # outputs
|
||||
for m in self.model: # except the head part
|
||||
if m.f != -1: # if not from previous layer
|
||||
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
||||
if profile:
|
||||
self._profile_one_layer(m, x, dt)
|
||||
if isinstance(m, C2fAttn):
|
||||
x = m(x, txt_feats)
|
||||
elif isinstance(m, WorldDetect):
|
||||
x = m(x, ori_txt_feats)
|
||||
elif isinstance(m, ImagePoolingAttn):
|
||||
txt_feats = m(x, txt_feats)
|
||||
else:
|
||||
x = m(x) # run
|
||||
|
||||
y.append(x if m.i in self.save else None) # save output
|
||||
if visualize:
|
||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||
if embed and m.i in embed:
|
||||
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||
if m.i == max(embed):
|
||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||
return x
|
||||
|
||||
|
||||
class Ensemble(nn.ModuleList):
|
||||
"""Ensemble of models."""
|
||||
|
||||
|
|
@ -685,11 +759,8 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||
|
||||
# Module updates
|
||||
for m in ensemble.modules():
|
||||
t = type(m)
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
|
||||
if hasattr(m, "inplace"):
|
||||
m.inplace = inplace
|
||||
elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
||||
# Return model
|
||||
if len(ensemble) == 1:
|
||||
|
|
@ -699,7 +770,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||
LOGGER.info(f"Ensemble created with {weights}\n")
|
||||
for k in "names", "nc", "yaml":
|
||||
setattr(ensemble, k, getattr(ensemble[0], k))
|
||||
ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
|
||||
ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride
|
||||
assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
|
||||
return ensemble
|
||||
|
||||
|
|
@ -721,11 +792,8 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|||
|
||||
# Module updates
|
||||
for m in model.modules():
|
||||
t = type(m)
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
|
||||
if hasattr(m, "inplace"):
|
||||
m.inplace = inplace
|
||||
elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
||||
# Return model and ckpt
|
||||
return model, ckpt
|
||||
|
|
@ -778,6 +846,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
C1,
|
||||
C2,
|
||||
C2f,
|
||||
C2fAttn,
|
||||
C3,
|
||||
C3TR,
|
||||
C3Ghost,
|
||||
|
|
@ -789,9 +858,14 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
c1, c2 = ch[f], args[0]
|
||||
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
||||
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
||||
if m is C2fAttn:
|
||||
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels
|
||||
args[2] = int(
|
||||
max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
|
||||
) # num heads
|
||||
|
||||
args = [c1, c2, *args[1:]]
|
||||
if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3):
|
||||
if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3):
|
||||
args.insert(2, n) # number of repeats
|
||||
n = 1
|
||||
elif m is AIFI:
|
||||
|
|
@ -808,7 +882,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[x] for x in f)
|
||||
elif m in (Detect, Segment, Pose, OBB):
|
||||
elif m in (Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn):
|
||||
args.append([ch[x] for x in f])
|
||||
if m is Segment:
|
||||
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
||||
|
|
@ -911,9 +985,7 @@ def guess_model_task(model):
|
|||
return cfg2task(eval(x))
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, Detect):
|
||||
return "detect"
|
||||
elif isinstance(m, Segment):
|
||||
if isinstance(m, Segment):
|
||||
return "segment"
|
||||
elif isinstance(m, Classify):
|
||||
return "classify"
|
||||
|
|
@ -921,6 +993,8 @@ def guess_model_task(model):
|
|||
return "pose"
|
||||
elif isinstance(m, OBB):
|
||||
return "obb"
|
||||
elif isinstance(m, (Detect, WorldDetect)):
|
||||
return "detect"
|
||||
|
||||
# Guess from model filename
|
||||
if isinstance(model, (str, Path)):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue