ultralytics 8.0.143 add Model base class (#3934)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2023-07-27 06:54:02 +08:00 committed by GitHub
parent 3c787eb080
commit 1a0eb3f099
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 182 additions and 407 deletions

View file

@ -28,6 +28,8 @@ class Predictor(BasePredictor):
# Args for set_image
self.im = None
self.features = None
# Args for set_prompts
self.prompts = {}
# Args for segment everything
self.segment_all = False
@ -92,6 +94,10 @@ class Predictor(BasePredictor):
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
# Get prompts from self.prompts first
bboxes = self.prompts.pop('bboxes', bboxes)
points = self.prompts.pop('points', points)
masks = self.prompts.pop('masks', masks)
if all(i is None for i in [bboxes, points, masks]):
return self.generate(im, *args, **kwargs)
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
@ -348,6 +354,10 @@ class Predictor(BasePredictor):
self.im = im
break
def set_prompts(self, prompts):
"""Set prompts in advance."""
self.prompts = prompts
def reset_image(self):
self.im = None
self.features = None