Add pred, export and val callbacks (#126)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
63c7a74691
commit
c6eb6720de
8 changed files with 176 additions and 57 deletions
|
|
@ -26,6 +26,7 @@ Usage - formats:
|
|||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import platform
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
|
|
@ -35,6 +36,7 @@ from ultralytics.yolo.configs import get_config
|
|||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, ops
|
||||
from ultralytics.yolo.utils.callbacks import default_callbacks
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||
|
|
@ -89,6 +91,11 @@ class BasePredictor:
|
|||
self.annotator = None
|
||||
self.data_path = None
|
||||
|
||||
# callbacks
|
||||
self.callbacks = defaultdict([])
|
||||
for callback, func in default_callbacks.items():
|
||||
self.add_callback(callback, func)
|
||||
|
||||
def preprocess(self, img):
|
||||
pass
|
||||
|
||||
|
|
@ -143,9 +150,11 @@ class BasePredictor:
|
|||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, source=None, model=None):
|
||||
self.run_callbacks("on_predict_start")
|
||||
model = self.model if self.done_setup else self.setup(source, model)
|
||||
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks("on_predict_batch_start")
|
||||
path, im, im0s, vid_cap, s = batch
|
||||
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
||||
with self.dt[0]:
|
||||
|
|
@ -176,6 +185,8 @@ class BasePredictor:
|
|||
# Print time (inference-only)
|
||||
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
||||
|
||||
self.run_callbacks("on_predict_batch_end")
|
||||
|
||||
# Print results
|
||||
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
||||
LOGGER.info(
|
||||
|
|
@ -185,6 +196,8 @@ class BasePredictor:
|
|||
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||
|
||||
self.run_callbacks("on_predict_end")
|
||||
|
||||
def show(self, p):
|
||||
im0 = self.annotator.result()
|
||||
if platform.system() == 'Linux' and p not in self.windows:
|
||||
|
|
@ -213,3 +226,19 @@ class BasePredictor:
|
|||
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
|
||||
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
self.vid_writer[idx].write(im0)
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
appends the given callback
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
overrides the existing callbacks with the given callback
|
||||
"""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue