diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 4d845ae9..5a5052fd 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -637,8 +637,8 @@ class WorldModel(DetectionModel): (torch.Tensor): Model's output tensor. """ txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype) - if len(txt_feats) != len(x): - txt_feats = txt_feats.repeat(len(x), 1, 1) + if len(txt_feats) != len(x) or self.model[-1].export: + txt_feats = txt_feats.expand(x.shape[0], -1, -1) ori_txt_feats = txt_feats.clone() y, dt, embeddings = [], [], [] # outputs for m in self.model: # except the head part