ultralytics 8.0.229 add model.embed() method (#7098)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2023-12-22 15:32:06 +01:00 committed by GitHub
parent 38eaf5e29f
commit 5b3e20379f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 65 additions and 14 deletions

View file

@ -134,7 +134,7 @@ class BasePredictor:
"""Runs inference on a given image using the specified model and arguments."""
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
return self.model(im, augment=self.args.augment, visualize=visualize)
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
def pre_transform(self, im):
"""
@ -263,6 +263,9 @@ class BasePredictor:
# Inference
with profilers[1]:
preds = self.inference(im, *args, **kwargs)
if self.args.embed:
yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
continue
# Postprocess
with profilers[2]: