ultralytics 8.1.39 add YOLO-World training (#9268)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
18036908d4
commit
e9187c1296
34 changed files with 2166 additions and 100 deletions
|
|
@ -564,28 +564,28 @@ class WorldModel(DetectionModel):
|
|||
self.clip_model = None # CLIP model placeholder
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def set_classes(self, text):
|
||||
"""Perform a forward pass with optional profiling, visualization, and embedding extraction."""
|
||||
def set_classes(self, text, batch=80, cache_clip_model=True):
|
||||
"""Set classes in advance so that model could do offline-inference without clip model."""
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
check_requirements("git+https://github.com/openai/CLIP.git")
|
||||
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||
import clip
|
||||
|
||||
if not getattr(self, "clip_model", None): # for backwards compatibility of models lacking clip_model attribute
|
||||
if (
|
||||
not getattr(self, "clip_model", None) and cache_clip_model
|
||||
): # for backwards compatibility of models lacking clip_model attribute
|
||||
self.clip_model = clip.load("ViT-B/32")[0]
|
||||
device = next(self.clip_model.parameters()).device
|
||||
model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
|
||||
device = next(model.parameters()).device
|
||||
text_token = clip.tokenize(text).to(device)
|
||||
txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32)
|
||||
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
|
||||
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
|
||||
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
||||
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
|
||||
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
||||
self.model[-1].nc = len(text)
|
||||
|
||||
def init_criterion(self):
|
||||
"""Initialize the loss criterion for the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
||||
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the model.
|
||||
|
||||
|
|
@ -593,13 +593,14 @@ class WorldModel(DetectionModel):
|
|||
x (torch.Tensor): The input tensor.
|
||||
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
|
||||
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
|
||||
txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
|
||||
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
|
||||
embed (list, optional): A list of feature vectors/embeddings to return.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Model's output tensor.
|
||||
"""
|
||||
txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
|
||||
txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
|
||||
if len(txt_feats) != len(x):
|
||||
txt_feats = txt_feats.repeat(len(x), 1, 1)
|
||||
ori_txt_feats = txt_feats.clone()
|
||||
|
|
@ -627,6 +628,21 @@ class WorldModel(DetectionModel):
|
|||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||
return x
|
||||
|
||||
def loss(self, batch, preds=None):
|
||||
"""
|
||||
Compute loss.
|
||||
|
||||
Args:
|
||||
batch (dict): Batch to compute loss on.
|
||||
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
||||
"""
|
||||
if not hasattr(self, "criterion"):
|
||||
self.criterion = self.init_criterion()
|
||||
|
||||
if preds is None:
|
||||
preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
|
||||
return self.criterion(preds, batch)
|
||||
|
||||
|
||||
class Ensemble(nn.ModuleList):
|
||||
"""Ensemble of models."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue