ultralytics 8.1.39 add YOLO-World training (#9268)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-03-31 22:30:17 +08:00 committed by GitHub
parent 18036908d4
commit e9187c1296
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 2166 additions and 100 deletions

View file

@ -74,6 +74,7 @@ Here is a list of the supported datasets and a brief description for each:
- [**Argoverse**](argoverse.md): A collection of sensor data collected from autonomous vehicles. It contains 3D tracking annotations for car objects. - [**Argoverse**](argoverse.md): A collection of sensor data collected from autonomous vehicles. It contains 3D tracking annotations for car objects.
- [**COCO**](coco.md): Common Objects in Context (COCO) is a large-scale object detection, segmentation, and captioning dataset with 80 object categories. - [**COCO**](coco.md): Common Objects in Context (COCO) is a large-scale object detection, segmentation, and captioning dataset with 80 object categories.
- [**LVIS**](lvis.md): LVIS is a large-scale object detection, segmentation, and captioning dataset with 1203 object categories.
- [**COCO8**](coco8.md): A smaller subset of the COCO dataset, COCO8 is more lightweight and faster to train. - [**COCO8**](coco8.md): A smaller subset of the COCO dataset, COCO8 is more lightweight and faster to train.
- [**GlobalWheat2020**](globalwheat2020.md): A dataset containing images of wheat heads for the Global Wheat Challenge 2020. - [**GlobalWheat2020**](globalwheat2020.md): A dataset containing images of wheat heads for the Global Wheat Challenge 2020.
- [**Objects365**](objects365.md): A large-scale object detection dataset with 365 object categories and 600k images, aimed at advancing object detection research. - [**Objects365**](objects365.md): A large-scale object detection dataset with 365 object categories and 600k images, aimed at advancing object detection research.

View file

@ -0,0 +1,96 @@
---
comments: true
description: Learn how LVIS, a leading dataset for object detection and segmentation, integrates with Ultralytics. Discover ways to use it for training YOLO models.
keywords: Ultralytics, LVIS dataset, object detection, YOLO, YOLO model training, image segmentation, computer vision, deep learning models
---
# LVIS Dataset
The [LVIS](https://www.lvisdataset.org/dataset) dataset is a large-scale, fine-grained vocabulary-level annotation dataset developed and released by Facebook AI Research (FAIR). It is primarily used as a research benchmark for object detection and instance segmentation with a large vocabulary of categories, aiming to drive further advancements in computer vision field.
## Key Features
- LVIS contains 160k images and 2M instance annotations for object detection, segmentation, and captioning tasks.
- The dataset comprises 1203 object categories, including common objects like cars, bicycles, and animals, as well as more specific categories such as umbrellas, handbags, and sports equipment.
- Annotations include object bounding boxes, segmentation masks, and captions for each image.
- LVIS provides standardized evaluation metrics like mean Average Precision (mAP) for object detection, and mean Average Recall (mAR) for segmentation tasks, making it suitable for comparing model performance.
- LVIS uses the exactly the same images as [COCO](./coco.md) dataset, but with different splits and different annotations.
## Dataset Structure
The LVIS dataset is split into three subsets:
1. **Train**: This subset contains 100k images for training object detection, segmentation, and captioning models.
2. **Val**: This subset has 20k images used for validation purposes during model training.
3. **Minival**: This subset is exactly the same as COCO val2017 set which has 5k images used for validation purposes during model training.
4. **Test**: This subset consists of 20k images used for testing and benchmarking the trained models. Ground truth annotations for this subset are not publicly available, and the results are submitted to the [LVIS evaluation server](https://eval.ai/web/challenges/challenge-page/675/overview) for performance evaluation.
## Applications
The LVIS dataset is widely used for training and evaluating deep learning models in object detection (such as YOLO, Faster R-CNN, and SSD), instance segmentation (such as Mask R-CNN). The dataset's diverse set of object categories, large number of annotated images, and standardized evaluation metrics make it an essential resource for computer vision researchers and practitioners.
## Dataset YAML
A YAML (Yet Another Markup Language) file is used to define the dataset configuration. It contains information about the dataset's paths, classes, and other relevant information. In the case of the LVIS dataset, the `lvis.yaml` file is maintained at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml).
!!! Example "ultralytics/cfg/datasets/lvis.yaml"
```yaml
--8<-- "ultralytics/cfg/datasets/lvis.yaml"
```
## Usage
To train a YOLOv8n model on the LVIS dataset for 100 epochs with an image size of 640, you can use the following code snippets. For a comprehensive list of available arguments, refer to the model [Training](../../modes/train.md) page.
!!! Example "Train Example"
=== "Python"
```python
from ultralytics import YOLO
# Load a model
model = YOLO('yolov8n.pt') # load a pretrained model (recommended for training)
# Train the model
results = model.train(data='lvis.yaml', epochs=100, imgsz=640)
```
=== "CLI"
```bash
# Start training from a pretrained *.pt model
yolo detect train data=lvis.yaml model=yolov8n.pt epochs=100 imgsz=640
```
## Sample Images and Annotations
The LVIS dataset contains a diverse set of images with various object categories and complex scenes. Here are some examples of images from the dataset, along with their corresponding annotations:
![Dataset sample image](https://private-user-images.githubusercontent.com/61612323/316485965-a88c2e62-58d0-4f67-bc69-1418e42175e9.jpg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTEzNjcyNjYsIm5iZiI6MTcxMTM2Njk2NiwicGF0aCI6Ii82MTYxMjMyMy8zMTY0ODU5NjUtYTg4YzJlNjItNThkMC00ZjY3LWJjNjktMTQxOGU0MjE3NWU5LmpwZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDAzMjUlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwMzI1VDExNDI0NlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWZmMTVlNzE5MTBkOTZmNDQwNzJjNWQzYzM2NmEyMGMxODQ4ZDEyMjYwYmMyY2JjZDU5YzBmMDIyZGEwMGEwZDAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.7thukPdnJKYuBmTk1ROUyqxxV3Ix5GeNLqyi4wSDYvA)
- **Mosaiced Image**: This image demonstrates a training batch composed of mosaiced dataset images. Mosaicing is a technique used during training that combines multiple images into a single image to increase the variety of objects and scenes within each training batch. This helps improve the model's ability to generalize to different object sizes, aspect ratios, and contexts.
The example showcases the variety and complexity of the images in the LVIS dataset and the benefits of using mosaicing during the training process.
## Citations and Acknowledgments
If you use the LVIS dataset in your research or development work, please cite the following paper:
!!! Quote ""
=== "BibTeX"
```bibtex
@inproceedings{gupta2019lvis,
title={{LVIS}: A Dataset for Large Vocabulary Instance Segmentation},
author={Gupta, Agrim and Dollar, Piotr and Girshick, Ross},
booktitle={Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition},
year={2019}
}
```
We would like to acknowledge the LVIS Consortium for creating and maintaining this valuable resource for the computer vision community. For more information about the LVIS dataset and its creators, visit the [LVIS dataset website](https://www.lvisdataset.org/dataset).

View file

@ -36,6 +36,7 @@ Bounding box object detection is a computer vision technique that involves detec
- [Argoverse](detect/argoverse.md): A dataset containing 3D tracking and motion forecasting data from urban environments with rich annotations. - [Argoverse](detect/argoverse.md): A dataset containing 3D tracking and motion forecasting data from urban environments with rich annotations.
- [COCO](detect/coco.md): A large-scale dataset designed for object detection, segmentation, and captioning with over 200K labeled images. - [COCO](detect/coco.md): A large-scale dataset designed for object detection, segmentation, and captioning with over 200K labeled images.
- [LVIS](lvis.md): A large-scale object detection, segmentation, and captioning dataset with 1203 object categories.
- [COCO8](detect/coco8.md): Contains the first 4 images from COCO train and COCO val, suitable for quick tests. - [COCO8](detect/coco8.md): Contains the first 4 images from COCO train and COCO val, suitable for quick tests.
- [Global Wheat 2020](detect/globalwheat2020.md): A dataset of wheat head images collected from around the world for object detection and localization tasks. - [Global Wheat 2020](detect/globalwheat2020.md): A dataset of wheat head images collected from around the world for object detection and localization tasks.
- [Objects365](detect/objects365.md): A high-quality, large-scale dataset for object detection with 365 object categories and over 600K annotated images. - [Objects365](detect/objects365.md): A high-quality, large-scale dataset for object detection with 365 object categories and over 600K annotated images.

View file

@ -147,7 +147,7 @@ FastSAM is also available directly from the [https://github.com/CASIA-IVA-Lab/Fa
4. Install the CLIP model: 4. Install the CLIP model:
```shell ```shell
pip install git+https://github.com/openai/CLIP.git pip install git+https://github.com/ultralytics/CLIP.git
``` ```
### Example Usage ### Example Usage

View file

@ -64,6 +64,39 @@ This section details the models available with their specific pre-trained weight
The YOLO-World models are easy to integrate into your Python applications. Ultralytics provides user-friendly Python API and CLI commands to streamline development. The YOLO-World models are easy to integrate into your Python applications. Ultralytics provides user-friendly Python API and CLI commands to streamline development.
### Train Usage
!!! Tip "Tip"
We strongly recommend to use `yolov8-worldv2` model for custom training, because it supports deterministic training and also easy to export other formats i.e onnx/tensorrt.
Object detection is straightforward with the `train` method, as illustrated below:
!!! Example
=== "Python"
PyTorch pretrained `*.pt` models as well as configuration `*.yaml` files can be passed to the `YOLOWorld()` class to create a model instance in python:
```python
from ultralytics import YOLOWorld
# Load a pretrained YOLOv8s-worldv2 model
model = YOLOWorld('yolov8s-worldv2.pt')
# Train the model on the COCO8 example dataset for 100 epochs
results = model.train(data='coco8.yaml', epochs=100, imgsz=640)
# Run inference with the YOLOv8n model on the 'bus.jpg' image
results = model('path/to/bus.jpg')
```
=== "CLI"
```bash
# Load a pretrained YOLOv8s-worldv2 model and train it on the COCO8 example dataset for 100 epochs
yolo train model=yolov8s-worldv2.yaml data=coco8.yaml epochs=100 imgsz=640
```
### Predict Usage ### Predict Usage
Object detection is straightforward with the `predict` method, as illustrated below: Object detection is straightforward with the `predict` method, as illustrated below:
@ -196,6 +229,59 @@ You can also save a model after setting custom classes. By doing this you create
This approach provides a powerful means of customizing state-of-the-art object detection models for specific tasks, making advanced AI more accessible and applicable to a broader range of practical applications. This approach provides a powerful means of customizing state-of-the-art object detection models for specific tasks, making advanced AI more accessible and applicable to a broader range of practical applications.
## Reproduce official results from scratch(Experimental)
### Prepare datasets
- Train data
| Dataset | Type | Samples | Boxes | Annotation Files |
|-------------------------------------------------------------------|-----------|---------|-------|--------------------------------------------------------------------------------------------------------------------------------------------|
| [Objects365v1](https://opendatalab.com/OpenDataLab/Objects365_v1) | Detection | 609k | 9621k | [objects365_train.json](https://opendatalab.com/OpenDataLab/Objects365_v1) |
| [GQA](https://nlp.stanford.edu/data/gqa/images.zip) | Grounding | 621k | 3681k | [final_mixed_train_no_coco.json](https://huggingface.co/GLIPModel/GLIP/blob/main/mdetr_annotations/final_mixed_train_no_coco.json) |
| [Flickr30k](https://shannon.cs.illinois.edu/DenotationGraph/) | Grounding | 149k | 641k | [final_flickr_separateGT_train.json](https://huggingface.co/GLIPModel/GLIP/blob/main/mdetr_annotations/final_flickr_separateGT_train.json) |
- Val data
| Dataset | Type | Annotation Files |
|---------------------------------------------------------------------------------------------------------|-----------|--------------------------------------------------------------------------------------------------------|
| [LVIS minival](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml) | Detection | [minival.txt](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml) |
### Launch training from scratch
!!! Note
`WorldTrainerFromScratch` is highly customized to allow training yolo-world models on both detection datasets and grounding datasets simultaneously. More details please checkout [ultralytics.model.yolo.world.train_world.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train_world.py).
!!! Example
=== "Python"
```python
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
from ultralytics import YOLOWorld
data = dict(
train=dict(
yolo_data=["Objects365.yaml"],
grounding_data=[
dict(
img_path="../datasets/flickr30k/images",
json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
),
dict(
img_path="../datasets/GQA/images",
json_file="../datasets/GQA/final_mixed_train_no_coco.json",
),
],
),
val=dict(yolo_data=["lvis.yaml"]),
)
model = YOLOWorld("yolov8s-worldv2.yaml")
model.train(data=data, batch=128, epochs=100, trainer=WorldTrainerFromScratch)
```
## Citations and Acknowledgements ## Citations and Acknowledgements
We extend our gratitude to the [Tencent AILab Computer Vision Center](https://ai.tencent.com/) for their pioneering work in real-time open-vocabulary object detection with YOLO-World: We extend our gratitude to the [Tencent AILab Computer Vision Center](https://ai.tencent.com/) for their pioneering work in real-time open-vocabulary object detection with YOLO-World:

View file

@ -59,6 +59,10 @@ keywords: Ultralytics, Data Augmentation, BaseTransform, MixUp, RandomHSV, Lette
<br><br> <br><br>
## ::: ultralytics.data.augment.RandomLoadText
<br><br>
## ::: ultralytics.data.augment.ClassifyLetterBox ## ::: ultralytics.data.augment.ClassifyLetterBox
<br><br> <br><br>

View file

@ -27,6 +27,10 @@ keywords: Ultralytics, YOLO v3, Data build, DataLoader, InfiniteDataLoader, seed
<br><br> <br><br>
## ::: ultralytics.data.build.build_grounding
<br><br>
## ::: ultralytics.data.build.build_dataloader ## ::: ultralytics.data.build.build_dataloader
<br><br> <br><br>

View file

@ -19,14 +19,18 @@ keywords: Ultralytics, YOLO, YOLODataset, SemanticDataset, data handling, data m
<br><br> <br><br>
## ::: ultralytics.data.dataset.YOLOMultiModalDataset
<br><br>
## ::: ultralytics.data.dataset.GroundingDataset
<br><br>
## ::: ultralytics.data.dataset.YOLOConcatDataset
<br><br>
## ::: ultralytics.data.dataset.SemanticDataset ## ::: ultralytics.data.dataset.SemanticDataset
<br><br> <br><br>
## ::: ultralytics.data.dataset.load_dataset_cache_file
<br><br>
## ::: ultralytics.data.dataset.save_dataset_cache_file
<br><br>

View file

@ -66,3 +66,11 @@ keywords: Ultralytics, data utils, YOLO, img2label_paths, exif_size, polygon2mas
## ::: ultralytics.data.utils.autosplit ## ::: ultralytics.data.utils.autosplit
<br><br> <br><br>
## ::: ultralytics.data.utils.load_dataset_cache_file
<br><br>
## ::: ultralytics.data.utils.save_dataset_cache_file
<br><br>

View file

@ -0,0 +1,15 @@
# Reference for `ultralytics/models/yolo/world/train.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/yolo/world/train.py) 🛠️. Thank you 🙏!
<br><br>
## ::: ultralytics.models.yolo.world.train.WorldTrainer
<br><br>
## ::: ultralytics.models.yolo.world.train.on_pretrain_routine_end
<br><br>

View file

@ -0,0 +1,11 @@
# Reference for `ultralytics/models/yolo/world/train_world.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train_world.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train_world.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/yolo/world/train_world.py) 🛠️. Thank you 🙏!
<br><br>
## ::: ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch
<br><br>

View file

@ -18,6 +18,7 @@ chr043416@gmail.com: RizwanMunawar
glenn.jocher@ultralytics.com: glenn-jocher glenn.jocher@ultralytics.com: glenn-jocher
muhammadrizwanmunawar123@gmail.com: RizwanMunawar muhammadrizwanmunawar123@gmail.com: RizwanMunawar
not.committed.yet: null not.committed.yet: null
plashchynski@gmail.com: plashchynski
priytosh.revolution@live.com: priytosh-tripathi priytosh.revolution@live.com: priytosh-tripathi
shuizhuyuanluo@126.com: null shuizhuyuanluo@126.com: null
xinwang614@gmail.com: GreatV xinwang614@gmail.com: GreatV

View file

@ -240,6 +240,7 @@ nav:
- datasets/detect/index.md - datasets/detect/index.md
- Argoverse: datasets/detect/argoverse.md - Argoverse: datasets/detect/argoverse.md
- COCO: datasets/detect/coco.md - COCO: datasets/detect/coco.md
- LVIS: datasets/detect/lvis.md
- COCO8: datasets/detect/coco8.md - COCO8: datasets/detect/coco8.md
- GlobalWheat2020: datasets/detect/globalwheat2020.md - GlobalWheat2020: datasets/detect/globalwheat2020.md
- Objects365: datasets/detect/objects365.md - Objects365: datasets/detect/objects365.md
@ -492,6 +493,9 @@ nav:
- predict: reference/models/yolo/segment/predict.md - predict: reference/models/yolo/segment/predict.md
- train: reference/models/yolo/segment/train.md - train: reference/models/yolo/segment/train.md
- val: reference/models/yolo/segment/val.md - val: reference/models/yolo/segment/val.md
- world:
- train: reference/models/yolo/world/train.md
- train_world: reference/models/yolo/world/train_world.md
- nn: - nn:
- autobackend: reference/nn/autobackend.md - autobackend: reference/nn/autobackend.md
- modules: - modules:

View file

@ -643,3 +643,29 @@ def test_yolo_world():
model = YOLO("yolov8s-world.pt") # no YOLOv8n-world model yet model = YOLO("yolov8s-world.pt") # no YOLOv8n-world model yet
model.set_classes(["tree", "window"]) model.set_classes(["tree", "window"])
model(ASSETS / "bus.jpg", conf=0.01) model(ASSETS / "bus.jpg", conf=0.01)
# Training from yaml
model = YOLO("yolov8s-worldv2.yaml") # no YOLOv8n-world model yet
model.train(data="coco8.yaml", epochs=2, imgsz=32, cache="disk", batch=-1, close_mosaic=1, name="yolo-world")
model = YOLO("yolov8s-worldv2.pt") # no YOLOv8n-world model yet
# val
model.val(data="coco8.yaml", imgsz=32, save_txt=True, save_json=True)
# Training from pretrain
model.train(data="coco8.yaml", epochs=2, imgsz=32, cache="disk", batch=-1, close_mosaic=1, name="yolo-world")
# test WorWorldTrainerFromScratch
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
model = YOLO("yolov8s-worldv2.yaml") # no YOLOv8n-world model yet
data = dict(train=dict(yolo_data=["coco8.yaml"]), val=dict(yolo_data=["coco8.yaml"]))
model.train(
data=data,
epochs=2,
imgsz=32,
cache="disk",
batch=-1,
close_mosaic=1,
name="yolo-world",
trainer=WorldTrainerFromScratch,
)

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.38" __version__ = "8.1.39"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

File diff suppressed because it is too large Load diff

View file

@ -1,15 +1,31 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
from .base import BaseDataset from .base import BaseDataset
from .build import build_dataloader, build_yolo_dataset, load_inference_source from .build import (
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset build_dataloader,
build_yolo_dataset,
build_grounding,
load_inference_source,
)
from .dataset import (
ClassificationDataset,
SemanticDataset,
YOLODataset,
YOLOMultiModalDataset,
GroundingDataset,
YOLOConcatDataset,
)
__all__ = ( __all__ = (
"BaseDataset", "BaseDataset",
"ClassificationDataset", "ClassificationDataset",
"SemanticDataset", "SemanticDataset",
"YOLODataset", "YOLODataset",
"YOLOMultiModalDataset",
"YOLOConcatDataset",
"GroundingDataset",
"build_yolo_dataset", "build_yolo_dataset",
"build_grounding",
"build_dataloader", "build_dataloader",
"load_inference_source", "load_inference_source",
) )

View file

@ -3,6 +3,7 @@
import math import math
import random import random
from copy import deepcopy from copy import deepcopy
from typing import Tuple, Union
import cv2 import cv2
import numpy as np import numpy as np
@ -66,7 +67,7 @@ class Compose:
def __init__(self, transforms): def __init__(self, transforms):
"""Initializes the Compose object with a list of transforms.""" """Initializes the Compose object with a list of transforms."""
self.transforms = transforms self.transforms = transforms if isinstance(transforms, list) else [transforms]
def __call__(self, data): def __call__(self, data):
"""Applies a series of transformations to input data.""" """Applies a series of transformations to input data."""
@ -78,6 +79,29 @@ class Compose:
"""Appends a new transform to the existing list of transforms.""" """Appends a new transform to the existing list of transforms."""
self.transforms.append(transform) self.transforms.append(transform)
def insert(self, index, transform):
"""Inserts a new transform to the existing list of transforms."""
self.transforms.insert(index, transform)
def __getitem__(self, index: Union[list, int]) -> "Compose":
"""Retrieve a specific transform or a set of transforms using indexing."""
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
index = [index] if isinstance(index, int) else index
return Compose([self.transforms[i] for i in index])
def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None:
"""Retrieve a specific transform or a set of transforms using indexing."""
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
if isinstance(index, list):
assert isinstance(
value, list
), f"The indices should be the same type as values, but got {type(index)} and {type(value)}"
if isinstance(index, int):
index, value = [index], [value]
for i, v in zip(index, value):
assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}."
self.transforms[i] = v
def tolist(self): def tolist(self):
"""Converts the list of transforms to a standard Python list.""" """Converts the list of transforms to a standard Python list."""
return self.transforms return self.transforms
@ -118,6 +142,8 @@ class BaseMixTransform:
mix_labels[i] = self.pre_transform(data) mix_labels[i] = self.pre_transform(data)
labels["mix_labels"] = mix_labels labels["mix_labels"] = mix_labels
# Update cls and texts
labels = self._update_label_text(labels)
# Mosaic or MixUp # Mosaic or MixUp
labels = self._mix_transform(labels) labels = self._mix_transform(labels)
labels.pop("mix_labels", None) labels.pop("mix_labels", None)
@ -131,6 +157,22 @@ class BaseMixTransform:
"""Gets a list of shuffled indexes for mosaic augmentation.""" """Gets a list of shuffled indexes for mosaic augmentation."""
raise NotImplementedError raise NotImplementedError
def _update_label_text(self, labels):
"""Update label text."""
if "texts" not in labels:
return labels
mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], [])
mix_texts = list({tuple(x) for x in mix_texts})
text2id = {text: i for i, text in enumerate(mix_texts)}
for label in [labels] + labels["mix_labels"]:
for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
text = label["texts"][int(l)]
label["cls"][i] = text2id[tuple(text)]
label["texts"] = mix_texts
return labels
class Mosaic(BaseMixTransform): class Mosaic(BaseMixTransform):
""" """
@ -320,6 +362,8 @@ class Mosaic(BaseMixTransform):
final_labels["instances"].clip(imgsz, imgsz) final_labels["instances"].clip(imgsz, imgsz)
good = final_labels["instances"].remove_zero_area_boxes() good = final_labels["instances"].remove_zero_area_boxes()
final_labels["cls"] = final_labels["cls"][good] final_labels["cls"] = final_labels["cls"][good]
if "texts" in mosaic_labels[0]:
final_labels["texts"] = mosaic_labels[0]["texts"]
return final_labels return final_labels
@ -970,6 +1014,83 @@ class Format:
return masks, instances, cls return masks, instances, cls
class RandomLoadText:
"""
Randomly sample positive texts and negative texts and update the class indices accordingly to the number of samples.
Attributes:
prompt_format (str): Format for prompt. Default is '{}'.
neg_samples (tuple[int]): A ranger to randomly sample negative texts, Default is (80, 80).
max_samples (int): The max number of different text samples in one image, Default is 80.
padding (bool): Whether to pad texts to max_samples. Default is False.
padding_value (str): The padding text. Default is "".
"""
def __init__(
self,
prompt_format: str = "{}",
neg_samples: Tuple[int, int] = (80, 80),
max_samples: int = 80,
padding: bool = False,
padding_value: str = "",
) -> None:
"""Initializes the RandomLoadText class with given parameters."""
self.prompt_format = prompt_format
self.neg_samples = neg_samples
self.max_samples = max_samples
self.padding = padding
self.padding_value = padding_value
def __call__(self, labels: dict) -> dict:
"""Return updated classes and texts."""
assert "texts" in labels, "No texts found in labels."
class_texts = labels["texts"]
num_classes = len(class_texts)
cls = np.asarray(labels.pop("cls"), dtype=int)
pos_labels = np.unique(cls).tolist()
if len(pos_labels) > self.max_samples:
pos_labels = set(random.sample(pos_labels, k=self.max_samples))
neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))
neg_labels = []
for i in range(num_classes):
if i not in pos_labels:
neg_labels.append(i)
neg_labels = random.sample(neg_labels, k=neg_samples)
sampled_labels = pos_labels + neg_labels
random.shuffle(sampled_labels)
label2ids = {label: i for i, label in enumerate(sampled_labels)}
valid_idx = np.zeros(len(labels["instances"]), dtype=bool)
new_cls = []
for i, label in enumerate(cls.squeeze(-1).tolist()):
if label not in label2ids:
continue
valid_idx[i] = True
new_cls.append([label2ids[label]])
labels["instances"] = labels["instances"][valid_idx]
labels["cls"] = np.array(new_cls)
# Randomly select one prompt when there's more than one prompts
texts = []
for label in sampled_labels:
prompts = class_texts[label]
assert len(prompts) > 0
prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))])
texts.append(prompt)
if self.padding:
valid_labels = len(pos_labels) + len(neg_labels)
num_padding = self.max_samples - valid_labels
if num_padding > 0:
texts += [self.padding_value] * num_padding
labels["texts"] = texts
return labels
def v8_transforms(dataset, imgsz, hyp, stretch=False): def v8_transforms(dataset, imgsz, hyp, stretch=False):
"""Convert images to a size suitable for YOLOv8 training.""" """Convert images to a size suitable for YOLOv8 training."""
pre_transform = Compose( pre_transform = Compose(

View file

@ -22,7 +22,7 @@ from ultralytics.data.loaders import (
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.utils import RANK, colorstr from ultralytics.utils import RANK, colorstr
from ultralytics.utils.checks import check_file from ultralytics.utils.checks import check_file
from .dataset import YOLODataset from .dataset import YOLODataset, YOLOMultiModalDataset, GroundingDataset
from .utils import PIN_MEMORY from .utils import PIN_MEMORY
@ -82,9 +82,10 @@ def seed_worker(worker_id): # noqa
random.seed(worker_seed) random.seed(worker_seed)
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32): def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
"""Build YOLO Dataset.""" """Build YOLO Dataset."""
return YOLODataset( dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
return dataset(
img_path=img_path, img_path=img_path,
imgsz=cfg.imgsz, imgsz=cfg.imgsz,
batch_size=batch, batch_size=batch,
@ -103,6 +104,27 @@ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, str
) )
def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):
"""Build YOLO Dataset."""
return GroundingDataset(
img_path=img_path,
json_file=json_file,
imgsz=cfg.imgsz,
batch_size=batch,
augment=mode == "train", # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
rect=cfg.rect or rect, # rectangular batches
cache=cfg.cache or None,
single_cls=cfg.single_cls or False,
stride=int(stride),
pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "),
task=cfg.task,
classes=cfg.classes,
fraction=cfg.fraction if mode == "train" else 1.0,
)
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
"""Return an InfiniteDataLoader or DataLoader for training or validation set.""" """Return an InfiniteDataLoader or DataLoader for training or validation set."""
batch = min(batch, len(dataset)) batch = min(batch, len(dataset))

View file

@ -219,6 +219,7 @@ def convert_coco(
use_segments=False, use_segments=False,
use_keypoints=False, use_keypoints=False,
cls91to80=True, cls91to80=True,
lvis=False,
): ):
""" """
Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models. Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
@ -229,12 +230,14 @@ def convert_coco(
use_segments (bool, optional): Whether to include segmentation masks in the output. use_segments (bool, optional): Whether to include segmentation masks in the output.
use_keypoints (bool, optional): Whether to include keypoint annotations in the output. use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs. cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
lvis (bool, optional): Whether to convert data in lvis dataset way.
Example: Example:
```python ```python
from ultralytics.data.converter import convert_coco from ultralytics.data.converter import convert_coco
convert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True) convert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True)
convert_coco('../datasets/lvis/annotations/', use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)
``` ```
Output: Output:
@ -251,8 +254,14 @@ def convert_coco(
# Import json # Import json
for json_file in sorted(Path(labels_dir).resolve().glob("*.json")): for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name lname = "" if lvis else json_file.stem.replace("instances_", "")
fn = Path(save_dir) / "labels" / lname # folder name
fn.mkdir(parents=True, exist_ok=True) fn.mkdir(parents=True, exist_ok=True)
if lvis:
# NOTE: create folders for both train and val in advance,
# since LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split.
(fn / "train2017").mkdir(parents=True, exist_ok=True)
(fn / "val2017").mkdir(parents=True, exist_ok=True)
with open(json_file) as f: with open(json_file) as f:
data = json.load(f) data = json.load(f)
@ -263,16 +272,20 @@ def convert_coco(
for ann in data["annotations"]: for ann in data["annotations"]:
imgToAnns[ann["image_id"]].append(ann) imgToAnns[ann["image_id"]].append(ann)
image_txt = []
# Write labels file # Write labels file
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"): for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"):
img = images[f"{img_id:d}"] img = images[f"{img_id:d}"]
h, w, f = img["height"], img["width"], img["file_name"] h, w = img["height"], img["width"]
f = str(Path(img["coco_url"]).relative_to("http://images.cocodataset.org")) if lvis else img["file_name"]
if lvis:
image_txt.append(str(Path("./images") / f))
bboxes = [] bboxes = []
segments = [] segments = []
keypoints = [] keypoints = []
for ann in anns: for ann in anns:
if ann["iscrowd"]: if ann.get("iscrowd", False):
continue continue
# The COCO box format is [top left x, top left y, width, height] # The COCO box format is [top left x, top left y, width, height]
box = np.array(ann["bbox"], dtype=np.float64) box = np.array(ann["bbox"], dtype=np.float64)
@ -314,7 +327,12 @@ def convert_coco(
) # cls, box or segments ) # cls, box or segments
file.write(("%g " * len(line)).rstrip() % line + "\n") file.write(("%g " * len(line)).rstrip() % line + "\n")
LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}") if lvis:
with open((Path(save_dir) / json_file.name.replace("lvis_v1_", "").replace(".json", ".txt")), "a") as f:
for l in image_txt:
f.write(f"{l}\n")
LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}")
def convert_dota_to_yolo_obb(dota_root_path: str): def convert_dota_to_yolo_obb(dota_root_path: str):

View file

@ -1,20 +1,41 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib import contextlib
from itertools import repeat from itertools import repeat
from collections import defaultdict
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from pathlib import Path from pathlib import Path
import cv2 import cv2
import json
import numpy as np import numpy as np
import torch import torch
import torchvision import torchvision
from PIL import Image from PIL import Image
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable from torch.utils.data import ConcatDataset
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.ops import resample_segments from ultralytics.utils.ops import resample_segments
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms from .augment import (
Compose,
Format,
Instances,
LetterBox,
RandomLoadText,
classify_augmentations,
classify_transforms,
v8_transforms,
)
from .base import BaseDataset from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label from .utils import (
HELP_URL,
LOGGER,
get_hash,
img2label_paths,
verify_image,
verify_image_label,
load_dataset_cache_file,
save_dataset_cache_file,
)
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8 # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
DATASET_CACHE_VERSION = "1.0.3" DATASET_CACHE_VERSION = "1.0.3"
@ -105,7 +126,7 @@ class YOLODataset(BaseDataset):
x["hash"] = get_hash(self.label_files + self.im_files) x["hash"] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files) x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x) save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return x return x
def get_labels(self): def get_labels(self):
@ -339,31 +360,125 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
x["hash"] = get_hash([x[0] for x in self.samples]) x["hash"] = get_hash([x[0] for x in self.samples])
x["results"] = nf, nc, len(samples), samples x["results"] = nf, nc, len(samples), samples
x["msgs"] = msgs # warnings x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x) save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return samples return samples
def load_dataset_cache_file(path): class YOLOMultiModalDataset(YOLODataset):
"""Load an Ultralytics *.cache dictionary from path.""" """
import gc Dataset class for loading object detection and/or segmentation labels in YOLO format.
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 Args:
cache = np.load(str(path), allow_pickle=True).item() # load dict data (dict, optional): A dataset YAML dictionary. Defaults to None.
gc.enable() task (str): An explicit arg to point current task, Defaults to 'detect'.
return cache
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
def __init__(self, *args, data=None, task="detect", **kwargs):
"""Initializes a dataset object for object detection tasks with optional specifications."""
super().__init__(*args, data=data, task=task, **kwargs)
def update_labels_info(self, label):
"""Add texts information for multi modal model training."""
labels = super().update_labels_info(label)
# NOTE: some categories are concatenated with its synonyms by `/`.
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
return labels
def build_transforms(self, hyp=None):
"""Enhances data transformations with optional text augmentation for multi-modal training."""
transforms = super().build_transforms(hyp)
if self.augment:
# NOTE: hard-coded the args for now.
transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
return transforms
def save_dataset_cache_file(prefix, path, x): class GroundingDataset(YOLODataset):
"""Save an Ultralytics dataset *.cache dictionary x to path.""" def __init__(self, *args, task="detect", json_file, **kwargs):
x["version"] = DATASET_CACHE_VERSION # add cache version """Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
if is_dir_writeable(path.parent): assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
if path.exists(): self.json_file = json_file
path.unlink() # remove *.cache file if exists super().__init__(*args, task=task, data={}, **kwargs)
np.save(str(path), x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix def get_img_files(self, img_path):
LOGGER.info(f"{prefix}New cache created: {path}") """The image files would be read in `get_labels` function, return empty list here."""
else: return []
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
def get_labels(self):
"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
labels = []
LOGGER.info("Loading annotation file...")
with open(self.json_file, "r") as f:
annotations = json.load(f)
images = {f'{x["id"]:d}': x for x in annotations["images"]}
imgToAnns = defaultdict(list)
for ann in annotations["annotations"]:
imgToAnns[ann["image_id"]].append(ann)
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Reading annotations {self.json_file}"):
img = images[f"{img_id:d}"]
h, w, f = img["height"], img["width"], img["file_name"]
im_file = Path(self.img_path) / f
if not im_file.exists():
continue
self.im_files.append(str(im_file))
bboxes = []
cat2id = {}
texts = []
for ann in anns:
if ann["iscrowd"]:
continue
box = np.array(ann["bbox"], dtype=np.float32)
box[:2] += box[2:] / 2
box[[0, 2]] /= float(w)
box[[1, 3]] /= float(h)
if box[2] <= 0 or box[3] <= 0:
continue
cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
if cat_name not in cat2id:
cat2id[cat_name] = len(cat2id)
texts.append([cat_name])
cls = cat2id[cat_name] # class
box = [cls] + box.tolist()
if box not in bboxes:
bboxes.append(box)
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
labels.append(
dict(
im_file=im_file,
shape=(h, w),
cls=lb[:, 0:1], # n, 1
bboxes=lb[:, 1:], # n, 4
normalized=True,
bbox_format="xywh",
texts=texts,
)
)
return labels
def build_transforms(self, hyp=None):
"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
transforms = super().build_transforms(hyp)
if self.augment:
# NOTE: hard-coded the args for now.
transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
return transforms
class YOLOConcatDataset(ConcatDataset):
"""
Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
"""
@staticmethod
def collate_fn(batch):
"""Collates data samples into batches."""
return YOLODataset.collate_fn(batch)
# TODO: support semantic segmentation # TODO: support semantic segmentation

View file

@ -29,6 +29,7 @@ from ultralytics.utils import (
emojis, emojis,
yaml_load, yaml_load,
yaml_save, yaml_save,
is_dir_writeable,
) )
from ultralytics.utils.checks import check_file, check_font, is_ascii from ultralytics.utils.checks import check_file, check_font, is_ascii
from ultralytics.utils.downloads import download, safe_download, unzip_file from ultralytics.utils.downloads import download, safe_download, unzip_file
@ -303,7 +304,7 @@ def check_det_dataset(dataset, autodownload=True):
# Set paths # Set paths
data["path"] = path # download scripts data["path"] = path # download scripts
for k in "train", "val", "test": for k in "train", "val", "test", "minival":
if data.get(k): # prepend path if data.get(k): # prepend path
if isinstance(data[k], str): if isinstance(data[k], str):
x = (path / data[k]).resolve() x = (path / data[k]).resolve()
@ -649,3 +650,26 @@ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annot
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
with open(path.parent / txt[i], "a") as f: with open(path.parent / txt[i], "a") as f:
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
def load_dataset_cache_file(path):
"""Load an Ultralytics *.cache dictionary from path."""
import gc
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
cache = np.load(str(path), allow_pickle=True).item() # load dict
gc.enable()
return cache
def save_dataset_cache_file(prefix, path, x, version):
"""Save an Ultralytics dataset *.cache dictionary x to path."""
x["version"] = version # add cache version
if is_dir_writeable(path.parent):
if path.exists():
path.unlink() # remove *.cache file if exists
np.save(str(path), x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
LOGGER.info(f"{prefix}New cache created: {path}")
else:
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")

View file

@ -126,22 +126,7 @@ class BaseTrainer:
# Model and Dataset # Model and Dataset
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
try: self.trainset, self.testset = self.get_dataset()
if self.args.task == "classify":
self.data = check_cls_dataset(self.args.data)
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
"detect",
"segment",
"pose",
"obb",
):
self.data = check_det_dataset(self.args.data)
if "yaml_file" in self.data:
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
self.trainset, self.testset = self.get_dataset(self.data)
self.ema = None self.ema = None
# Optimization utils init # Optimization utils init
@ -509,13 +494,27 @@ class BaseTrainer:
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0): if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt' (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
@staticmethod def get_dataset(self):
def get_dataset(data):
""" """
Get train, val path from data dict if it exists. Get train, val path from data dict if it exists.
Returns None if data format is not recognized. Returns None if data format is not recognized.
""" """
try:
if self.args.task == "classify":
data = check_cls_dataset(self.args.data)
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
"detect",
"segment",
"pose",
"obb",
):
data = check_det_dataset(self.args.data)
if "yaml_file" in data:
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
self.data = data
return data["train"], data.get("val") or data.get("test") return data["train"], data.get("val") or data.get("test")
def setup_model(self): def setup_model(self):
@ -666,8 +665,8 @@ class BaseTrainer:
if ckpt is None: if ckpt is None:
return return
best_fitness = 0.0 best_fitness = 0.0
start_epoch = ckpt["epoch"] + 1 start_epoch = ckpt.get("epoch", -1) + 1
if ckpt["optimizer"] is not None: if ckpt.get("optimizer", None) is not None:
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
best_fitness = ckpt["best_fitness"] best_fitness = ckpt["best_fitness"]
if self.ema and ckpt.get("ema"): if self.ema and ckpt.get("ema"):

View file

@ -35,7 +35,7 @@ class FastSAMPrompt:
except ImportError: except ImportError:
from ultralytics.utils.checks import check_requirements from ultralytics.utils.checks import check_requirements
check_requirements("git+https://github.com/openai/CLIP.git") check_requirements("git+https://github.com/ultralytics/CLIP.git")
import clip import clip
self.clip = clip self.clip = clip

View file

@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models.yolo import classify, detect, obb, pose, segment from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
from .model import YOLO, YOLOWorld from .model import YOLO, YOLOWorld
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld" __all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"

View file

@ -33,6 +33,7 @@ class DetectionValidator(BaseValidator):
super().__init__(dataloader, save_dir, pbar, args, _callbacks) super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.nt_per_class = None self.nt_per_class = None
self.is_coco = False self.is_coco = False
self.is_lvis = False
self.class_map = None self.class_map = None
self.args.task = "detect" self.args.task = "detect"
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
@ -66,8 +67,9 @@ class DetectionValidator(BaseValidator):
"""Initialize evaluation metrics for YOLO.""" """Initialize evaluation metrics for YOLO."""
val = self.data.get(self.args.split, "") # validation path val = self.data.get(self.args.split, "") # validation path
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000)) self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(len(model.names)))
self.args.save_json |= (self.is_coco or self.is_lvis) and not self.training # run on final val if training COCO
self.names = model.names self.names = model.names
self.nc = len(model.names) self.nc = len(model.names)
self.metrics.names = self.names self.metrics.names = self.names
@ -266,7 +268,8 @@ class DetectionValidator(BaseValidator):
self.jdict.append( self.jdict.append(
{ {
"image_id": image_id, "image_id": image_id,
"category_id": self.class_map[int(p[5])], "category_id": self.class_map[int(p[5])]
+ (1 if self.is_lvis else 0), # index starts from 1 if it's lvis
"bbox": [round(x, 3) for x in b], "bbox": [round(x, 3) for x in b],
"score": round(p[4], 5), "score": round(p[4], 5),
} }
@ -274,26 +277,42 @@ class DetectionValidator(BaseValidator):
def eval_json(self, stats): def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics.""" """Evaluates YOLO output in JSON format and returns performance statistics."""
if self.args.save_json and self.is_coco and len(self.jdict): if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
pred_json = self.save_dir / "predictions.json" # predictions pred_json = self.save_dir / "predictions.json" # predictions
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") anno_json = (
self.data["path"]
/ "annotations"
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
) # annotations
pkg = "pycocotools" if self.is_coco else "lvis"
LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements("pycocotools>=2.0.6") for x in pred_json, anno_json:
assert x.is_file(), f"{x} file not found"
check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
if self.is_coco:
from pycocotools.coco import COCO # noqa from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
eval = COCOeval(anno, pred, "bbox") eval = COCOeval(anno, pred, "bbox")
if self.is_coco: else:
from lvis import LVIS, LVISEval
anno = LVIS(str(anno_json)) # init annotations api
pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
eval = LVISEval(anno, pred, "bbox")
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
eval.evaluate() eval.evaluate()
eval.accumulate() eval.accumulate()
eval.summarize() eval.summarize()
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50 if self.is_lvis:
eval.print_results() # explicitly call print_results
# update mAP50-95 and mAP50
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]]
)
except Exception as e: except Exception as e:
LOGGER.warning(f"pycocotools unable to run: {e}") LOGGER.warning(f"{pkg} unable to run: {e}")
return stats return stats

View file

@ -83,6 +83,7 @@ class YOLOWorld(Model):
"model": WorldModel, "model": WorldModel,
"validator": yolo.detect.DetectionValidator, "validator": yolo.detect.DetectionValidator,
"predictor": yolo.detect.DetectionPredictor, "predictor": yolo.detect.DetectionPredictor,
"trainer": yolo.world.WorldTrainer,
} }
} }

View file

@ -0,0 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .train import WorldTrainer
__all__ = ["WorldTrainer"]

View file

@ -0,0 +1,91 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models import yolo
from ultralytics.nn.tasks import WorldModel
from ultralytics.utils import DEFAULT_CFG, RANK
from ultralytics.data import build_yolo_dataset
from ultralytics.utils.torch_utils import de_parallel
from ultralytics.utils.checks import check_requirements
import itertools
try:
import clip
except ImportError:
check_requirements("git+https://github.com/ultralytics/CLIP.git")
import clip
def on_pretrain_routine_end(trainer):
"""Callback."""
if RANK in (-1, 0):
# NOTE: for evaluation
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
device = next(trainer.model.parameters()).device
text_model, _ = clip.load("ViT-B/32", device=device)
for p in text_model.parameters():
p.requires_grad_(False)
trainer.text_model = text_model
class WorldTrainer(yolo.detect.DetectionTrainer):
"""
A class to fine-tune a world model on a close-set dataset.
Example:
```python
from ultralytics.models.yolo.world import WorldModel
args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
trainer = WorldTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a WorldTrainer object with given arguments."""
if overrides is None:
overrides = {}
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return WorldModel initialized with specified config and weights."""
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
# NOTE: Following the official config, nc hard-coded to 80 for now.
model = WorldModel(
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
ch=3,
nc=min(self.data["nc"], 80),
verbose=verbose and RANK == -1,
)
if weights:
model.load(weights)
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
return model
def build_dataset(self, img_path, mode="train", batch=None):
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
)
def preprocess_batch(self, batch):
"""Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
batch = super().preprocess_batch(batch)
# NOTE: add text features
texts = list(itertools.chain(*batch["texts"]))
text_token = clip.tokenize(texts).to(batch["img"].device)
txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
return batch

View file

@ -0,0 +1,108 @@
from ultralytics.data import build_yolo_dataset, build_grounding, YOLOConcatDataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.world import WorldTrainer
from ultralytics.utils.torch_utils import de_parallel
from ultralytics.utils import DEFAULT_CFG
class WorldTrainerFromScratch(WorldTrainer):
"""
A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
Example:
```python
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
from ultralytics import YOLOWorld
data = dict(
train=dict(
yolo_data=["Objects365.yaml"],
grounding_data=[
dict(
img_path="../datasets/flickr30k/images",
json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
),
dict(
img_path="../datasets/GQA/images",
json_file="../datasets/GQA/final_mixed_train_no_coco.json",
),
],
),
val=dict(yolo_data=["lvis.yaml"]),
)
model = YOLOWorld("yolov8s-worldv2.yaml")
model.train(data=data, trainer=WorldTrainerFromScratch)
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a WorldTrainer object with given arguments."""
if overrides is None:
overrides = {}
super().__init__(cfg, overrides, _callbacks)
def build_dataset(self, img_path, mode="train", batch=None):
"""
Build YOLO Dataset.
Args:
img_path (List[str] | str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
if mode == "train":
dataset = [
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
if isinstance(im_path, str)
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
for im_path in img_path
]
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
else:
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
def get_dataset(self):
"""
Get train, val path from data dict if it exists.
Returns None if data format is not recognized.
"""
final_data = dict()
data_yaml = self.args.data
assert data_yaml.get("train", False) # object365.yaml
assert data_yaml.get("val", False) # lvis.yaml
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
for d in data["val"]:
if d.get("minival") is None: # for lvis dataset
continue
d["minival"] = str(d["path"] / d["minival"])
for s in ["train", "val"]:
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
# save grounding data if there's one
grounding_data = data_yaml[s].get("grounding_data")
if grounding_data is None:
continue
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
for g in grounding_data:
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
final_data[s] += grounding_data
# NOTE: to make training work properly, set `nc` and `names`
final_data["nc"] = data["val"][0]["nc"]
final_data["names"] = data["val"][0]["names"]
self.data = final_data
return final_data["train"], final_data["val"][0]
def plot_training_labels(self):
"""DO NOT plot labels."""
pass
def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO-World model."""
val = self.args.data["val"]["yolo_data"][0]
self.validator.args.data = val
self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
return super().final_eval()

View file

@ -519,7 +519,8 @@ class ContrastiveHead(nn.Module):
def __init__(self): def __init__(self):
"""Initializes ContrastiveHead with specified region-text similarity parameters.""" """Initializes ContrastiveHead with specified region-text similarity parameters."""
super().__init__() super().__init__()
self.bias = nn.Parameter(torch.zeros([])) # NOTE: use -10.0 to keep the init cls loss consistency with other losses
self.bias = nn.Parameter(torch.tensor([-10.0]))
self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log()) self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
def forward(self, x, w): def forward(self, x, w):
@ -542,7 +543,8 @@ class BNContrastiveHead(nn.Module):
"""Initialize ContrastiveHead with region-text similarity parameters.""" """Initialize ContrastiveHead with region-text similarity parameters."""
super().__init__() super().__init__()
self.norm = nn.BatchNorm2d(embed_dims) self.norm = nn.BatchNorm2d(embed_dims)
self.bias = nn.Parameter(torch.zeros([])) # NOTE: use -10.0 to keep the init cls loss consistency with other losses
self.bias = nn.Parameter(torch.tensor([-10.0]))
# use -1.0 is more stable # use -1.0 is more stable
self.logit_scale = nn.Parameter(-1.0 * torch.ones([])) self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))

View file

@ -250,6 +250,15 @@ class WorldDetect(Detect):
y = torch.cat((dbox, cls.sigmoid()), 1) y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x) return y if self.export else (y, x)
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
# b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
class RTDETRDecoder(nn.Module): class RTDETRDecoder(nn.Module):
""" """

View file

@ -564,28 +564,28 @@ class WorldModel(DetectionModel):
self.clip_model = None # CLIP model placeholder self.clip_model = None # CLIP model placeholder
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def set_classes(self, text): def set_classes(self, text, batch=80, cache_clip_model=True):
"""Perform a forward pass with optional profiling, visualization, and embedding extraction.""" """Set classes in advance so that model could do offline-inference without clip model."""
try: try:
import clip import clip
except ImportError: except ImportError:
check_requirements("git+https://github.com/openai/CLIP.git") check_requirements("git+https://github.com/ultralytics/CLIP.git")
import clip import clip
if not getattr(self, "clip_model", None): # for backwards compatibility of models lacking clip_model attribute if (
not getattr(self, "clip_model", None) and cache_clip_model
): # for backwards compatibility of models lacking clip_model attribute
self.clip_model = clip.load("ViT-B/32")[0] self.clip_model = clip.load("ViT-B/32")[0]
device = next(self.clip_model.parameters()).device model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
device = next(model.parameters()).device
text_token = clip.tokenize(text).to(device) text_token = clip.tokenize(text).to(device)
txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32) txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach() self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
self.model[-1].nc = len(text) self.model[-1].nc = len(text)
def init_criterion(self): def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
"""Initialize the loss criterion for the model."""
raise NotImplementedError
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
""" """
Perform a forward pass through the model. Perform a forward pass through the model.
@ -593,13 +593,14 @@ class WorldModel(DetectionModel):
x (torch.Tensor): The input tensor. x (torch.Tensor): The input tensor.
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False. profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False. visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False. augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
embed (list, optional): A list of feature vectors/embeddings to return. embed (list, optional): A list of feature vectors/embeddings to return.
Returns: Returns:
(torch.Tensor): Model's output tensor. (torch.Tensor): Model's output tensor.
""" """
txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype) txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
if len(txt_feats) != len(x): if len(txt_feats) != len(x):
txt_feats = txt_feats.repeat(len(x), 1, 1) txt_feats = txt_feats.repeat(len(x), 1, 1)
ori_txt_feats = txt_feats.clone() ori_txt_feats = txt_feats.clone()
@ -627,6 +628,21 @@ class WorldModel(DetectionModel):
return torch.unbind(torch.cat(embeddings, 1), dim=0) return torch.unbind(torch.cat(embeddings, 1), dim=0)
return x return x
def loss(self, batch, preds=None):
"""
Compute loss.
Args:
batch (dict): Batch to compute loss on.
preds (torch.Tensor | List[torch.Tensor]): Predictions.
"""
if not hasattr(self, "criterion"):
self.criterion = self.init_criterion()
if preds is None:
preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
return self.criterion(preds, batch)
class Ensemble(nn.ModuleList): class Ensemble(nn.ModuleList):
"""Ensemble of models.""" """Ensemble of models."""

View file

@ -157,7 +157,7 @@ class v8DetectionLoss:
self.hyp = h self.hyp = h
self.stride = m.stride # model strides self.stride = m.stride # model strides
self.nc = m.nc # number of classes self.nc = m.nc # number of classes
self.no = m.no self.no = m.nc + m.reg_max * 4
self.reg_max = m.reg_max self.reg_max = m.reg_max
self.device = device self.device = device