ultralytics 8.0.229 add model.embed() method (#7098)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
38eaf5e29f
commit
5b3e20379f
11 changed files with 65 additions and 14 deletions
|
|
@ -94,7 +94,7 @@ class Model(nn.Module):
|
|||
self._load(model, task)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
"""Calls the predict() method with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -201,6 +201,24 @@ class Model(nn.Module):
|
|||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
def embed(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Calls the predict() method and returns image embeddings.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
||||
Accepts all source types accepted by the YOLO model.
|
||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
||||
**kwargs : Additional keyword arguments passed to the predictor.
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[torch.Tensor]): A list of image embeddings.
|
||||
"""
|
||||
if not kwargs.get('embed'):
|
||||
kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ class BasePredictor:
|
|||
"""Runs inference on a given image using the specified model and arguments."""
|
||||
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
|
||||
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
|
||||
return self.model(im, augment=self.args.augment, visualize=visualize)
|
||||
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
|
|
@ -263,6 +263,9 @@ class BasePredictor:
|
|||
# Inference
|
||||
with profilers[1]:
|
||||
preds = self.inference(im, *args, **kwargs)
|
||||
if self.args.embed:
|
||||
yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
|
||||
continue
|
||||
|
||||
# Postprocess
|
||||
with profilers[2]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue