ultralytics 8.0.14 Hydra removal fixes and cleanup (#542)

Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kamlesh Kumar <patelkamleshpatel364@gmail.com>
This commit is contained in:
Glenn Jocher 2023-01-21 21:22:40 +01:00 committed by GitHub
parent cc3be0e223
commit d9a0fba251
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 339 additions and 301 deletions

View file

@ -33,7 +33,7 @@ from pathlib import Path
import cv2
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, SETTINGS, callbacks, colorstr, ops
@ -70,7 +70,7 @@ class BasePredictor:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_config(config, overrides)
self.args = get_cfg(config, overrides)
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}"
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
@ -84,6 +84,7 @@ class BasePredictor:
self.bs = None
self.imgsz = None
self.device = None
self.classes = self.args.classes
self.dataset = None
self.vid_path, self.vid_writer = None, None
self.annotator = None
@ -100,7 +101,7 @@ class BasePredictor:
def write_results(self, results, batch, print_string):
raise NotImplementedError("print_results function needs to be implemented")
def postprocess(self, preds, img, orig_img):
def postprocess(self, preds, img, orig_img, classes=None):
return preds
def setup_source(self, source=None):
@ -195,7 +196,7 @@ class BasePredictor:
# postprocess
with self.dt[2]:
results = self.postprocess(preds, im, im0s)
results = self.postprocess(preds, im, im0s, self.classes)
for i in range(len(im)):
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
p = Path(p)