General trainer cleanup (#147)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia 2023-01-07 19:25:48 +05:30 committed by GitHub
parent f8a13c49a0
commit 0e5a7ae623
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 196 additions and 60 deletions

View file

@ -23,6 +23,7 @@ from tqdm import tqdm
import ultralytics.yolo.utils as utils
from ultralytics import __version__
from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
@ -380,21 +381,18 @@ class BaseTrainer:
"""
load/create/download model for any task
"""
if isinstance(self.model, torch.nn.Module): # if loaded model is passed
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
# We should improve the code flow here. This function looks hacky
model = self.model
pretrained = not str(model).endswith(".yaml")
# config
if not pretrained:
model = check_file(model)
ckpt = self.load_ckpt(model) if pretrained else None
weights = ckpt["model"] if isinstance(ckpt, dict) else ckpt # torchvision weights are not dicts
self.model = self.load_model(model_cfg=None if pretrained else model, weights=weights)
return ckpt
def load_ckpt(self, ckpt):
return torch.load(ckpt, map_location='cpu')
model, weights = self.model, None
ckpt = None
if str(model).endswith(".pt"):
weights, ckpt = attempt_load_one_weight(model)
cfg = ckpt["model"].yaml
else:
cfg = model
self.model = self.get_model(cfg=cfg, weights=weights) # calls Model(cfg, weights)
return ckpt
def optimizer_step(self):
self.scaler.unscale_(self.optimizer) # unscale gradients
@ -433,7 +431,7 @@ class BaseTrainer:
if rank in {-1, 0}:
self.console.info(text)
def load_model(self, model_cfg=None, weights=None, verbose=True):
def get_model(self, cfg=None, weights=None, verbose=True):
raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self):