ultralytics 8.1.39 add YOLO-World training (#9268)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
18036908d4
commit
e9187c1296
34 changed files with 2166 additions and 100 deletions
|
|
@ -1,20 +1,41 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
import contextlib
|
||||
from itertools import repeat
|
||||
from collections import defaultdict
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
|
||||
from torch.utils.data import ConcatDataset
|
||||
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
||||
from ultralytics.utils.ops import resample_segments
|
||||
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
|
||||
from .augment import (
|
||||
Compose,
|
||||
Format,
|
||||
Instances,
|
||||
LetterBox,
|
||||
RandomLoadText,
|
||||
classify_augmentations,
|
||||
classify_transforms,
|
||||
v8_transforms,
|
||||
)
|
||||
from .base import BaseDataset
|
||||
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
|
||||
from .utils import (
|
||||
HELP_URL,
|
||||
LOGGER,
|
||||
get_hash,
|
||||
img2label_paths,
|
||||
verify_image,
|
||||
verify_image_label,
|
||||
load_dataset_cache_file,
|
||||
save_dataset_cache_file,
|
||||
)
|
||||
|
||||
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
||||
DATASET_CACHE_VERSION = "1.0.3"
|
||||
|
|
@ -105,7 +126,7 @@ class YOLODataset(BaseDataset):
|
|||
x["hash"] = get_hash(self.label_files + self.im_files)
|
||||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||||
x["msgs"] = msgs # warnings
|
||||
save_dataset_cache_file(self.prefix, path, x)
|
||||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||
return x
|
||||
|
||||
def get_labels(self):
|
||||
|
|
@ -339,31 +360,125 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
x["hash"] = get_hash([x[0] for x in self.samples])
|
||||
x["results"] = nf, nc, len(samples), samples
|
||||
x["msgs"] = msgs # warnings
|
||||
save_dataset_cache_file(self.prefix, path, x)
|
||||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||
return samples
|
||||
|
||||
|
||||
def load_dataset_cache_file(path):
|
||||
"""Load an Ultralytics *.cache dictionary from path."""
|
||||
import gc
|
||||
class YOLOMultiModalDataset(YOLODataset):
|
||||
"""
|
||||
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
||||
|
||||
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
||||
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
||||
gc.enable()
|
||||
return cache
|
||||
Args:
|
||||
data (dict, optional): A dataset YAML dictionary. Defaults to None.
|
||||
task (str): An explicit arg to point current task, Defaults to 'detect'.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, data=None, task="detect", **kwargs):
|
||||
"""Initializes a dataset object for object detection tasks with optional specifications."""
|
||||
super().__init__(*args, data=data, task=task, **kwargs)
|
||||
|
||||
def update_labels_info(self, label):
|
||||
"""Add texts information for multi modal model training."""
|
||||
labels = super().update_labels_info(label)
|
||||
# NOTE: some categories are concatenated with its synonyms by `/`.
|
||||
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
|
||||
return labels
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
"""Enhances data transformations with optional text augmentation for multi-modal training."""
|
||||
transforms = super().build_transforms(hyp)
|
||||
if self.augment:
|
||||
# NOTE: hard-coded the args for now.
|
||||
transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
|
||||
return transforms
|
||||
|
||||
|
||||
def save_dataset_cache_file(prefix, path, x):
|
||||
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
||||
x["version"] = DATASET_CACHE_VERSION # add cache version
|
||||
if is_dir_writeable(path.parent):
|
||||
if path.exists():
|
||||
path.unlink() # remove *.cache file if exists
|
||||
np.save(str(path), x) # save cache for next time
|
||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||
LOGGER.info(f"{prefix}New cache created: {path}")
|
||||
else:
|
||||
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
||||
class GroundingDataset(YOLODataset):
|
||||
def __init__(self, *args, task="detect", json_file, **kwargs):
|
||||
"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
|
||||
assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
|
||||
self.json_file = json_file
|
||||
super().__init__(*args, task=task, data={}, **kwargs)
|
||||
|
||||
def get_img_files(self, img_path):
|
||||
"""The image files would be read in `get_labels` function, return empty list here."""
|
||||
return []
|
||||
|
||||
def get_labels(self):
|
||||
"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
|
||||
labels = []
|
||||
LOGGER.info("Loading annotation file...")
|
||||
with open(self.json_file, "r") as f:
|
||||
annotations = json.load(f)
|
||||
images = {f'{x["id"]:d}': x for x in annotations["images"]}
|
||||
imgToAnns = defaultdict(list)
|
||||
for ann in annotations["annotations"]:
|
||||
imgToAnns[ann["image_id"]].append(ann)
|
||||
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Reading annotations {self.json_file}"):
|
||||
img = images[f"{img_id:d}"]
|
||||
h, w, f = img["height"], img["width"], img["file_name"]
|
||||
im_file = Path(self.img_path) / f
|
||||
if not im_file.exists():
|
||||
continue
|
||||
self.im_files.append(str(im_file))
|
||||
bboxes = []
|
||||
cat2id = {}
|
||||
texts = []
|
||||
for ann in anns:
|
||||
if ann["iscrowd"]:
|
||||
continue
|
||||
box = np.array(ann["bbox"], dtype=np.float32)
|
||||
box[:2] += box[2:] / 2
|
||||
box[[0, 2]] /= float(w)
|
||||
box[[1, 3]] /= float(h)
|
||||
if box[2] <= 0 or box[3] <= 0:
|
||||
continue
|
||||
|
||||
cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
|
||||
if cat_name not in cat2id:
|
||||
cat2id[cat_name] = len(cat2id)
|
||||
texts.append([cat_name])
|
||||
cls = cat2id[cat_name] # class
|
||||
box = [cls] + box.tolist()
|
||||
if box not in bboxes:
|
||||
bboxes.append(box)
|
||||
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
||||
labels.append(
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=(h, w),
|
||||
cls=lb[:, 0:1], # n, 1
|
||||
bboxes=lb[:, 1:], # n, 4
|
||||
normalized=True,
|
||||
bbox_format="xywh",
|
||||
texts=texts,
|
||||
)
|
||||
)
|
||||
return labels
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
|
||||
transforms = super().build_transforms(hyp)
|
||||
if self.augment:
|
||||
# NOTE: hard-coded the args for now.
|
||||
transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
|
||||
return transforms
|
||||
|
||||
|
||||
class YOLOConcatDataset(ConcatDataset):
|
||||
"""
|
||||
Dataset as a concatenation of multiple datasets.
|
||||
|
||||
This class is useful to assemble different existing datasets.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
"""Collates data samples into batches."""
|
||||
return YOLODataset.collate_fn(batch)
|
||||
|
||||
|
||||
# TODO: support semantic segmentation
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue