ultralytics 8.0.143 add Model base class (#3934)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
3c787eb080
commit
1a0eb3f099
15 changed files with 182 additions and 407 deletions
|
|
@ -28,6 +28,8 @@ class Predictor(BasePredictor):
|
|||
# Args for set_image
|
||||
self.im = None
|
||||
self.features = None
|
||||
# Args for set_prompts
|
||||
self.prompts = {}
|
||||
# Args for segment everything
|
||||
self.segment_all = False
|
||||
|
||||
|
|
@ -92,6 +94,10 @@ class Predictor(BasePredictor):
|
|||
of masks and H=W=256. These low resolution logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
"""
|
||||
# Get prompts from self.prompts first
|
||||
bboxes = self.prompts.pop('bboxes', bboxes)
|
||||
points = self.prompts.pop('points', points)
|
||||
masks = self.prompts.pop('masks', masks)
|
||||
if all(i is None for i in [bboxes, points, masks]):
|
||||
return self.generate(im, *args, **kwargs)
|
||||
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
||||
|
|
@ -348,6 +354,10 @@ class Predictor(BasePredictor):
|
|||
self.im = im
|
||||
break
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
"""Set prompts in advance."""
|
||||
self.prompts = prompts
|
||||
|
||||
def reset_image(self):
|
||||
self.im = None
|
||||
self.features = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue