Unified model loading with backwards compatibility (#132)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
8996c5c6cf
commit
c3d961fb03
10 changed files with 65 additions and 50 deletions
|
|
@ -77,13 +77,12 @@ class YOLO:
|
|||
Args:
|
||||
weights (str): model checkpoint to be loaded
|
||||
"""
|
||||
self.ckpt = torch.load(weights, map_location="cpu")
|
||||
self.task = self.ckpt["train_args"]["task"]
|
||||
self.overrides = dict(self.ckpt["train_args"])
|
||||
self.model = attempt_load_weights(weights)
|
||||
self.task = self.model.args["task"]
|
||||
self.overrides = self.model.args
|
||||
self.overrides["device"] = '' # reset device
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
||||
self._guess_ops_from_task(self.task)
|
||||
self.model = attempt_load_weights(weights, fuse=False)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
|
|
@ -189,7 +188,7 @@ class YOLO:
|
|||
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
|
||||
|
||||
self.trainer = self.TrainerClass(overrides=overrides)
|
||||
self.trainer.model = self.trainer.load_model(weights=self.ckpt,
|
||||
self.trainer.model = self.trainer.load_model(weights=self.model,
|
||||
model_cfg=self.model.yaml if self.task != "classify" else None)
|
||||
self.model = self.trainer.model # override here to save memory
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue