ultralytics 8.0.229 add model.embed() method (#7098)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2023-12-22 15:32:06 +01:00 committed by GitHub
parent 38eaf5e29f
commit 5b3e20379f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 65 additions and 14 deletions

View file

@ -94,7 +94,7 @@ class Model(nn.Module):
self._load(model, task)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
"""Calls the predict() method with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
@staticmethod
@ -201,6 +201,24 @@ class Model(nn.Module):
self._check_is_pytorch_model()
self.model.fuse()
def embed(self, source=None, stream=False, **kwargs):
"""
Calls the predict() method and returns image embeddings.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[torch.Tensor]): A list of image embeddings.
"""
if not kwargs.get('embed'):
kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
return self.predict(source, stream, **kwargs)
def predict(self, source=None, stream=False, predictor=None, **kwargs):
"""
Perform prediction using the YOLO model.