General trainer cleanup (#147)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
f8a13c49a0
commit
0e5a7ae623
8 changed files with 196 additions and 60 deletions
|
|
@ -1,7 +1,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
|
||||
|
|
@ -45,8 +45,8 @@ class YOLO:
|
|||
self.trainer = None # trainer object
|
||||
self.task = None # task type
|
||||
self.ckpt = None # if loaded from *.pt
|
||||
self.ckpt_path = None
|
||||
self.cfg = None # if loaded from *.yaml
|
||||
self.ckpt_path = None
|
||||
self.overrides = {} # overrides for trainer object
|
||||
|
||||
# Load or create new YOLO model
|
||||
|
|
@ -78,7 +78,7 @@ class YOLO:
|
|||
Args:
|
||||
weights (str): model checkpoint to be loaded
|
||||
"""
|
||||
self.model = attempt_load_weights(weights)
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.ckpt_path = weights
|
||||
self.task = self.model.args["task"]
|
||||
self.overrides = self.model.args
|
||||
|
|
@ -188,14 +188,14 @@ class YOLO:
|
|||
overrides["mode"] = "train"
|
||||
if not overrides.get("data"):
|
||||
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
|
||||
|
||||
if overrides.get("resume"):
|
||||
overrides["resume"] = self.ckpt_path
|
||||
|
||||
self.trainer = self.TrainerClass(overrides=overrides)
|
||||
if not overrides.get("resume"):
|
||||
self.trainer.model = self.trainer.load_model(weights=self.model,
|
||||
model_cfg=self.model.yaml if self.task != "classify" else None)
|
||||
self.model = self.trainer.model # override here to save memory
|
||||
if not overrides.get("resume"): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None,
|
||||
cfg=self.model.yaml if self.task != "classify" else None)
|
||||
self.model = self.trainer.model
|
||||
|
||||
self.trainer.train()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue