ultralytics 8.0.134 add MobileSAM support (#3474)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Chaoning Zhang 2023-07-13 20:25:56 +08:00 committed by GitHub
parent c55a98ab8e
commit 201e69e4e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
32 changed files with 1472 additions and 841 deletions

View file

@ -131,6 +131,11 @@ class BasePredictor:
img /= 255 # 0 - 255 to 0.0 - 1.0
return img
def inference(self, im, *args, **kwargs):
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
return self.model(im, augment=self.args.augment, visualize=visualize)
def pre_transform(self, im):
"""Pre-transform input image before inference.
@ -181,13 +186,13 @@ class BasePredictor:
"""Post-processes predictions for an image and returns them."""
return preds
def __call__(self, source=None, model=None, stream=False):
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
"""Performs inference on an image or stream."""
self.stream = stream
if stream:
return self.stream_inference(source, model)
return self.stream_inference(source, model, *args, **kwargs)
else:
return list(self.stream_inference(source, model)) # merge list of Result into one
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
def predict_cli(self, source=None, model=None):
"""Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
@ -209,7 +214,7 @@ class BasePredictor:
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
@smart_inference_mode()
def stream_inference(self, source=None, model=None):
def stream_inference(self, source=None, model=None, *args, **kwargs):
"""Streams real-time inference on camera feed and saves results to file."""
if self.args.verbose:
LOGGER.info('')
@ -236,8 +241,6 @@ class BasePredictor:
self.run_callbacks('on_predict_batch_start')
self.batch = batch
path, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path[0]).stem,
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
# Preprocess
with profilers[0]:
@ -245,7 +248,7 @@ class BasePredictor:
# Inference
with profilers[1]:
preds = self.model(im, augment=self.args.augment, visualize=visualize)
preds = self.inference(im, *args, **kwargs)
# Postprocess
with profilers[2]: