ultralytics 8.0.81 single-line docstring updates (#2061)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-04-17 00:45:36 +02:00 committed by GitHub
parent 5bce1c3021
commit a38f227672
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
64 changed files with 620 additions and 58 deletions

View file

@ -107,9 +107,11 @@ class BasePredictor:
callbacks.add_integration_callbacks(self)
def preprocess(self, img):
"""Prepares input image before inference."""
pass
def write_results(self, idx, results, batch):
"""Write inference results to a file or directory."""
p, im, _ = batch
log_string = ''
if len(im.shape) == 3:
@ -143,9 +145,11 @@ class BasePredictor:
return log_string
def postprocess(self, preds, img, orig_img):
"""Post-processes predictions for an image and returns them."""
return preds
def __call__(self, source=None, model=None, stream=False):
"""Performs inference on an image or stream."""
self.stream = stream
if stream:
return self.stream_inference(source, model)
@ -159,6 +163,7 @@ class BasePredictor:
pass
def setup_source(self, source):
"""Sets up source and inference mode."""
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
if self.args.task == 'classify':
transforms = getattr(self.model.model, 'transforms', classify_transforms(self.imgsz[0]))
@ -179,6 +184,7 @@ class BasePredictor:
@smart_inference_mode()
def stream_inference(self, source=None, model=None):
"""Streams real-time inference on camera feed and saves results to file."""
if self.args.verbose:
LOGGER.info('')
@ -264,6 +270,7 @@ class BasePredictor:
self.run_callbacks('on_predict_end')
def setup_model(self, model, verbose=True):
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
device = select_device(self.args.device, verbose=verbose)
model = model or self.args.model
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
@ -278,6 +285,7 @@ class BasePredictor:
self.model.eval()
def show(self, p):
"""Display an image in a window using OpenCV imshow()."""
im0 = self.plotted_img
if platform.system() == 'Linux' and p not in self.windows:
self.windows.append(p)
@ -287,6 +295,7 @@ class BasePredictor:
cv2.waitKey(500 if self.batch[4].startswith('image') else 1) # 1 millisecond
def save_preds(self, vid_cap, idx, save_path):
"""Save video predictions as mp4 at specified path."""
im0 = self.plotted_img
# Save imgs
if self.dataset.mode == 'image':
@ -307,6 +316,7 @@ class BasePredictor:
self.vid_writer[idx].write(im0)
def run_callbacks(self, event: str):
"""Runs all registered callbacks for a specific event."""
for callback in self.callbacks.get(event, []):
callback(self)