ultralytics 8.1.39 add YOLO-World training (#9268)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
18036908d4
commit
e9187c1296
34 changed files with 2166 additions and 100 deletions
|
|
@ -126,22 +126,7 @@ class BaseTrainer:
|
|||
|
||||
# Model and Dataset
|
||||
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
try:
|
||||
if self.args.task == "classify":
|
||||
self.data = check_cls_dataset(self.args.data)
|
||||
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
||||
"detect",
|
||||
"segment",
|
||||
"pose",
|
||||
"obb",
|
||||
):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
if "yaml_file" in self.data:
|
||||
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||
except Exception as e:
|
||||
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
self.trainset, self.testset = self.get_dataset()
|
||||
self.ema = None
|
||||
|
||||
# Optimization utils init
|
||||
|
|
@ -509,13 +494,27 @@ class BaseTrainer:
|
|||
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
||||
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(data):
|
||||
def get_dataset(self):
|
||||
"""
|
||||
Get train, val path from data dict if it exists.
|
||||
|
||||
Returns None if data format is not recognized.
|
||||
"""
|
||||
try:
|
||||
if self.args.task == "classify":
|
||||
data = check_cls_dataset(self.args.data)
|
||||
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
||||
"detect",
|
||||
"segment",
|
||||
"pose",
|
||||
"obb",
|
||||
):
|
||||
data = check_det_dataset(self.args.data)
|
||||
if "yaml_file" in data:
|
||||
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||
except Exception as e:
|
||||
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||
self.data = data
|
||||
return data["train"], data.get("val") or data.get("test")
|
||||
|
||||
def setup_model(self):
|
||||
|
|
@ -666,8 +665,8 @@ class BaseTrainer:
|
|||
if ckpt is None:
|
||||
return
|
||||
best_fitness = 0.0
|
||||
start_epoch = ckpt["epoch"] + 1
|
||||
if ckpt["optimizer"] is not None:
|
||||
start_epoch = ckpt.get("epoch", -1) + 1
|
||||
if ckpt.get("optimizer", None) is not None:
|
||||
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
||||
best_fitness = ckpt["best_fitness"]
|
||||
if self.ema and ckpt.get("ema"):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue