ultralytics 8.0.81 single-line docstring updates (#2061)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
5bce1c3021
commit
a38f227672
64 changed files with 620 additions and 58 deletions
|
|
@ -107,9 +107,11 @@ class BasePredictor:
|
|||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def preprocess(self, img):
|
||||
"""Prepares input image before inference."""
|
||||
pass
|
||||
|
||||
def write_results(self, idx, results, batch):
|
||||
"""Write inference results to a file or directory."""
|
||||
p, im, _ = batch
|
||||
log_string = ''
|
||||
if len(im.shape) == 3:
|
||||
|
|
@ -143,9 +145,11 @@ class BasePredictor:
|
|||
return log_string
|
||||
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
"""Post-processes predictions for an image and returns them."""
|
||||
return preds
|
||||
|
||||
def __call__(self, source=None, model=None, stream=False):
|
||||
"""Performs inference on an image or stream."""
|
||||
self.stream = stream
|
||||
if stream:
|
||||
return self.stream_inference(source, model)
|
||||
|
|
@ -159,6 +163,7 @@ class BasePredictor:
|
|||
pass
|
||||
|
||||
def setup_source(self, source):
|
||||
"""Sets up source and inference mode."""
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||
if self.args.task == 'classify':
|
||||
transforms = getattr(self.model.model, 'transforms', classify_transforms(self.imgsz[0]))
|
||||
|
|
@ -179,6 +184,7 @@ class BasePredictor:
|
|||
|
||||
@smart_inference_mode()
|
||||
def stream_inference(self, source=None, model=None):
|
||||
"""Streams real-time inference on camera feed and saves results to file."""
|
||||
if self.args.verbose:
|
||||
LOGGER.info('')
|
||||
|
||||
|
|
@ -264,6 +270,7 @@ class BasePredictor:
|
|||
self.run_callbacks('on_predict_end')
|
||||
|
||||
def setup_model(self, model, verbose=True):
|
||||
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
||||
device = select_device(self.args.device, verbose=verbose)
|
||||
model = model or self.args.model
|
||||
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
|
||||
|
|
@ -278,6 +285,7 @@ class BasePredictor:
|
|||
self.model.eval()
|
||||
|
||||
def show(self, p):
|
||||
"""Display an image in a window using OpenCV imshow()."""
|
||||
im0 = self.plotted_img
|
||||
if platform.system() == 'Linux' and p not in self.windows:
|
||||
self.windows.append(p)
|
||||
|
|
@ -287,6 +295,7 @@ class BasePredictor:
|
|||
cv2.waitKey(500 if self.batch[4].startswith('image') else 1) # 1 millisecond
|
||||
|
||||
def save_preds(self, vid_cap, idx, save_path):
|
||||
"""Save video predictions as mp4 at specified path."""
|
||||
im0 = self.plotted_img
|
||||
# Save imgs
|
||||
if self.dataset.mode == 'image':
|
||||
|
|
@ -307,6 +316,7 @@ class BasePredictor:
|
|||
self.vid_writer[idx].write(im0)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Runs all registered callbacks for a specific event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue