General trainer cleanup (#147)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
f8a13c49a0
commit
0e5a7ae623
8 changed files with 196 additions and 60 deletions
|
|
@ -23,6 +23,7 @@ from tqdm import tqdm
|
|||
|
||||
import ultralytics.yolo.utils as utils
|
||||
from ultralytics import __version__
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
||||
|
|
@ -380,21 +381,18 @@ class BaseTrainer:
|
|||
"""
|
||||
load/create/download model for any task
|
||||
"""
|
||||
if isinstance(self.model, torch.nn.Module): # if loaded model is passed
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
# We should improve the code flow here. This function looks hacky
|
||||
model = self.model
|
||||
pretrained = not str(model).endswith(".yaml")
|
||||
# config
|
||||
if not pretrained:
|
||||
model = check_file(model)
|
||||
ckpt = self.load_ckpt(model) if pretrained else None
|
||||
weights = ckpt["model"] if isinstance(ckpt, dict) else ckpt # torchvision weights are not dicts
|
||||
self.model = self.load_model(model_cfg=None if pretrained else model, weights=weights)
|
||||
return ckpt
|
||||
|
||||
def load_ckpt(self, ckpt):
|
||||
return torch.load(ckpt, map_location='cpu')
|
||||
model, weights = self.model, None
|
||||
ckpt = None
|
||||
if str(model).endswith(".pt"):
|
||||
weights, ckpt = attempt_load_one_weight(model)
|
||||
cfg = ckpt["model"].yaml
|
||||
else:
|
||||
cfg = model
|
||||
self.model = self.get_model(cfg=cfg, weights=weights) # calls Model(cfg, weights)
|
||||
return ckpt
|
||||
|
||||
def optimizer_step(self):
|
||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||
|
|
@ -433,7 +431,7 @@ class BaseTrainer:
|
|||
if rank in {-1, 0}:
|
||||
self.console.info(text)
|
||||
|
||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
def get_validator(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue