ultralytics 8.0.229 add model.embed() method (#7098)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
38eaf5e29f
commit
5b3e20379f
11 changed files with 65 additions and 14 deletions
|
|
@ -94,7 +94,7 @@ class Model(nn.Module):
|
|||
self._load(model, task)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
"""Calls the predict() method with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -201,6 +201,24 @@ class Model(nn.Module):
|
|||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
def embed(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Calls the predict() method and returns image embeddings.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
||||
Accepts all source types accepted by the YOLO model.
|
||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
||||
**kwargs : Additional keyword arguments passed to the predictor.
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[torch.Tensor]): A list of image embeddings.
|
||||
"""
|
||||
if not kwargs.get('embed'):
|
||||
kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue