ultralytics 8.1.39 add YOLO-World training (#9268)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
18036908d4
commit
e9187c1296
34 changed files with 2166 additions and 100 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.1.38"
|
||||
__version__ = "8.1.39"
|
||||
|
||||
from ultralytics.data.explorer.explorer import Explorer
|
||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
||||
|
|
|
|||
1239
ultralytics/cfg/datasets/lvis.yaml
Normal file
1239
ultralytics/cfg/datasets/lvis.yaml
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,15 +1,31 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .base import BaseDataset
|
||||
from .build import build_dataloader, build_yolo_dataset, load_inference_source
|
||||
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
||||
from .build import (
|
||||
build_dataloader,
|
||||
build_yolo_dataset,
|
||||
build_grounding,
|
||||
load_inference_source,
|
||||
)
|
||||
from .dataset import (
|
||||
ClassificationDataset,
|
||||
SemanticDataset,
|
||||
YOLODataset,
|
||||
YOLOMultiModalDataset,
|
||||
GroundingDataset,
|
||||
YOLOConcatDataset,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"BaseDataset",
|
||||
"ClassificationDataset",
|
||||
"SemanticDataset",
|
||||
"YOLODataset",
|
||||
"YOLOMultiModalDataset",
|
||||
"YOLOConcatDataset",
|
||||
"GroundingDataset",
|
||||
"build_yolo_dataset",
|
||||
"build_grounding",
|
||||
"build_dataloader",
|
||||
"load_inference_source",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import math
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from typing import Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
|
@ -66,7 +67,7 @@ class Compose:
|
|||
|
||||
def __init__(self, transforms):
|
||||
"""Initializes the Compose object with a list of transforms."""
|
||||
self.transforms = transforms
|
||||
self.transforms = transforms if isinstance(transforms, list) else [transforms]
|
||||
|
||||
def __call__(self, data):
|
||||
"""Applies a series of transformations to input data."""
|
||||
|
|
@ -78,6 +79,29 @@ class Compose:
|
|||
"""Appends a new transform to the existing list of transforms."""
|
||||
self.transforms.append(transform)
|
||||
|
||||
def insert(self, index, transform):
|
||||
"""Inserts a new transform to the existing list of transforms."""
|
||||
self.transforms.insert(index, transform)
|
||||
|
||||
def __getitem__(self, index: Union[list, int]) -> "Compose":
|
||||
"""Retrieve a specific transform or a set of transforms using indexing."""
|
||||
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
|
||||
index = [index] if isinstance(index, int) else index
|
||||
return Compose([self.transforms[i] for i in index])
|
||||
|
||||
def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None:
|
||||
"""Retrieve a specific transform or a set of transforms using indexing."""
|
||||
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
|
||||
if isinstance(index, list):
|
||||
assert isinstance(
|
||||
value, list
|
||||
), f"The indices should be the same type as values, but got {type(index)} and {type(value)}"
|
||||
if isinstance(index, int):
|
||||
index, value = [index], [value]
|
||||
for i, v in zip(index, value):
|
||||
assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}."
|
||||
self.transforms[i] = v
|
||||
|
||||
def tolist(self):
|
||||
"""Converts the list of transforms to a standard Python list."""
|
||||
return self.transforms
|
||||
|
|
@ -118,6 +142,8 @@ class BaseMixTransform:
|
|||
mix_labels[i] = self.pre_transform(data)
|
||||
labels["mix_labels"] = mix_labels
|
||||
|
||||
# Update cls and texts
|
||||
labels = self._update_label_text(labels)
|
||||
# Mosaic or MixUp
|
||||
labels = self._mix_transform(labels)
|
||||
labels.pop("mix_labels", None)
|
||||
|
|
@ -131,6 +157,22 @@ class BaseMixTransform:
|
|||
"""Gets a list of shuffled indexes for mosaic augmentation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _update_label_text(self, labels):
|
||||
"""Update label text."""
|
||||
if "texts" not in labels:
|
||||
return labels
|
||||
|
||||
mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], [])
|
||||
mix_texts = list({tuple(x) for x in mix_texts})
|
||||
text2id = {text: i for i, text in enumerate(mix_texts)}
|
||||
|
||||
for label in [labels] + labels["mix_labels"]:
|
||||
for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
|
||||
text = label["texts"][int(l)]
|
||||
label["cls"][i] = text2id[tuple(text)]
|
||||
label["texts"] = mix_texts
|
||||
return labels
|
||||
|
||||
|
||||
class Mosaic(BaseMixTransform):
|
||||
"""
|
||||
|
|
@ -320,6 +362,8 @@ class Mosaic(BaseMixTransform):
|
|||
final_labels["instances"].clip(imgsz, imgsz)
|
||||
good = final_labels["instances"].remove_zero_area_boxes()
|
||||
final_labels["cls"] = final_labels["cls"][good]
|
||||
if "texts" in mosaic_labels[0]:
|
||||
final_labels["texts"] = mosaic_labels[0]["texts"]
|
||||
return final_labels
|
||||
|
||||
|
||||
|
|
@ -970,6 +1014,83 @@ class Format:
|
|||
return masks, instances, cls
|
||||
|
||||
|
||||
class RandomLoadText:
|
||||
"""
|
||||
Randomly sample positive texts and negative texts and update the class indices accordingly to the number of samples.
|
||||
|
||||
Attributes:
|
||||
prompt_format (str): Format for prompt. Default is '{}'.
|
||||
neg_samples (tuple[int]): A ranger to randomly sample negative texts, Default is (80, 80).
|
||||
max_samples (int): The max number of different text samples in one image, Default is 80.
|
||||
padding (bool): Whether to pad texts to max_samples. Default is False.
|
||||
padding_value (str): The padding text. Default is "".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_format: str = "{}",
|
||||
neg_samples: Tuple[int, int] = (80, 80),
|
||||
max_samples: int = 80,
|
||||
padding: bool = False,
|
||||
padding_value: str = "",
|
||||
) -> None:
|
||||
"""Initializes the RandomLoadText class with given parameters."""
|
||||
self.prompt_format = prompt_format
|
||||
self.neg_samples = neg_samples
|
||||
self.max_samples = max_samples
|
||||
self.padding = padding
|
||||
self.padding_value = padding_value
|
||||
|
||||
def __call__(self, labels: dict) -> dict:
|
||||
"""Return updated classes and texts."""
|
||||
assert "texts" in labels, "No texts found in labels."
|
||||
class_texts = labels["texts"]
|
||||
num_classes = len(class_texts)
|
||||
cls = np.asarray(labels.pop("cls"), dtype=int)
|
||||
pos_labels = np.unique(cls).tolist()
|
||||
|
||||
if len(pos_labels) > self.max_samples:
|
||||
pos_labels = set(random.sample(pos_labels, k=self.max_samples))
|
||||
|
||||
neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))
|
||||
neg_labels = []
|
||||
for i in range(num_classes):
|
||||
if i not in pos_labels:
|
||||
neg_labels.append(i)
|
||||
neg_labels = random.sample(neg_labels, k=neg_samples)
|
||||
|
||||
sampled_labels = pos_labels + neg_labels
|
||||
random.shuffle(sampled_labels)
|
||||
|
||||
label2ids = {label: i for i, label in enumerate(sampled_labels)}
|
||||
valid_idx = np.zeros(len(labels["instances"]), dtype=bool)
|
||||
new_cls = []
|
||||
for i, label in enumerate(cls.squeeze(-1).tolist()):
|
||||
if label not in label2ids:
|
||||
continue
|
||||
valid_idx[i] = True
|
||||
new_cls.append([label2ids[label]])
|
||||
labels["instances"] = labels["instances"][valid_idx]
|
||||
labels["cls"] = np.array(new_cls)
|
||||
|
||||
# Randomly select one prompt when there's more than one prompts
|
||||
texts = []
|
||||
for label in sampled_labels:
|
||||
prompts = class_texts[label]
|
||||
assert len(prompts) > 0
|
||||
prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))])
|
||||
texts.append(prompt)
|
||||
|
||||
if self.padding:
|
||||
valid_labels = len(pos_labels) + len(neg_labels)
|
||||
num_padding = self.max_samples - valid_labels
|
||||
if num_padding > 0:
|
||||
texts += [self.padding_value] * num_padding
|
||||
|
||||
labels["texts"] = texts
|
||||
return labels
|
||||
|
||||
|
||||
def v8_transforms(dataset, imgsz, hyp, stretch=False):
|
||||
"""Convert images to a size suitable for YOLOv8 training."""
|
||||
pre_transform = Compose(
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from ultralytics.data.loaders import (
|
|||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.utils import RANK, colorstr
|
||||
from ultralytics.utils.checks import check_file
|
||||
from .dataset import YOLODataset
|
||||
from .dataset import YOLODataset, YOLOMultiModalDataset, GroundingDataset
|
||||
from .utils import PIN_MEMORY
|
||||
|
||||
|
||||
|
|
@ -82,9 +82,10 @@ def seed_worker(worker_id): # noqa
|
|||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
|
||||
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
|
||||
"""Build YOLO Dataset."""
|
||||
return YOLODataset(
|
||||
dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
|
||||
return dataset(
|
||||
img_path=img_path,
|
||||
imgsz=cfg.imgsz,
|
||||
batch_size=batch,
|
||||
|
|
@ -103,6 +104,27 @@ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, str
|
|||
)
|
||||
|
||||
|
||||
def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):
|
||||
"""Build YOLO Dataset."""
|
||||
return GroundingDataset(
|
||||
img_path=img_path,
|
||||
json_file=json_file,
|
||||
imgsz=cfg.imgsz,
|
||||
batch_size=batch,
|
||||
augment=mode == "train", # augmentation
|
||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||
rect=cfg.rect or rect, # rectangular batches
|
||||
cache=cfg.cache or None,
|
||||
single_cls=cfg.single_cls or False,
|
||||
stride=int(stride),
|
||||
pad=0.0 if mode == "train" else 0.5,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
task=cfg.task,
|
||||
classes=cfg.classes,
|
||||
fraction=cfg.fraction if mode == "train" else 1.0,
|
||||
)
|
||||
|
||||
|
||||
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
||||
"""Return an InfiniteDataLoader or DataLoader for training or validation set."""
|
||||
batch = min(batch, len(dataset))
|
||||
|
|
|
|||
|
|
@ -219,6 +219,7 @@ def convert_coco(
|
|||
use_segments=False,
|
||||
use_keypoints=False,
|
||||
cls91to80=True,
|
||||
lvis=False,
|
||||
):
|
||||
"""
|
||||
Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
|
||||
|
|
@ -229,12 +230,14 @@ def convert_coco(
|
|||
use_segments (bool, optional): Whether to include segmentation masks in the output.
|
||||
use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
|
||||
cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
|
||||
lvis (bool, optional): Whether to convert data in lvis dataset way.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.data.converter import convert_coco
|
||||
|
||||
convert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True)
|
||||
convert_coco('../datasets/lvis/annotations/', use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)
|
||||
```
|
||||
|
||||
Output:
|
||||
|
|
@ -251,8 +254,14 @@ def convert_coco(
|
|||
|
||||
# Import json
|
||||
for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
|
||||
fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name
|
||||
lname = "" if lvis else json_file.stem.replace("instances_", "")
|
||||
fn = Path(save_dir) / "labels" / lname # folder name
|
||||
fn.mkdir(parents=True, exist_ok=True)
|
||||
if lvis:
|
||||
# NOTE: create folders for both train and val in advance,
|
||||
# since LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split.
|
||||
(fn / "train2017").mkdir(parents=True, exist_ok=True)
|
||||
(fn / "val2017").mkdir(parents=True, exist_ok=True)
|
||||
with open(json_file) as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
|
@ -263,16 +272,20 @@ def convert_coco(
|
|||
for ann in data["annotations"]:
|
||||
imgToAnns[ann["image_id"]].append(ann)
|
||||
|
||||
image_txt = []
|
||||
# Write labels file
|
||||
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"):
|
||||
img = images[f"{img_id:d}"]
|
||||
h, w, f = img["height"], img["width"], img["file_name"]
|
||||
h, w = img["height"], img["width"]
|
||||
f = str(Path(img["coco_url"]).relative_to("http://images.cocodataset.org")) if lvis else img["file_name"]
|
||||
if lvis:
|
||||
image_txt.append(str(Path("./images") / f))
|
||||
|
||||
bboxes = []
|
||||
segments = []
|
||||
keypoints = []
|
||||
for ann in anns:
|
||||
if ann["iscrowd"]:
|
||||
if ann.get("iscrowd", False):
|
||||
continue
|
||||
# The COCO box format is [top left x, top left y, width, height]
|
||||
box = np.array(ann["bbox"], dtype=np.float64)
|
||||
|
|
@ -314,7 +327,12 @@ def convert_coco(
|
|||
) # cls, box or segments
|
||||
file.write(("%g " * len(line)).rstrip() % line + "\n")
|
||||
|
||||
LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}")
|
||||
if lvis:
|
||||
with open((Path(save_dir) / json_file.name.replace("lvis_v1_", "").replace(".json", ".txt")), "a") as f:
|
||||
for l in image_txt:
|
||||
f.write(f"{l}\n")
|
||||
|
||||
LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}")
|
||||
|
||||
|
||||
def convert_dota_to_yolo_obb(dota_root_path: str):
|
||||
|
|
|
|||
|
|
@ -1,20 +1,41 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
import contextlib
|
||||
from itertools import repeat
|
||||
from collections import defaultdict
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
|
||||
from torch.utils.data import ConcatDataset
|
||||
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
||||
from ultralytics.utils.ops import resample_segments
|
||||
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
|
||||
from .augment import (
|
||||
Compose,
|
||||
Format,
|
||||
Instances,
|
||||
LetterBox,
|
||||
RandomLoadText,
|
||||
classify_augmentations,
|
||||
classify_transforms,
|
||||
v8_transforms,
|
||||
)
|
||||
from .base import BaseDataset
|
||||
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
|
||||
from .utils import (
|
||||
HELP_URL,
|
||||
LOGGER,
|
||||
get_hash,
|
||||
img2label_paths,
|
||||
verify_image,
|
||||
verify_image_label,
|
||||
load_dataset_cache_file,
|
||||
save_dataset_cache_file,
|
||||
)
|
||||
|
||||
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
||||
DATASET_CACHE_VERSION = "1.0.3"
|
||||
|
|
@ -105,7 +126,7 @@ class YOLODataset(BaseDataset):
|
|||
x["hash"] = get_hash(self.label_files + self.im_files)
|
||||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||||
x["msgs"] = msgs # warnings
|
||||
save_dataset_cache_file(self.prefix, path, x)
|
||||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||
return x
|
||||
|
||||
def get_labels(self):
|
||||
|
|
@ -339,31 +360,125 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
x["hash"] = get_hash([x[0] for x in self.samples])
|
||||
x["results"] = nf, nc, len(samples), samples
|
||||
x["msgs"] = msgs # warnings
|
||||
save_dataset_cache_file(self.prefix, path, x)
|
||||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||
return samples
|
||||
|
||||
|
||||
def load_dataset_cache_file(path):
|
||||
"""Load an Ultralytics *.cache dictionary from path."""
|
||||
import gc
|
||||
class YOLOMultiModalDataset(YOLODataset):
|
||||
"""
|
||||
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
||||
|
||||
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
||||
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
||||
gc.enable()
|
||||
return cache
|
||||
Args:
|
||||
data (dict, optional): A dataset YAML dictionary. Defaults to None.
|
||||
task (str): An explicit arg to point current task, Defaults to 'detect'.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, data=None, task="detect", **kwargs):
|
||||
"""Initializes a dataset object for object detection tasks with optional specifications."""
|
||||
super().__init__(*args, data=data, task=task, **kwargs)
|
||||
|
||||
def update_labels_info(self, label):
|
||||
"""Add texts information for multi modal model training."""
|
||||
labels = super().update_labels_info(label)
|
||||
# NOTE: some categories are concatenated with its synonyms by `/`.
|
||||
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
|
||||
return labels
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
"""Enhances data transformations with optional text augmentation for multi-modal training."""
|
||||
transforms = super().build_transforms(hyp)
|
||||
if self.augment:
|
||||
# NOTE: hard-coded the args for now.
|
||||
transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
|
||||
return transforms
|
||||
|
||||
|
||||
def save_dataset_cache_file(prefix, path, x):
|
||||
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
||||
x["version"] = DATASET_CACHE_VERSION # add cache version
|
||||
if is_dir_writeable(path.parent):
|
||||
if path.exists():
|
||||
path.unlink() # remove *.cache file if exists
|
||||
np.save(str(path), x) # save cache for next time
|
||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||
LOGGER.info(f"{prefix}New cache created: {path}")
|
||||
else:
|
||||
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
||||
class GroundingDataset(YOLODataset):
|
||||
def __init__(self, *args, task="detect", json_file, **kwargs):
|
||||
"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
|
||||
assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
|
||||
self.json_file = json_file
|
||||
super().__init__(*args, task=task, data={}, **kwargs)
|
||||
|
||||
def get_img_files(self, img_path):
|
||||
"""The image files would be read in `get_labels` function, return empty list here."""
|
||||
return []
|
||||
|
||||
def get_labels(self):
|
||||
"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
|
||||
labels = []
|
||||
LOGGER.info("Loading annotation file...")
|
||||
with open(self.json_file, "r") as f:
|
||||
annotations = json.load(f)
|
||||
images = {f'{x["id"]:d}': x for x in annotations["images"]}
|
||||
imgToAnns = defaultdict(list)
|
||||
for ann in annotations["annotations"]:
|
||||
imgToAnns[ann["image_id"]].append(ann)
|
||||
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Reading annotations {self.json_file}"):
|
||||
img = images[f"{img_id:d}"]
|
||||
h, w, f = img["height"], img["width"], img["file_name"]
|
||||
im_file = Path(self.img_path) / f
|
||||
if not im_file.exists():
|
||||
continue
|
||||
self.im_files.append(str(im_file))
|
||||
bboxes = []
|
||||
cat2id = {}
|
||||
texts = []
|
||||
for ann in anns:
|
||||
if ann["iscrowd"]:
|
||||
continue
|
||||
box = np.array(ann["bbox"], dtype=np.float32)
|
||||
box[:2] += box[2:] / 2
|
||||
box[[0, 2]] /= float(w)
|
||||
box[[1, 3]] /= float(h)
|
||||
if box[2] <= 0 or box[3] <= 0:
|
||||
continue
|
||||
|
||||
cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
|
||||
if cat_name not in cat2id:
|
||||
cat2id[cat_name] = len(cat2id)
|
||||
texts.append([cat_name])
|
||||
cls = cat2id[cat_name] # class
|
||||
box = [cls] + box.tolist()
|
||||
if box not in bboxes:
|
||||
bboxes.append(box)
|
||||
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
||||
labels.append(
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=(h, w),
|
||||
cls=lb[:, 0:1], # n, 1
|
||||
bboxes=lb[:, 1:], # n, 4
|
||||
normalized=True,
|
||||
bbox_format="xywh",
|
||||
texts=texts,
|
||||
)
|
||||
)
|
||||
return labels
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
|
||||
transforms = super().build_transforms(hyp)
|
||||
if self.augment:
|
||||
# NOTE: hard-coded the args for now.
|
||||
transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
|
||||
return transforms
|
||||
|
||||
|
||||
class YOLOConcatDataset(ConcatDataset):
|
||||
"""
|
||||
Dataset as a concatenation of multiple datasets.
|
||||
|
||||
This class is useful to assemble different existing datasets.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
"""Collates data samples into batches."""
|
||||
return YOLODataset.collate_fn(batch)
|
||||
|
||||
|
||||
# TODO: support semantic segmentation
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from ultralytics.utils import (
|
|||
emojis,
|
||||
yaml_load,
|
||||
yaml_save,
|
||||
is_dir_writeable,
|
||||
)
|
||||
from ultralytics.utils.checks import check_file, check_font, is_ascii
|
||||
from ultralytics.utils.downloads import download, safe_download, unzip_file
|
||||
|
|
@ -303,7 +304,7 @@ def check_det_dataset(dataset, autodownload=True):
|
|||
|
||||
# Set paths
|
||||
data["path"] = path # download scripts
|
||||
for k in "train", "val", "test":
|
||||
for k in "train", "val", "test", "minival":
|
||||
if data.get(k): # prepend path
|
||||
if isinstance(data[k], str):
|
||||
x = (path / data[k]).resolve()
|
||||
|
|
@ -649,3 +650,26 @@ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annot
|
|||
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
||||
with open(path.parent / txt[i], "a") as f:
|
||||
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
|
||||
|
||||
|
||||
def load_dataset_cache_file(path):
|
||||
"""Load an Ultralytics *.cache dictionary from path."""
|
||||
import gc
|
||||
|
||||
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
||||
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
||||
gc.enable()
|
||||
return cache
|
||||
|
||||
|
||||
def save_dataset_cache_file(prefix, path, x, version):
|
||||
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
||||
x["version"] = version # add cache version
|
||||
if is_dir_writeable(path.parent):
|
||||
if path.exists():
|
||||
path.unlink() # remove *.cache file if exists
|
||||
np.save(str(path), x) # save cache for next time
|
||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||
LOGGER.info(f"{prefix}New cache created: {path}")
|
||||
else:
|
||||
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
||||
|
|
|
|||
|
|
@ -126,22 +126,7 @@ class BaseTrainer:
|
|||
|
||||
# Model and Dataset
|
||||
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
try:
|
||||
if self.args.task == "classify":
|
||||
self.data = check_cls_dataset(self.args.data)
|
||||
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
||||
"detect",
|
||||
"segment",
|
||||
"pose",
|
||||
"obb",
|
||||
):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
if "yaml_file" in self.data:
|
||||
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||
except Exception as e:
|
||||
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
self.trainset, self.testset = self.get_dataset()
|
||||
self.ema = None
|
||||
|
||||
# Optimization utils init
|
||||
|
|
@ -509,13 +494,27 @@ class BaseTrainer:
|
|||
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
||||
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(data):
|
||||
def get_dataset(self):
|
||||
"""
|
||||
Get train, val path from data dict if it exists.
|
||||
|
||||
Returns None if data format is not recognized.
|
||||
"""
|
||||
try:
|
||||
if self.args.task == "classify":
|
||||
data = check_cls_dataset(self.args.data)
|
||||
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
||||
"detect",
|
||||
"segment",
|
||||
"pose",
|
||||
"obb",
|
||||
):
|
||||
data = check_det_dataset(self.args.data)
|
||||
if "yaml_file" in data:
|
||||
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||
except Exception as e:
|
||||
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||
self.data = data
|
||||
return data["train"], data.get("val") or data.get("test")
|
||||
|
||||
def setup_model(self):
|
||||
|
|
@ -666,8 +665,8 @@ class BaseTrainer:
|
|||
if ckpt is None:
|
||||
return
|
||||
best_fitness = 0.0
|
||||
start_epoch = ckpt["epoch"] + 1
|
||||
if ckpt["optimizer"] is not None:
|
||||
start_epoch = ckpt.get("epoch", -1) + 1
|
||||
if ckpt.get("optimizer", None) is not None:
|
||||
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
||||
best_fitness = ckpt["best_fitness"]
|
||||
if self.ema and ckpt.get("ema"):
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class FastSAMPrompt:
|
|||
except ImportError:
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements("git+https://github.com/openai/CLIP.git")
|
||||
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||
import clip
|
||||
self.clip = clip
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.models.yolo import classify, detect, obb, pose, segment
|
||||
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
|
||||
|
||||
from .model import YOLO, YOLOWorld
|
||||
|
||||
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld"
|
||||
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class DetectionValidator(BaseValidator):
|
|||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.nt_per_class = None
|
||||
self.is_coco = False
|
||||
self.is_lvis = False
|
||||
self.class_map = None
|
||||
self.args.task = "detect"
|
||||
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
|
|
@ -66,8 +67,9 @@ class DetectionValidator(BaseValidator):
|
|||
"""Initialize evaluation metrics for YOLO."""
|
||||
val = self.data.get(self.args.split, "") # validation path
|
||||
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
|
||||
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
|
||||
self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
|
||||
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(len(model.names)))
|
||||
self.args.save_json |= (self.is_coco or self.is_lvis) and not self.training # run on final val if training COCO
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.metrics.names = self.names
|
||||
|
|
@ -266,7 +268,8 @@ class DetectionValidator(BaseValidator):
|
|||
self.jdict.append(
|
||||
{
|
||||
"image_id": image_id,
|
||||
"category_id": self.class_map[int(p[5])],
|
||||
"category_id": self.class_map[int(p[5])]
|
||||
+ (1 if self.is_lvis else 0), # index starts from 1 if it's lvis
|
||||
"bbox": [round(x, 3) for x in b],
|
||||
"score": round(p[4], 5),
|
||||
}
|
||||
|
|
@ -274,26 +277,42 @@ class DetectionValidator(BaseValidator):
|
|||
|
||||
def eval_json(self, stats):
|
||||
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
|
||||
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
|
||||
pred_json = self.save_dir / "predictions.json" # predictions
|
||||
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
||||
anno_json = (
|
||||
self.data["path"]
|
||||
/ "annotations"
|
||||
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
|
||||
) # annotations
|
||||
pkg = "pycocotools" if self.is_coco else "lvis"
|
||||
LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
|
||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||
check_requirements("pycocotools>=2.0.6")
|
||||
from pycocotools.coco import COCO # noqa
|
||||
from pycocotools.cocoeval import COCOeval # noqa
|
||||
|
||||
for x in anno_json, pred_json:
|
||||
for x in pred_json, anno_json:
|
||||
assert x.is_file(), f"{x} file not found"
|
||||
anno = COCO(str(anno_json)) # init annotations api
|
||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
eval = COCOeval(anno, pred, "bbox")
|
||||
check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
|
||||
if self.is_coco:
|
||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||
from pycocotools.coco import COCO # noqa
|
||||
from pycocotools.cocoeval import COCOeval # noqa
|
||||
|
||||
anno = COCO(str(anno_json)) # init annotations api
|
||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
eval = COCOeval(anno, pred, "bbox")
|
||||
else:
|
||||
from lvis import LVIS, LVISEval
|
||||
|
||||
anno = LVIS(str(anno_json)) # init annotations api
|
||||
pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
eval = LVISEval(anno, pred, "bbox")
|
||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||
eval.evaluate()
|
||||
eval.accumulate()
|
||||
eval.summarize()
|
||||
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||
if self.is_lvis:
|
||||
eval.print_results() # explicitly call print_results
|
||||
# update mAP50-95 and mAP50
|
||||
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
|
||||
eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]]
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"pycocotools unable to run: {e}")
|
||||
LOGGER.warning(f"{pkg} unable to run: {e}")
|
||||
return stats
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ class YOLOWorld(Model):
|
|||
"model": WorldModel,
|
||||
"validator": yolo.detect.DetectionValidator,
|
||||
"predictor": yolo.detect.DetectionPredictor,
|
||||
"trainer": yolo.world.WorldTrainer,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
5
ultralytics/models/yolo/world/__init__.py
Normal file
5
ultralytics/models/yolo/world/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .train import WorldTrainer
|
||||
|
||||
__all__ = ["WorldTrainer"]
|
||||
91
ultralytics/models/yolo/world/train.py
Normal file
91
ultralytics/models/yolo/world/train.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import WorldModel
|
||||
from ultralytics.utils import DEFAULT_CFG, RANK
|
||||
from ultralytics.data import build_yolo_dataset
|
||||
from ultralytics.utils.torch_utils import de_parallel
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
import itertools
|
||||
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||
import clip
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Callback."""
|
||||
if RANK in (-1, 0):
|
||||
# NOTE: for evaluation
|
||||
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
||||
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
||||
device = next(trainer.model.parameters()).device
|
||||
text_model, _ = clip.load("ViT-B/32", device=device)
|
||||
for p in text_model.parameters():
|
||||
p.requires_grad_(False)
|
||||
trainer.text_model = text_model
|
||||
|
||||
|
||||
class WorldTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class to fine-tune a world model on a close-set dataset.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.models.yolo.world import WorldModel
|
||||
|
||||
args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
|
||||
trainer = WorldTrainer(overrides=args)
|
||||
trainer.train()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a WorldTrainer object with given arguments."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Return WorldModel initialized with specified config and weights."""
|
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||
model = WorldModel(
|
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||
ch=3,
|
||||
nc=min(self.data["nc"], 80),
|
||||
verbose=verbose and RANK == -1,
|
||||
)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
|
||||
|
||||
return model
|
||||
|
||||
def build_dataset(self, img_path, mode="train", batch=None):
|
||||
"""
|
||||
Build YOLO Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||
"""
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_yolo_dataset(
|
||||
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
|
||||
)
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
|
||||
batch = super().preprocess_batch(batch)
|
||||
|
||||
# NOTE: add text features
|
||||
texts = list(itertools.chain(*batch["texts"]))
|
||||
text_token = clip.tokenize(texts).to(batch["img"].device)
|
||||
txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
|
||||
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
||||
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
||||
return batch
|
||||
108
ultralytics/models/yolo/world/train_world.py
Normal file
108
ultralytics/models/yolo/world/train_world.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
from ultralytics.data import build_yolo_dataset, build_grounding, YOLOConcatDataset
|
||||
from ultralytics.data.utils import check_det_dataset
|
||||
from ultralytics.models.yolo.world import WorldTrainer
|
||||
from ultralytics.utils.torch_utils import de_parallel
|
||||
from ultralytics.utils import DEFAULT_CFG
|
||||
|
||||
|
||||
class WorldTrainerFromScratch(WorldTrainer):
|
||||
"""
|
||||
A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
||||
from ultralytics import YOLOWorld
|
||||
|
||||
data = dict(
|
||||
train=dict(
|
||||
yolo_data=["Objects365.yaml"],
|
||||
grounding_data=[
|
||||
dict(
|
||||
img_path="../datasets/flickr30k/images",
|
||||
json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
|
||||
),
|
||||
dict(
|
||||
img_path="../datasets/GQA/images",
|
||||
json_file="../datasets/GQA/final_mixed_train_no_coco.json",
|
||||
),
|
||||
],
|
||||
),
|
||||
val=dict(yolo_data=["lvis.yaml"]),
|
||||
)
|
||||
|
||||
model = YOLOWorld("yolov8s-worldv2.yaml")
|
||||
model.train(data=data, trainer=WorldTrainerFromScratch)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a WorldTrainer object with given arguments."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def build_dataset(self, img_path, mode="train", batch=None):
|
||||
"""
|
||||
Build YOLO Dataset.
|
||||
|
||||
Args:
|
||||
img_path (List[str] | str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||
"""
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
if mode == "train":
|
||||
dataset = [
|
||||
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
|
||||
if isinstance(im_path, str)
|
||||
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
|
||||
for im_path in img_path
|
||||
]
|
||||
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
|
||||
else:
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
||||
|
||||
def get_dataset(self):
|
||||
"""
|
||||
Get train, val path from data dict if it exists.
|
||||
|
||||
Returns None if data format is not recognized.
|
||||
"""
|
||||
final_data = dict()
|
||||
data_yaml = self.args.data
|
||||
assert data_yaml.get("train", False) # object365.yaml
|
||||
assert data_yaml.get("val", False) # lvis.yaml
|
||||
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
|
||||
assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
|
||||
val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
|
||||
for d in data["val"]:
|
||||
if d.get("minival") is None: # for lvis dataset
|
||||
continue
|
||||
d["minival"] = str(d["path"] / d["minival"])
|
||||
for s in ["train", "val"]:
|
||||
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
|
||||
# save grounding data if there's one
|
||||
grounding_data = data_yaml[s].get("grounding_data")
|
||||
if grounding_data is None:
|
||||
continue
|
||||
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
|
||||
for g in grounding_data:
|
||||
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
|
||||
final_data[s] += grounding_data
|
||||
# NOTE: to make training work properly, set `nc` and `names`
|
||||
final_data["nc"] = data["val"][0]["nc"]
|
||||
final_data["names"] = data["val"][0]["names"]
|
||||
self.data = final_data
|
||||
return final_data["train"], final_data["val"][0]
|
||||
|
||||
def plot_training_labels(self):
|
||||
"""DO NOT plot labels."""
|
||||
pass
|
||||
|
||||
def final_eval(self):
|
||||
"""Performs final evaluation and validation for object detection YOLO-World model."""
|
||||
val = self.args.data["val"]["yolo_data"][0]
|
||||
self.validator.args.data = val
|
||||
self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
|
||||
return super().final_eval()
|
||||
|
|
@ -519,7 +519,8 @@ class ContrastiveHead(nn.Module):
|
|||
def __init__(self):
|
||||
"""Initializes ContrastiveHead with specified region-text similarity parameters."""
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros([]))
|
||||
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
|
||||
self.bias = nn.Parameter(torch.tensor([-10.0]))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
|
||||
|
||||
def forward(self, x, w):
|
||||
|
|
@ -542,7 +543,8 @@ class BNContrastiveHead(nn.Module):
|
|||
"""Initialize ContrastiveHead with region-text similarity parameters."""
|
||||
super().__init__()
|
||||
self.norm = nn.BatchNorm2d(embed_dims)
|
||||
self.bias = nn.Parameter(torch.zeros([]))
|
||||
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
|
||||
self.bias = nn.Parameter(torch.tensor([-10.0]))
|
||||
# use -1.0 is more stable
|
||||
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
|
||||
|
||||
|
|
|
|||
|
|
@ -250,6 +250,15 @@ class WorldDetect(Detect):
|
|||
y = torch.cat((dbox, cls.sigmoid()), 1)
|
||||
return y if self.export else (y, x)
|
||||
|
||||
def bias_init(self):
|
||||
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
||||
m = self # self.model[-1] # Detect() module
|
||||
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
||||
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
||||
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
||||
a[-1].bias.data[:] = 1.0 # box
|
||||
# b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
||||
|
||||
|
||||
class RTDETRDecoder(nn.Module):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -564,28 +564,28 @@ class WorldModel(DetectionModel):
|
|||
self.clip_model = None # CLIP model placeholder
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def set_classes(self, text):
|
||||
"""Perform a forward pass with optional profiling, visualization, and embedding extraction."""
|
||||
def set_classes(self, text, batch=80, cache_clip_model=True):
|
||||
"""Set classes in advance so that model could do offline-inference without clip model."""
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
check_requirements("git+https://github.com/openai/CLIP.git")
|
||||
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||
import clip
|
||||
|
||||
if not getattr(self, "clip_model", None): # for backwards compatibility of models lacking clip_model attribute
|
||||
if (
|
||||
not getattr(self, "clip_model", None) and cache_clip_model
|
||||
): # for backwards compatibility of models lacking clip_model attribute
|
||||
self.clip_model = clip.load("ViT-B/32")[0]
|
||||
device = next(self.clip_model.parameters()).device
|
||||
model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
|
||||
device = next(model.parameters()).device
|
||||
text_token = clip.tokenize(text).to(device)
|
||||
txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32)
|
||||
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
|
||||
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
|
||||
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
||||
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
|
||||
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
||||
self.model[-1].nc = len(text)
|
||||
|
||||
def init_criterion(self):
|
||||
"""Initialize the loss criterion for the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
||||
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the model.
|
||||
|
||||
|
|
@ -593,13 +593,14 @@ class WorldModel(DetectionModel):
|
|||
x (torch.Tensor): The input tensor.
|
||||
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
|
||||
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
|
||||
txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
|
||||
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
|
||||
embed (list, optional): A list of feature vectors/embeddings to return.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Model's output tensor.
|
||||
"""
|
||||
txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
|
||||
txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
|
||||
if len(txt_feats) != len(x):
|
||||
txt_feats = txt_feats.repeat(len(x), 1, 1)
|
||||
ori_txt_feats = txt_feats.clone()
|
||||
|
|
@ -627,6 +628,21 @@ class WorldModel(DetectionModel):
|
|||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||
return x
|
||||
|
||||
def loss(self, batch, preds=None):
|
||||
"""
|
||||
Compute loss.
|
||||
|
||||
Args:
|
||||
batch (dict): Batch to compute loss on.
|
||||
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
||||
"""
|
||||
if not hasattr(self, "criterion"):
|
||||
self.criterion = self.init_criterion()
|
||||
|
||||
if preds is None:
|
||||
preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
|
||||
return self.criterion(preds, batch)
|
||||
|
||||
|
||||
class Ensemble(nn.ModuleList):
|
||||
"""Ensemble of models."""
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ class v8DetectionLoss:
|
|||
self.hyp = h
|
||||
self.stride = m.stride # model strides
|
||||
self.nc = m.nc # number of classes
|
||||
self.no = m.no
|
||||
self.no = m.nc + m.reg_max * 4
|
||||
self.reg_max = m.reg_max
|
||||
self.device = device
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue