ultralytics 8.0.229 add model.embed() method (#7098)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
38eaf5e29f
commit
5b3e20379f
11 changed files with 65 additions and 14 deletions
|
|
@ -333,7 +333,7 @@ class AutoBackend(nn.Module):
|
|||
|
||||
self.__dict__.update(locals()) # assign all variables to self
|
||||
|
||||
def forward(self, im, augment=False, visualize=False):
|
||||
def forward(self, im, augment=False, visualize=False, embed=None):
|
||||
"""
|
||||
Runs inference on the YOLOv8 MultiBackend model.
|
||||
|
||||
|
|
@ -341,6 +341,7 @@ class AutoBackend(nn.Module):
|
|||
im (torch.Tensor): The image tensor to perform inference on.
|
||||
augment (bool): whether to perform data augmentation during inference, defaults to False
|
||||
visualize (bool): whether to visualize the output predictions, defaults to False
|
||||
embed (list, optional): A list of feature vectors/embeddings to return.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
|
||||
|
|
@ -352,7 +353,7 @@ class AutoBackend(nn.Module):
|
|||
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
||||
|
||||
if self.pt or self.nn_module: # PyTorch
|
||||
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
|
||||
y = self.model(im, augment=augment, visualize=visualize, embed=embed)
|
||||
elif self.jit: # TorchScript
|
||||
y = self.model(im)
|
||||
elif self.dnn: # ONNX OpenCV DNN
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class BaseModel(nn.Module):
|
|||
return self.loss(x, *args, **kwargs)
|
||||
return self.predict(x, *args, **kwargs)
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, augment=False):
|
||||
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the network.
|
||||
|
||||
|
|
@ -50,15 +50,16 @@ class BaseModel(nn.Module):
|
|||
profile (bool): Print the computation time of each layer if True, defaults to False.
|
||||
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
||||
augment (bool): Augment image during prediction, defaults to False.
|
||||
embed (list, optional): A list of feature vectors/embeddings to return.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The last output of the model.
|
||||
"""
|
||||
if augment:
|
||||
return self._predict_augment(x)
|
||||
return self._predict_once(x, profile, visualize)
|
||||
return self._predict_once(x, profile, visualize, embed)
|
||||
|
||||
def _predict_once(self, x, profile=False, visualize=False):
|
||||
def _predict_once(self, x, profile=False, visualize=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the network.
|
||||
|
||||
|
|
@ -66,11 +67,12 @@ class BaseModel(nn.Module):
|
|||
x (torch.Tensor): The input tensor to the model.
|
||||
profile (bool): Print the computation time of each layer if True, defaults to False.
|
||||
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
||||
embed (list, optional): A list of feature vectors/embeddings to return.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The last output of the model.
|
||||
"""
|
||||
y, dt = [], [] # outputs
|
||||
y, dt, embeddings = [], [], [] # outputs
|
||||
for m in self.model:
|
||||
if m.f != -1: # if not from previous layer
|
||||
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
||||
|
|
@ -80,6 +82,10 @@ class BaseModel(nn.Module):
|
|||
y.append(x if m.i in self.save else None) # save output
|
||||
if visualize:
|
||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||
if embed and m.i in embed:
|
||||
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||
if m.i == max(embed):
|
||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||
return x
|
||||
|
||||
def _predict_augment(self, x):
|
||||
|
|
@ -454,7 +460,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
|
||||
device=img.device)
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
|
||||
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the model.
|
||||
|
||||
|
|
@ -464,11 +470,12 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
|
||||
batch (dict, optional): Ground truth data for evaluation. Defaults to None.
|
||||
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
|
||||
embed (list, optional): A list of feature vectors/embeddings to return.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Model's output tensor.
|
||||
"""
|
||||
y, dt = [], [] # outputs
|
||||
y, dt, embeddings = [], [], [] # outputs
|
||||
for m in self.model[:-1]: # except the head part
|
||||
if m.f != -1: # if not from previous layer
|
||||
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
||||
|
|
@ -478,6 +485,10 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
y.append(x if m.i in self.save else None) # save output
|
||||
if visualize:
|
||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||
if embed and m.i in embed:
|
||||
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||
if m.i == max(embed):
|
||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||
head = self.model[-1]
|
||||
x = head([y[j] for j in head.f], batch) # head inference
|
||||
return x
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue