ultralytics 8.1.39 add YOLO-World training (#9268)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
18036908d4
commit
e9187c1296
34 changed files with 2166 additions and 100 deletions
|
|
@ -3,6 +3,7 @@
|
|||
import math
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from typing import Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
|
@ -66,7 +67,7 @@ class Compose:
|
|||
|
||||
def __init__(self, transforms):
|
||||
"""Initializes the Compose object with a list of transforms."""
|
||||
self.transforms = transforms
|
||||
self.transforms = transforms if isinstance(transforms, list) else [transforms]
|
||||
|
||||
def __call__(self, data):
|
||||
"""Applies a series of transformations to input data."""
|
||||
|
|
@ -78,6 +79,29 @@ class Compose:
|
|||
"""Appends a new transform to the existing list of transforms."""
|
||||
self.transforms.append(transform)
|
||||
|
||||
def insert(self, index, transform):
|
||||
"""Inserts a new transform to the existing list of transforms."""
|
||||
self.transforms.insert(index, transform)
|
||||
|
||||
def __getitem__(self, index: Union[list, int]) -> "Compose":
|
||||
"""Retrieve a specific transform or a set of transforms using indexing."""
|
||||
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
|
||||
index = [index] if isinstance(index, int) else index
|
||||
return Compose([self.transforms[i] for i in index])
|
||||
|
||||
def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None:
|
||||
"""Retrieve a specific transform or a set of transforms using indexing."""
|
||||
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
|
||||
if isinstance(index, list):
|
||||
assert isinstance(
|
||||
value, list
|
||||
), f"The indices should be the same type as values, but got {type(index)} and {type(value)}"
|
||||
if isinstance(index, int):
|
||||
index, value = [index], [value]
|
||||
for i, v in zip(index, value):
|
||||
assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}."
|
||||
self.transforms[i] = v
|
||||
|
||||
def tolist(self):
|
||||
"""Converts the list of transforms to a standard Python list."""
|
||||
return self.transforms
|
||||
|
|
@ -118,6 +142,8 @@ class BaseMixTransform:
|
|||
mix_labels[i] = self.pre_transform(data)
|
||||
labels["mix_labels"] = mix_labels
|
||||
|
||||
# Update cls and texts
|
||||
labels = self._update_label_text(labels)
|
||||
# Mosaic or MixUp
|
||||
labels = self._mix_transform(labels)
|
||||
labels.pop("mix_labels", None)
|
||||
|
|
@ -131,6 +157,22 @@ class BaseMixTransform:
|
|||
"""Gets a list of shuffled indexes for mosaic augmentation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _update_label_text(self, labels):
|
||||
"""Update label text."""
|
||||
if "texts" not in labels:
|
||||
return labels
|
||||
|
||||
mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], [])
|
||||
mix_texts = list({tuple(x) for x in mix_texts})
|
||||
text2id = {text: i for i, text in enumerate(mix_texts)}
|
||||
|
||||
for label in [labels] + labels["mix_labels"]:
|
||||
for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
|
||||
text = label["texts"][int(l)]
|
||||
label["cls"][i] = text2id[tuple(text)]
|
||||
label["texts"] = mix_texts
|
||||
return labels
|
||||
|
||||
|
||||
class Mosaic(BaseMixTransform):
|
||||
"""
|
||||
|
|
@ -320,6 +362,8 @@ class Mosaic(BaseMixTransform):
|
|||
final_labels["instances"].clip(imgsz, imgsz)
|
||||
good = final_labels["instances"].remove_zero_area_boxes()
|
||||
final_labels["cls"] = final_labels["cls"][good]
|
||||
if "texts" in mosaic_labels[0]:
|
||||
final_labels["texts"] = mosaic_labels[0]["texts"]
|
||||
return final_labels
|
||||
|
||||
|
||||
|
|
@ -970,6 +1014,83 @@ class Format:
|
|||
return masks, instances, cls
|
||||
|
||||
|
||||
class RandomLoadText:
|
||||
"""
|
||||
Randomly sample positive texts and negative texts and update the class indices accordingly to the number of samples.
|
||||
|
||||
Attributes:
|
||||
prompt_format (str): Format for prompt. Default is '{}'.
|
||||
neg_samples (tuple[int]): A ranger to randomly sample negative texts, Default is (80, 80).
|
||||
max_samples (int): The max number of different text samples in one image, Default is 80.
|
||||
padding (bool): Whether to pad texts to max_samples. Default is False.
|
||||
padding_value (str): The padding text. Default is "".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_format: str = "{}",
|
||||
neg_samples: Tuple[int, int] = (80, 80),
|
||||
max_samples: int = 80,
|
||||
padding: bool = False,
|
||||
padding_value: str = "",
|
||||
) -> None:
|
||||
"""Initializes the RandomLoadText class with given parameters."""
|
||||
self.prompt_format = prompt_format
|
||||
self.neg_samples = neg_samples
|
||||
self.max_samples = max_samples
|
||||
self.padding = padding
|
||||
self.padding_value = padding_value
|
||||
|
||||
def __call__(self, labels: dict) -> dict:
|
||||
"""Return updated classes and texts."""
|
||||
assert "texts" in labels, "No texts found in labels."
|
||||
class_texts = labels["texts"]
|
||||
num_classes = len(class_texts)
|
||||
cls = np.asarray(labels.pop("cls"), dtype=int)
|
||||
pos_labels = np.unique(cls).tolist()
|
||||
|
||||
if len(pos_labels) > self.max_samples:
|
||||
pos_labels = set(random.sample(pos_labels, k=self.max_samples))
|
||||
|
||||
neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))
|
||||
neg_labels = []
|
||||
for i in range(num_classes):
|
||||
if i not in pos_labels:
|
||||
neg_labels.append(i)
|
||||
neg_labels = random.sample(neg_labels, k=neg_samples)
|
||||
|
||||
sampled_labels = pos_labels + neg_labels
|
||||
random.shuffle(sampled_labels)
|
||||
|
||||
label2ids = {label: i for i, label in enumerate(sampled_labels)}
|
||||
valid_idx = np.zeros(len(labels["instances"]), dtype=bool)
|
||||
new_cls = []
|
||||
for i, label in enumerate(cls.squeeze(-1).tolist()):
|
||||
if label not in label2ids:
|
||||
continue
|
||||
valid_idx[i] = True
|
||||
new_cls.append([label2ids[label]])
|
||||
labels["instances"] = labels["instances"][valid_idx]
|
||||
labels["cls"] = np.array(new_cls)
|
||||
|
||||
# Randomly select one prompt when there's more than one prompts
|
||||
texts = []
|
||||
for label in sampled_labels:
|
||||
prompts = class_texts[label]
|
||||
assert len(prompts) > 0
|
||||
prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))])
|
||||
texts.append(prompt)
|
||||
|
||||
if self.padding:
|
||||
valid_labels = len(pos_labels) + len(neg_labels)
|
||||
num_padding = self.max_samples - valid_labels
|
||||
if num_padding > 0:
|
||||
texts += [self.padding_value] * num_padding
|
||||
|
||||
labels["texts"] = texts
|
||||
return labels
|
||||
|
||||
|
||||
def v8_transforms(dataset, imgsz, hyp, stretch=False):
|
||||
"""Convert images to a size suitable for YOLOv8 training."""
|
||||
pre_transform = Compose(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue