Add pred, export and val callbacks (#126)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia 2023-01-01 22:46:10 +05:30 committed by GitHub
parent 63c7a74691
commit c6eb6720de
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 176 additions and 57 deletions

View file

@ -26,6 +26,7 @@ Usage - formats:
yolov8n_paddle_model # PaddlePaddle
"""
import platform
from collections import defaultdict
from pathlib import Path
import cv2
@ -35,6 +36,7 @@ from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, ops
from ultralytics.yolo.utils.callbacks import default_callbacks
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
@ -89,6 +91,11 @@ class BasePredictor:
self.annotator = None
self.data_path = None
# callbacks
self.callbacks = defaultdict([])
for callback, func in default_callbacks.items():
self.add_callback(callback, func)
def preprocess(self, img):
pass
@ -143,9 +150,11 @@ class BasePredictor:
@smart_inference_mode()
def __call__(self, source=None, model=None):
self.run_callbacks("on_predict_start")
model = self.model if self.done_setup else self.setup(source, model)
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
for batch in self.dataset:
self.run_callbacks("on_predict_batch_start")
path, im, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
with self.dt[0]:
@ -176,6 +185,8 @@ class BasePredictor:
# Print time (inference-only)
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
self.run_callbacks("on_predict_batch_end")
# Print results
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
LOGGER.info(
@ -185,6 +196,8 @@ class BasePredictor:
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
self.run_callbacks("on_predict_end")
def show(self, p):
im0 = self.annotator.result()
if platform.system() == 'Linux' and p not in self.windows:
@ -213,3 +226,19 @@ class BasePredictor:
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
self.vid_writer[idx].write(im0)
def add_callback(self, event: str, callback):
"""
appends the given callback
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
overrides the existing callbacks with the given callback
"""
self.callbacks[event] = [callback]
def run_callbacks(self, event: str):
for callback in self.callbacks.get(event, []):
callback(self)