Predictor support (#65)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
479992093c
commit
e6737f1207
22 changed files with 916 additions and 48 deletions
|
|
@ -15,7 +15,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from omegaconf import OmegaConf
|
||||
from torch.cuda import amp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import lr_scheduler
|
||||
|
|
@ -26,7 +26,9 @@ import ultralytics.yolo.utils.callbacks as callbacks
|
|||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
||||
from ultralytics.yolo.utils.checks import check_file, print_args
|
||||
from ultralytics.yolo.utils.configs import get_config
|
||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
|
||||
from ultralytics.yolo.utils.modeling import get_model
|
||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
||||
|
||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||
|
|
@ -36,7 +38,7 @@ RANK = int(os.getenv('RANK', -1))
|
|||
class BaseTrainer:
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||
self.args = self._get_config(config, overrides)
|
||||
self.args = get_config(config, overrides)
|
||||
self.check_resume()
|
||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||
|
||||
|
|
@ -84,25 +86,6 @@ class BaseTrainer:
|
|||
self.add_callback(callback, func)
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def _get_config(self, config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
|
||||
"""
|
||||
Accepts yaml file name or DictConfig containing experiment configuration.
|
||||
Returns training args namespace
|
||||
:param config: Optional file name or DictConfig object
|
||||
"""
|
||||
if isinstance(config, (str, Path)):
|
||||
config = OmegaConf.load(config)
|
||||
elif isinstance(config, Dict):
|
||||
config = OmegaConf.create(config)
|
||||
|
||||
# override
|
||||
if isinstance(overrides, str):
|
||||
overrides = OmegaConf.load(overrides)
|
||||
elif isinstance(overrides, Dict):
|
||||
overrides = OmegaConf.create(overrides)
|
||||
|
||||
return OmegaConf.merge(config, overrides)
|
||||
|
||||
def add_callback(self, onevent: str, callback):
|
||||
"""
|
||||
appends the given callback
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue