ultralytics 8.0.229 add model.embed() method (#7098)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
38eaf5e29f
commit
5b3e20379f
11 changed files with 65 additions and 14 deletions
|
|
@ -134,7 +134,7 @@ class BasePredictor:
|
|||
"""Runs inference on a given image using the specified model and arguments."""
|
||||
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
|
||||
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
|
||||
return self.model(im, augment=self.args.augment, visualize=visualize)
|
||||
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
|
|
@ -263,6 +263,9 @@ class BasePredictor:
|
|||
# Inference
|
||||
with profilers[1]:
|
||||
preds = self.inference(im, *args, **kwargs)
|
||||
if self.args.embed:
|
||||
yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
|
||||
continue
|
||||
|
||||
# Postprocess
|
||||
with profilers[2]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue