ultralytics 8.1.39 add YOLO-World training (#9268)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-03-31 22:30:17 +08:00 committed by GitHub
parent 18036908d4
commit e9187c1296
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 2166 additions and 100 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.38"
__version__ = "8.1.39"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

File diff suppressed because it is too large Load diff

View file

@ -1,15 +1,31 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .base import BaseDataset
from .build import build_dataloader, build_yolo_dataset, load_inference_source
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
from .build import (
build_dataloader,
build_yolo_dataset,
build_grounding,
load_inference_source,
)
from .dataset import (
ClassificationDataset,
SemanticDataset,
YOLODataset,
YOLOMultiModalDataset,
GroundingDataset,
YOLOConcatDataset,
)
__all__ = (
"BaseDataset",
"ClassificationDataset",
"SemanticDataset",
"YOLODataset",
"YOLOMultiModalDataset",
"YOLOConcatDataset",
"GroundingDataset",
"build_yolo_dataset",
"build_grounding",
"build_dataloader",
"load_inference_source",
)

View file

@ -3,6 +3,7 @@
import math
import random
from copy import deepcopy
from typing import Tuple, Union
import cv2
import numpy as np
@ -66,7 +67,7 @@ class Compose:
def __init__(self, transforms):
"""Initializes the Compose object with a list of transforms."""
self.transforms = transforms
self.transforms = transforms if isinstance(transforms, list) else [transforms]
def __call__(self, data):
"""Applies a series of transformations to input data."""
@ -78,6 +79,29 @@ class Compose:
"""Appends a new transform to the existing list of transforms."""
self.transforms.append(transform)
def insert(self, index, transform):
"""Inserts a new transform to the existing list of transforms."""
self.transforms.insert(index, transform)
def __getitem__(self, index: Union[list, int]) -> "Compose":
"""Retrieve a specific transform or a set of transforms using indexing."""
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
index = [index] if isinstance(index, int) else index
return Compose([self.transforms[i] for i in index])
def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None:
"""Retrieve a specific transform or a set of transforms using indexing."""
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
if isinstance(index, list):
assert isinstance(
value, list
), f"The indices should be the same type as values, but got {type(index)} and {type(value)}"
if isinstance(index, int):
index, value = [index], [value]
for i, v in zip(index, value):
assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}."
self.transforms[i] = v
def tolist(self):
"""Converts the list of transforms to a standard Python list."""
return self.transforms
@ -118,6 +142,8 @@ class BaseMixTransform:
mix_labels[i] = self.pre_transform(data)
labels["mix_labels"] = mix_labels
# Update cls and texts
labels = self._update_label_text(labels)
# Mosaic or MixUp
labels = self._mix_transform(labels)
labels.pop("mix_labels", None)
@ -131,6 +157,22 @@ class BaseMixTransform:
"""Gets a list of shuffled indexes for mosaic augmentation."""
raise NotImplementedError
def _update_label_text(self, labels):
"""Update label text."""
if "texts" not in labels:
return labels
mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], [])
mix_texts = list({tuple(x) for x in mix_texts})
text2id = {text: i for i, text in enumerate(mix_texts)}
for label in [labels] + labels["mix_labels"]:
for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
text = label["texts"][int(l)]
label["cls"][i] = text2id[tuple(text)]
label["texts"] = mix_texts
return labels
class Mosaic(BaseMixTransform):
"""
@ -320,6 +362,8 @@ class Mosaic(BaseMixTransform):
final_labels["instances"].clip(imgsz, imgsz)
good = final_labels["instances"].remove_zero_area_boxes()
final_labels["cls"] = final_labels["cls"][good]
if "texts" in mosaic_labels[0]:
final_labels["texts"] = mosaic_labels[0]["texts"]
return final_labels
@ -970,6 +1014,83 @@ class Format:
return masks, instances, cls
class RandomLoadText:
"""
Randomly sample positive texts and negative texts and update the class indices accordingly to the number of samples.
Attributes:
prompt_format (str): Format for prompt. Default is '{}'.
neg_samples (tuple[int]): A ranger to randomly sample negative texts, Default is (80, 80).
max_samples (int): The max number of different text samples in one image, Default is 80.
padding (bool): Whether to pad texts to max_samples. Default is False.
padding_value (str): The padding text. Default is "".
"""
def __init__(
self,
prompt_format: str = "{}",
neg_samples: Tuple[int, int] = (80, 80),
max_samples: int = 80,
padding: bool = False,
padding_value: str = "",
) -> None:
"""Initializes the RandomLoadText class with given parameters."""
self.prompt_format = prompt_format
self.neg_samples = neg_samples
self.max_samples = max_samples
self.padding = padding
self.padding_value = padding_value
def __call__(self, labels: dict) -> dict:
"""Return updated classes and texts."""
assert "texts" in labels, "No texts found in labels."
class_texts = labels["texts"]
num_classes = len(class_texts)
cls = np.asarray(labels.pop("cls"), dtype=int)
pos_labels = np.unique(cls).tolist()
if len(pos_labels) > self.max_samples:
pos_labels = set(random.sample(pos_labels, k=self.max_samples))
neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))
neg_labels = []
for i in range(num_classes):
if i not in pos_labels:
neg_labels.append(i)
neg_labels = random.sample(neg_labels, k=neg_samples)
sampled_labels = pos_labels + neg_labels
random.shuffle(sampled_labels)
label2ids = {label: i for i, label in enumerate(sampled_labels)}
valid_idx = np.zeros(len(labels["instances"]), dtype=bool)
new_cls = []
for i, label in enumerate(cls.squeeze(-1).tolist()):
if label not in label2ids:
continue
valid_idx[i] = True
new_cls.append([label2ids[label]])
labels["instances"] = labels["instances"][valid_idx]
labels["cls"] = np.array(new_cls)
# Randomly select one prompt when there's more than one prompts
texts = []
for label in sampled_labels:
prompts = class_texts[label]
assert len(prompts) > 0
prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))])
texts.append(prompt)
if self.padding:
valid_labels = len(pos_labels) + len(neg_labels)
num_padding = self.max_samples - valid_labels
if num_padding > 0:
texts += [self.padding_value] * num_padding
labels["texts"] = texts
return labels
def v8_transforms(dataset, imgsz, hyp, stretch=False):
"""Convert images to a size suitable for YOLOv8 training."""
pre_transform = Compose(

View file

@ -22,7 +22,7 @@ from ultralytics.data.loaders import (
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.utils import RANK, colorstr
from ultralytics.utils.checks import check_file
from .dataset import YOLODataset
from .dataset import YOLODataset, YOLOMultiModalDataset, GroundingDataset
from .utils import PIN_MEMORY
@ -82,9 +82,10 @@ def seed_worker(worker_id): # noqa
random.seed(worker_seed)
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
"""Build YOLO Dataset."""
return YOLODataset(
dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
return dataset(
img_path=img_path,
imgsz=cfg.imgsz,
batch_size=batch,
@ -103,6 +104,27 @@ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, str
)
def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):
"""Build YOLO Dataset."""
return GroundingDataset(
img_path=img_path,
json_file=json_file,
imgsz=cfg.imgsz,
batch_size=batch,
augment=mode == "train", # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
rect=cfg.rect or rect, # rectangular batches
cache=cfg.cache or None,
single_cls=cfg.single_cls or False,
stride=int(stride),
pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "),
task=cfg.task,
classes=cfg.classes,
fraction=cfg.fraction if mode == "train" else 1.0,
)
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
"""Return an InfiniteDataLoader or DataLoader for training or validation set."""
batch = min(batch, len(dataset))

View file

@ -219,6 +219,7 @@ def convert_coco(
use_segments=False,
use_keypoints=False,
cls91to80=True,
lvis=False,
):
"""
Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
@ -229,12 +230,14 @@ def convert_coco(
use_segments (bool, optional): Whether to include segmentation masks in the output.
use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
lvis (bool, optional): Whether to convert data in lvis dataset way.
Example:
```python
from ultralytics.data.converter import convert_coco
convert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True)
convert_coco('../datasets/lvis/annotations/', use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)
```
Output:
@ -251,8 +254,14 @@ def convert_coco(
# Import json
for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name
lname = "" if lvis else json_file.stem.replace("instances_", "")
fn = Path(save_dir) / "labels" / lname # folder name
fn.mkdir(parents=True, exist_ok=True)
if lvis:
# NOTE: create folders for both train and val in advance,
# since LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split.
(fn / "train2017").mkdir(parents=True, exist_ok=True)
(fn / "val2017").mkdir(parents=True, exist_ok=True)
with open(json_file) as f:
data = json.load(f)
@ -263,16 +272,20 @@ def convert_coco(
for ann in data["annotations"]:
imgToAnns[ann["image_id"]].append(ann)
image_txt = []
# Write labels file
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"):
img = images[f"{img_id:d}"]
h, w, f = img["height"], img["width"], img["file_name"]
h, w = img["height"], img["width"]
f = str(Path(img["coco_url"]).relative_to("http://images.cocodataset.org")) if lvis else img["file_name"]
if lvis:
image_txt.append(str(Path("./images") / f))
bboxes = []
segments = []
keypoints = []
for ann in anns:
if ann["iscrowd"]:
if ann.get("iscrowd", False):
continue
# The COCO box format is [top left x, top left y, width, height]
box = np.array(ann["bbox"], dtype=np.float64)
@ -314,7 +327,12 @@ def convert_coco(
) # cls, box or segments
file.write(("%g " * len(line)).rstrip() % line + "\n")
LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}")
if lvis:
with open((Path(save_dir) / json_file.name.replace("lvis_v1_", "").replace(".json", ".txt")), "a") as f:
for l in image_txt:
f.write(f"{l}\n")
LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}")
def convert_dota_to_yolo_obb(dota_root_path: str):

View file

@ -1,20 +1,41 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
from itertools import repeat
from collections import defaultdict
from multiprocessing.pool import ThreadPool
from pathlib import Path
import cv2
import json
import numpy as np
import torch
import torchvision
from PIL import Image
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
from torch.utils.data import ConcatDataset
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.ops import resample_segments
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
from .augment import (
Compose,
Format,
Instances,
LetterBox,
RandomLoadText,
classify_augmentations,
classify_transforms,
v8_transforms,
)
from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
from .utils import (
HELP_URL,
LOGGER,
get_hash,
img2label_paths,
verify_image,
verify_image_label,
load_dataset_cache_file,
save_dataset_cache_file,
)
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
DATASET_CACHE_VERSION = "1.0.3"
@ -105,7 +126,7 @@ class YOLODataset(BaseDataset):
x["hash"] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x)
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return x
def get_labels(self):
@ -339,31 +360,125 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
x["hash"] = get_hash([x[0] for x in self.samples])
x["results"] = nf, nc, len(samples), samples
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x)
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return samples
def load_dataset_cache_file(path):
"""Load an Ultralytics *.cache dictionary from path."""
import gc
class YOLOMultiModalDataset(YOLODataset):
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format.
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
cache = np.load(str(path), allow_pickle=True).item() # load dict
gc.enable()
return cache
Args:
data (dict, optional): A dataset YAML dictionary. Defaults to None.
task (str): An explicit arg to point current task, Defaults to 'detect'.
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
def __init__(self, *args, data=None, task="detect", **kwargs):
"""Initializes a dataset object for object detection tasks with optional specifications."""
super().__init__(*args, data=data, task=task, **kwargs)
def update_labels_info(self, label):
"""Add texts information for multi modal model training."""
labels = super().update_labels_info(label)
# NOTE: some categories are concatenated with its synonyms by `/`.
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
return labels
def build_transforms(self, hyp=None):
"""Enhances data transformations with optional text augmentation for multi-modal training."""
transforms = super().build_transforms(hyp)
if self.augment:
# NOTE: hard-coded the args for now.
transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
return transforms
def save_dataset_cache_file(prefix, path, x):
"""Save an Ultralytics dataset *.cache dictionary x to path."""
x["version"] = DATASET_CACHE_VERSION # add cache version
if is_dir_writeable(path.parent):
if path.exists():
path.unlink() # remove *.cache file if exists
np.save(str(path), x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
LOGGER.info(f"{prefix}New cache created: {path}")
else:
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
class GroundingDataset(YOLODataset):
def __init__(self, *args, task="detect", json_file, **kwargs):
"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
self.json_file = json_file
super().__init__(*args, task=task, data={}, **kwargs)
def get_img_files(self, img_path):
"""The image files would be read in `get_labels` function, return empty list here."""
return []
def get_labels(self):
"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
labels = []
LOGGER.info("Loading annotation file...")
with open(self.json_file, "r") as f:
annotations = json.load(f)
images = {f'{x["id"]:d}': x for x in annotations["images"]}
imgToAnns = defaultdict(list)
for ann in annotations["annotations"]:
imgToAnns[ann["image_id"]].append(ann)
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Reading annotations {self.json_file}"):
img = images[f"{img_id:d}"]
h, w, f = img["height"], img["width"], img["file_name"]
im_file = Path(self.img_path) / f
if not im_file.exists():
continue
self.im_files.append(str(im_file))
bboxes = []
cat2id = {}
texts = []
for ann in anns:
if ann["iscrowd"]:
continue
box = np.array(ann["bbox"], dtype=np.float32)
box[:2] += box[2:] / 2
box[[0, 2]] /= float(w)
box[[1, 3]] /= float(h)
if box[2] <= 0 or box[3] <= 0:
continue
cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
if cat_name not in cat2id:
cat2id[cat_name] = len(cat2id)
texts.append([cat_name])
cls = cat2id[cat_name] # class
box = [cls] + box.tolist()
if box not in bboxes:
bboxes.append(box)
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
labels.append(
dict(
im_file=im_file,
shape=(h, w),
cls=lb[:, 0:1], # n, 1
bboxes=lb[:, 1:], # n, 4
normalized=True,
bbox_format="xywh",
texts=texts,
)
)
return labels
def build_transforms(self, hyp=None):
"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
transforms = super().build_transforms(hyp)
if self.augment:
# NOTE: hard-coded the args for now.
transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
return transforms
class YOLOConcatDataset(ConcatDataset):
"""
Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
"""
@staticmethod
def collate_fn(batch):
"""Collates data samples into batches."""
return YOLODataset.collate_fn(batch)
# TODO: support semantic segmentation

View file

@ -29,6 +29,7 @@ from ultralytics.utils import (
emojis,
yaml_load,
yaml_save,
is_dir_writeable,
)
from ultralytics.utils.checks import check_file, check_font, is_ascii
from ultralytics.utils.downloads import download, safe_download, unzip_file
@ -303,7 +304,7 @@ def check_det_dataset(dataset, autodownload=True):
# Set paths
data["path"] = path # download scripts
for k in "train", "val", "test":
for k in "train", "val", "test", "minival":
if data.get(k): # prepend path
if isinstance(data[k], str):
x = (path / data[k]).resolve()
@ -649,3 +650,26 @@ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annot
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
with open(path.parent / txt[i], "a") as f:
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
def load_dataset_cache_file(path):
"""Load an Ultralytics *.cache dictionary from path."""
import gc
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
cache = np.load(str(path), allow_pickle=True).item() # load dict
gc.enable()
return cache
def save_dataset_cache_file(prefix, path, x, version):
"""Save an Ultralytics dataset *.cache dictionary x to path."""
x["version"] = version # add cache version
if is_dir_writeable(path.parent):
if path.exists():
path.unlink() # remove *.cache file if exists
np.save(str(path), x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
LOGGER.info(f"{prefix}New cache created: {path}")
else:
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")

View file

@ -126,22 +126,7 @@ class BaseTrainer:
# Model and Dataset
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
try:
if self.args.task == "classify":
self.data = check_cls_dataset(self.args.data)
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
"detect",
"segment",
"pose",
"obb",
):
self.data = check_det_dataset(self.args.data)
if "yaml_file" in self.data:
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
self.trainset, self.testset = self.get_dataset(self.data)
self.trainset, self.testset = self.get_dataset()
self.ema = None
# Optimization utils init
@ -509,13 +494,27 @@ class BaseTrainer:
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
@staticmethod
def get_dataset(data):
def get_dataset(self):
"""
Get train, val path from data dict if it exists.
Returns None if data format is not recognized.
"""
try:
if self.args.task == "classify":
data = check_cls_dataset(self.args.data)
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
"detect",
"segment",
"pose",
"obb",
):
data = check_det_dataset(self.args.data)
if "yaml_file" in data:
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
self.data = data
return data["train"], data.get("val") or data.get("test")
def setup_model(self):
@ -666,8 +665,8 @@ class BaseTrainer:
if ckpt is None:
return
best_fitness = 0.0
start_epoch = ckpt["epoch"] + 1
if ckpt["optimizer"] is not None:
start_epoch = ckpt.get("epoch", -1) + 1
if ckpt.get("optimizer", None) is not None:
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
best_fitness = ckpt["best_fitness"]
if self.ema and ckpt.get("ema"):

View file

@ -35,7 +35,7 @@ class FastSAMPrompt:
except ImportError:
from ultralytics.utils.checks import check_requirements
check_requirements("git+https://github.com/openai/CLIP.git")
check_requirements("git+https://github.com/ultralytics/CLIP.git")
import clip
self.clip = clip

View file

@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models.yolo import classify, detect, obb, pose, segment
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
from .model import YOLO, YOLOWorld
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld"
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"

View file

@ -33,6 +33,7 @@ class DetectionValidator(BaseValidator):
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.nt_per_class = None
self.is_coco = False
self.is_lvis = False
self.class_map = None
self.args.task = "detect"
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
@ -66,8 +67,9 @@ class DetectionValidator(BaseValidator):
"""Initialize evaluation metrics for YOLO."""
val = self.data.get(self.args.split, "") # validation path
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(len(model.names)))
self.args.save_json |= (self.is_coco or self.is_lvis) and not self.training # run on final val if training COCO
self.names = model.names
self.nc = len(model.names)
self.metrics.names = self.names
@ -266,7 +268,8 @@ class DetectionValidator(BaseValidator):
self.jdict.append(
{
"image_id": image_id,
"category_id": self.class_map[int(p[5])],
"category_id": self.class_map[int(p[5])]
+ (1 if self.is_lvis else 0), # index starts from 1 if it's lvis
"bbox": [round(x, 3) for x in b],
"score": round(p[4], 5),
}
@ -274,26 +277,42 @@ class DetectionValidator(BaseValidator):
def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
pred_json = self.save_dir / "predictions.json" # predictions
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
anno_json = (
self.data["path"]
/ "annotations"
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
) # annotations
pkg = "pycocotools" if self.is_coco else "lvis"
LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements("pycocotools>=2.0.6")
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
for x in pred_json, anno_json:
assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
eval = COCOeval(anno, pred, "bbox")
check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
eval = COCOeval(anno, pred, "bbox")
else:
from lvis import LVIS, LVISEval
anno = LVIS(str(anno_json)) # init annotations api
pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
eval = LVISEval(anno, pred, "bbox")
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
if self.is_lvis:
eval.print_results() # explicitly call print_results
# update mAP50-95 and mAP50
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]]
)
except Exception as e:
LOGGER.warning(f"pycocotools unable to run: {e}")
LOGGER.warning(f"{pkg} unable to run: {e}")
return stats

View file

@ -83,6 +83,7 @@ class YOLOWorld(Model):
"model": WorldModel,
"validator": yolo.detect.DetectionValidator,
"predictor": yolo.detect.DetectionPredictor,
"trainer": yolo.world.WorldTrainer,
}
}

View file

@ -0,0 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .train import WorldTrainer
__all__ = ["WorldTrainer"]

View file

@ -0,0 +1,91 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models import yolo
from ultralytics.nn.tasks import WorldModel
from ultralytics.utils import DEFAULT_CFG, RANK
from ultralytics.data import build_yolo_dataset
from ultralytics.utils.torch_utils import de_parallel
from ultralytics.utils.checks import check_requirements
import itertools
try:
import clip
except ImportError:
check_requirements("git+https://github.com/ultralytics/CLIP.git")
import clip
def on_pretrain_routine_end(trainer):
"""Callback."""
if RANK in (-1, 0):
# NOTE: for evaluation
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
device = next(trainer.model.parameters()).device
text_model, _ = clip.load("ViT-B/32", device=device)
for p in text_model.parameters():
p.requires_grad_(False)
trainer.text_model = text_model
class WorldTrainer(yolo.detect.DetectionTrainer):
"""
A class to fine-tune a world model on a close-set dataset.
Example:
```python
from ultralytics.models.yolo.world import WorldModel
args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
trainer = WorldTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a WorldTrainer object with given arguments."""
if overrides is None:
overrides = {}
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return WorldModel initialized with specified config and weights."""
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
# NOTE: Following the official config, nc hard-coded to 80 for now.
model = WorldModel(
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
ch=3,
nc=min(self.data["nc"], 80),
verbose=verbose and RANK == -1,
)
if weights:
model.load(weights)
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
return model
def build_dataset(self, img_path, mode="train", batch=None):
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
)
def preprocess_batch(self, batch):
"""Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
batch = super().preprocess_batch(batch)
# NOTE: add text features
texts = list(itertools.chain(*batch["texts"]))
text_token = clip.tokenize(texts).to(batch["img"].device)
txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
return batch

View file

@ -0,0 +1,108 @@
from ultralytics.data import build_yolo_dataset, build_grounding, YOLOConcatDataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.world import WorldTrainer
from ultralytics.utils.torch_utils import de_parallel
from ultralytics.utils import DEFAULT_CFG
class WorldTrainerFromScratch(WorldTrainer):
"""
A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
Example:
```python
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
from ultralytics import YOLOWorld
data = dict(
train=dict(
yolo_data=["Objects365.yaml"],
grounding_data=[
dict(
img_path="../datasets/flickr30k/images",
json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
),
dict(
img_path="../datasets/GQA/images",
json_file="../datasets/GQA/final_mixed_train_no_coco.json",
),
],
),
val=dict(yolo_data=["lvis.yaml"]),
)
model = YOLOWorld("yolov8s-worldv2.yaml")
model.train(data=data, trainer=WorldTrainerFromScratch)
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a WorldTrainer object with given arguments."""
if overrides is None:
overrides = {}
super().__init__(cfg, overrides, _callbacks)
def build_dataset(self, img_path, mode="train", batch=None):
"""
Build YOLO Dataset.
Args:
img_path (List[str] | str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
if mode == "train":
dataset = [
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
if isinstance(im_path, str)
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
for im_path in img_path
]
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
else:
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
def get_dataset(self):
"""
Get train, val path from data dict if it exists.
Returns None if data format is not recognized.
"""
final_data = dict()
data_yaml = self.args.data
assert data_yaml.get("train", False) # object365.yaml
assert data_yaml.get("val", False) # lvis.yaml
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
for d in data["val"]:
if d.get("minival") is None: # for lvis dataset
continue
d["minival"] = str(d["path"] / d["minival"])
for s in ["train", "val"]:
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
# save grounding data if there's one
grounding_data = data_yaml[s].get("grounding_data")
if grounding_data is None:
continue
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
for g in grounding_data:
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
final_data[s] += grounding_data
# NOTE: to make training work properly, set `nc` and `names`
final_data["nc"] = data["val"][0]["nc"]
final_data["names"] = data["val"][0]["names"]
self.data = final_data
return final_data["train"], final_data["val"][0]
def plot_training_labels(self):
"""DO NOT plot labels."""
pass
def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO-World model."""
val = self.args.data["val"]["yolo_data"][0]
self.validator.args.data = val
self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
return super().final_eval()

View file

@ -519,7 +519,8 @@ class ContrastiveHead(nn.Module):
def __init__(self):
"""Initializes ContrastiveHead with specified region-text similarity parameters."""
super().__init__()
self.bias = nn.Parameter(torch.zeros([]))
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
self.bias = nn.Parameter(torch.tensor([-10.0]))
self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
def forward(self, x, w):
@ -542,7 +543,8 @@ class BNContrastiveHead(nn.Module):
"""Initialize ContrastiveHead with region-text similarity parameters."""
super().__init__()
self.norm = nn.BatchNorm2d(embed_dims)
self.bias = nn.Parameter(torch.zeros([]))
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
self.bias = nn.Parameter(torch.tensor([-10.0]))
# use -1.0 is more stable
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))

View file

@ -250,6 +250,15 @@ class WorldDetect(Detect):
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
# b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
class RTDETRDecoder(nn.Module):
"""

View file

@ -564,28 +564,28 @@ class WorldModel(DetectionModel):
self.clip_model = None # CLIP model placeholder
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def set_classes(self, text):
"""Perform a forward pass with optional profiling, visualization, and embedding extraction."""
def set_classes(self, text, batch=80, cache_clip_model=True):
"""Set classes in advance so that model could do offline-inference without clip model."""
try:
import clip
except ImportError:
check_requirements("git+https://github.com/openai/CLIP.git")
check_requirements("git+https://github.com/ultralytics/CLIP.git")
import clip
if not getattr(self, "clip_model", None): # for backwards compatibility of models lacking clip_model attribute
if (
not getattr(self, "clip_model", None) and cache_clip_model
): # for backwards compatibility of models lacking clip_model attribute
self.clip_model = clip.load("ViT-B/32")[0]
device = next(self.clip_model.parameters()).device
model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
device = next(model.parameters()).device
text_token = clip.tokenize(text).to(device)
txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32)
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
self.model[-1].nc = len(text)
def init_criterion(self):
"""Initialize the loss criterion for the model."""
raise NotImplementedError
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
"""
Perform a forward pass through the model.
@ -593,13 +593,14 @@ class WorldModel(DetectionModel):
x (torch.Tensor): The input tensor.
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
embed (list, optional): A list of feature vectors/embeddings to return.
Returns:
(torch.Tensor): Model's output tensor.
"""
txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
if len(txt_feats) != len(x):
txt_feats = txt_feats.repeat(len(x), 1, 1)
ori_txt_feats = txt_feats.clone()
@ -627,6 +628,21 @@ class WorldModel(DetectionModel):
return torch.unbind(torch.cat(embeddings, 1), dim=0)
return x
def loss(self, batch, preds=None):
"""
Compute loss.
Args:
batch (dict): Batch to compute loss on.
preds (torch.Tensor | List[torch.Tensor]): Predictions.
"""
if not hasattr(self, "criterion"):
self.criterion = self.init_criterion()
if preds is None:
preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
return self.criterion(preds, batch)
class Ensemble(nn.ModuleList):
"""Ensemble of models."""

View file

@ -157,7 +157,7 @@ class v8DetectionLoss:
self.hyp = h
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
self.no = m.no
self.no = m.nc + m.reg_max * 4
self.reg_max = m.reg_max
self.device = device