Predictor support (#65)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia 2022-12-07 10:33:10 +05:30 committed by GitHub
parent 479992093c
commit e6737f1207
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 916 additions and 48 deletions

View file

@ -15,7 +15,7 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf
from omegaconf import OmegaConf
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import lr_scheduler
@ -26,7 +26,9 @@ import ultralytics.yolo.utils.callbacks as callbacks
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
from ultralytics.yolo.utils.checks import check_file, print_args
from ultralytics.yolo.utils.configs import get_config
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
@ -36,7 +38,7 @@ RANK = int(os.getenv('RANK', -1))
class BaseTrainer:
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
self.args = self._get_config(config, overrides)
self.args = get_config(config, overrides)
self.check_resume()
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
@ -84,25 +86,6 @@ class BaseTrainer:
self.add_callback(callback, func)
callbacks.add_integration_callbacks(self)
def _get_config(self, config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
"""
Accepts yaml file name or DictConfig containing experiment configuration.
Returns training args namespace
:param config: Optional file name or DictConfig object
"""
if isinstance(config, (str, Path)):
config = OmegaConf.load(config)
elif isinstance(config, Dict):
config = OmegaConf.create(config)
# override
if isinstance(overrides, str):
overrides = OmegaConf.load(overrides)
elif isinstance(overrides, Dict):
overrides = OmegaConf.create(overrides)
return OmegaConf.merge(config, overrides)
def add_callback(self, onevent: str, callback):
"""
appends the given callback