ultralytics 8.1.39 add YOLO-World training (#9268)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-03-31 22:30:17 +08:00 committed by GitHub
parent 18036908d4
commit e9187c1296
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 2166 additions and 100 deletions

View file

@ -1,20 +1,41 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
from itertools import repeat
from collections import defaultdict
from multiprocessing.pool import ThreadPool
from pathlib import Path
import cv2
import json
import numpy as np
import torch
import torchvision
from PIL import Image
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
from torch.utils.data import ConcatDataset
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.ops import resample_segments
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
from .augment import (
Compose,
Format,
Instances,
LetterBox,
RandomLoadText,
classify_augmentations,
classify_transforms,
v8_transforms,
)
from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
from .utils import (
HELP_URL,
LOGGER,
get_hash,
img2label_paths,
verify_image,
verify_image_label,
load_dataset_cache_file,
save_dataset_cache_file,
)
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
DATASET_CACHE_VERSION = "1.0.3"
@ -105,7 +126,7 @@ class YOLODataset(BaseDataset):
x["hash"] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x)
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return x
def get_labels(self):
@ -339,31 +360,125 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
x["hash"] = get_hash([x[0] for x in self.samples])
x["results"] = nf, nc, len(samples), samples
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x)
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return samples
def load_dataset_cache_file(path):
"""Load an Ultralytics *.cache dictionary from path."""
import gc
class YOLOMultiModalDataset(YOLODataset):
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format.
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
cache = np.load(str(path), allow_pickle=True).item() # load dict
gc.enable()
return cache
Args:
data (dict, optional): A dataset YAML dictionary. Defaults to None.
task (str): An explicit arg to point current task, Defaults to 'detect'.
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
def __init__(self, *args, data=None, task="detect", **kwargs):
"""Initializes a dataset object for object detection tasks with optional specifications."""
super().__init__(*args, data=data, task=task, **kwargs)
def update_labels_info(self, label):
"""Add texts information for multi modal model training."""
labels = super().update_labels_info(label)
# NOTE: some categories are concatenated with its synonyms by `/`.
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
return labels
def build_transforms(self, hyp=None):
"""Enhances data transformations with optional text augmentation for multi-modal training."""
transforms = super().build_transforms(hyp)
if self.augment:
# NOTE: hard-coded the args for now.
transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
return transforms
def save_dataset_cache_file(prefix, path, x):
"""Save an Ultralytics dataset *.cache dictionary x to path."""
x["version"] = DATASET_CACHE_VERSION # add cache version
if is_dir_writeable(path.parent):
if path.exists():
path.unlink() # remove *.cache file if exists
np.save(str(path), x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
LOGGER.info(f"{prefix}New cache created: {path}")
else:
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
class GroundingDataset(YOLODataset):
def __init__(self, *args, task="detect", json_file, **kwargs):
"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
self.json_file = json_file
super().__init__(*args, task=task, data={}, **kwargs)
def get_img_files(self, img_path):
"""The image files would be read in `get_labels` function, return empty list here."""
return []
def get_labels(self):
"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
labels = []
LOGGER.info("Loading annotation file...")
with open(self.json_file, "r") as f:
annotations = json.load(f)
images = {f'{x["id"]:d}': x for x in annotations["images"]}
imgToAnns = defaultdict(list)
for ann in annotations["annotations"]:
imgToAnns[ann["image_id"]].append(ann)
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Reading annotations {self.json_file}"):
img = images[f"{img_id:d}"]
h, w, f = img["height"], img["width"], img["file_name"]
im_file = Path(self.img_path) / f
if not im_file.exists():
continue
self.im_files.append(str(im_file))
bboxes = []
cat2id = {}
texts = []
for ann in anns:
if ann["iscrowd"]:
continue
box = np.array(ann["bbox"], dtype=np.float32)
box[:2] += box[2:] / 2
box[[0, 2]] /= float(w)
box[[1, 3]] /= float(h)
if box[2] <= 0 or box[3] <= 0:
continue
cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
if cat_name not in cat2id:
cat2id[cat_name] = len(cat2id)
texts.append([cat_name])
cls = cat2id[cat_name] # class
box = [cls] + box.tolist()
if box not in bboxes:
bboxes.append(box)
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
labels.append(
dict(
im_file=im_file,
shape=(h, w),
cls=lb[:, 0:1], # n, 1
bboxes=lb[:, 1:], # n, 4
normalized=True,
bbox_format="xywh",
texts=texts,
)
)
return labels
def build_transforms(self, hyp=None):
"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
transforms = super().build_transforms(hyp)
if self.augment:
# NOTE: hard-coded the args for now.
transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
return transforms
class YOLOConcatDataset(ConcatDataset):
"""
Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
"""
@staticmethod
def collate_fn(batch):
"""Collates data samples into batches."""
return YOLODataset.collate_fn(batch)
# TODO: support semantic segmentation