ultralytics 8.1.39 add YOLO-World training (#9268)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
18036908d4
commit
e9187c1296
34 changed files with 2166 additions and 100 deletions
|
|
@ -74,6 +74,7 @@ Here is a list of the supported datasets and a brief description for each:
|
||||||
|
|
||||||
- [**Argoverse**](argoverse.md): A collection of sensor data collected from autonomous vehicles. It contains 3D tracking annotations for car objects.
|
- [**Argoverse**](argoverse.md): A collection of sensor data collected from autonomous vehicles. It contains 3D tracking annotations for car objects.
|
||||||
- [**COCO**](coco.md): Common Objects in Context (COCO) is a large-scale object detection, segmentation, and captioning dataset with 80 object categories.
|
- [**COCO**](coco.md): Common Objects in Context (COCO) is a large-scale object detection, segmentation, and captioning dataset with 80 object categories.
|
||||||
|
- [**LVIS**](lvis.md): LVIS is a large-scale object detection, segmentation, and captioning dataset with 1203 object categories.
|
||||||
- [**COCO8**](coco8.md): A smaller subset of the COCO dataset, COCO8 is more lightweight and faster to train.
|
- [**COCO8**](coco8.md): A smaller subset of the COCO dataset, COCO8 is more lightweight and faster to train.
|
||||||
- [**GlobalWheat2020**](globalwheat2020.md): A dataset containing images of wheat heads for the Global Wheat Challenge 2020.
|
- [**GlobalWheat2020**](globalwheat2020.md): A dataset containing images of wheat heads for the Global Wheat Challenge 2020.
|
||||||
- [**Objects365**](objects365.md): A large-scale object detection dataset with 365 object categories and 600k images, aimed at advancing object detection research.
|
- [**Objects365**](objects365.md): A large-scale object detection dataset with 365 object categories and 600k images, aimed at advancing object detection research.
|
||||||
|
|
|
||||||
96
docs/en/datasets/detect/lvis.md
Normal file
96
docs/en/datasets/detect/lvis.md
Normal file
|
|
@ -0,0 +1,96 @@
|
||||||
|
---
|
||||||
|
comments: true
|
||||||
|
description: Learn how LVIS, a leading dataset for object detection and segmentation, integrates with Ultralytics. Discover ways to use it for training YOLO models.
|
||||||
|
keywords: Ultralytics, LVIS dataset, object detection, YOLO, YOLO model training, image segmentation, computer vision, deep learning models
|
||||||
|
---
|
||||||
|
|
||||||
|
# LVIS Dataset
|
||||||
|
|
||||||
|
The [LVIS](https://www.lvisdataset.org/dataset) dataset is a large-scale, fine-grained vocabulary-level annotation dataset developed and released by Facebook AI Research (FAIR). It is primarily used as a research benchmark for object detection and instance segmentation with a large vocabulary of categories, aiming to drive further advancements in computer vision field.
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
- LVIS contains 160k images and 2M instance annotations for object detection, segmentation, and captioning tasks.
|
||||||
|
- The dataset comprises 1203 object categories, including common objects like cars, bicycles, and animals, as well as more specific categories such as umbrellas, handbags, and sports equipment.
|
||||||
|
- Annotations include object bounding boxes, segmentation masks, and captions for each image.
|
||||||
|
- LVIS provides standardized evaluation metrics like mean Average Precision (mAP) for object detection, and mean Average Recall (mAR) for segmentation tasks, making it suitable for comparing model performance.
|
||||||
|
- LVIS uses the exactly the same images as [COCO](./coco.md) dataset, but with different splits and different annotations.
|
||||||
|
|
||||||
|
## Dataset Structure
|
||||||
|
|
||||||
|
The LVIS dataset is split into three subsets:
|
||||||
|
|
||||||
|
1. **Train**: This subset contains 100k images for training object detection, segmentation, and captioning models.
|
||||||
|
2. **Val**: This subset has 20k images used for validation purposes during model training.
|
||||||
|
3. **Minival**: This subset is exactly the same as COCO val2017 set which has 5k images used for validation purposes during model training.
|
||||||
|
4. **Test**: This subset consists of 20k images used for testing and benchmarking the trained models. Ground truth annotations for this subset are not publicly available, and the results are submitted to the [LVIS evaluation server](https://eval.ai/web/challenges/challenge-page/675/overview) for performance evaluation.
|
||||||
|
|
||||||
|
|
||||||
|
## Applications
|
||||||
|
|
||||||
|
The LVIS dataset is widely used for training and evaluating deep learning models in object detection (such as YOLO, Faster R-CNN, and SSD), instance segmentation (such as Mask R-CNN). The dataset's diverse set of object categories, large number of annotated images, and standardized evaluation metrics make it an essential resource for computer vision researchers and practitioners.
|
||||||
|
|
||||||
|
## Dataset YAML
|
||||||
|
|
||||||
|
A YAML (Yet Another Markup Language) file is used to define the dataset configuration. It contains information about the dataset's paths, classes, and other relevant information. In the case of the LVIS dataset, the `lvis.yaml` file is maintained at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml).
|
||||||
|
|
||||||
|
!!! Example "ultralytics/cfg/datasets/lvis.yaml"
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
--8<-- "ultralytics/cfg/datasets/lvis.yaml"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To train a YOLOv8n model on the LVIS dataset for 100 epochs with an image size of 640, you can use the following code snippets. For a comprehensive list of available arguments, refer to the model [Training](../../modes/train.md) page.
|
||||||
|
|
||||||
|
!!! Example "Train Example"
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# Load a model
|
||||||
|
model = YOLO('yolov8n.pt') # load a pretrained model (recommended for training)
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
results = model.train(data='lvis.yaml', epochs=100, imgsz=640)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "CLI"
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start training from a pretrained *.pt model
|
||||||
|
yolo detect train data=lvis.yaml model=yolov8n.pt epochs=100 imgsz=640
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sample Images and Annotations
|
||||||
|
|
||||||
|
The LVIS dataset contains a diverse set of images with various object categories and complex scenes. Here are some examples of images from the dataset, along with their corresponding annotations:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
- **Mosaiced Image**: This image demonstrates a training batch composed of mosaiced dataset images. Mosaicing is a technique used during training that combines multiple images into a single image to increase the variety of objects and scenes within each training batch. This helps improve the model's ability to generalize to different object sizes, aspect ratios, and contexts.
|
||||||
|
|
||||||
|
The example showcases the variety and complexity of the images in the LVIS dataset and the benefits of using mosaicing during the training process.
|
||||||
|
|
||||||
|
## Citations and Acknowledgments
|
||||||
|
|
||||||
|
If you use the LVIS dataset in your research or development work, please cite the following paper:
|
||||||
|
|
||||||
|
!!! Quote ""
|
||||||
|
|
||||||
|
=== "BibTeX"
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{gupta2019lvis,
|
||||||
|
title={{LVIS}: A Dataset for Large Vocabulary Instance Segmentation},
|
||||||
|
author={Gupta, Agrim and Dollar, Piotr and Girshick, Ross},
|
||||||
|
booktitle={Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition},
|
||||||
|
year={2019}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
We would like to acknowledge the LVIS Consortium for creating and maintaining this valuable resource for the computer vision community. For more information about the LVIS dataset and its creators, visit the [LVIS dataset website](https://www.lvisdataset.org/dataset).
|
||||||
|
|
@ -36,6 +36,7 @@ Bounding box object detection is a computer vision technique that involves detec
|
||||||
|
|
||||||
- [Argoverse](detect/argoverse.md): A dataset containing 3D tracking and motion forecasting data from urban environments with rich annotations.
|
- [Argoverse](detect/argoverse.md): A dataset containing 3D tracking and motion forecasting data from urban environments with rich annotations.
|
||||||
- [COCO](detect/coco.md): A large-scale dataset designed for object detection, segmentation, and captioning with over 200K labeled images.
|
- [COCO](detect/coco.md): A large-scale dataset designed for object detection, segmentation, and captioning with over 200K labeled images.
|
||||||
|
- [LVIS](lvis.md): A large-scale object detection, segmentation, and captioning dataset with 1203 object categories.
|
||||||
- [COCO8](detect/coco8.md): Contains the first 4 images from COCO train and COCO val, suitable for quick tests.
|
- [COCO8](detect/coco8.md): Contains the first 4 images from COCO train and COCO val, suitable for quick tests.
|
||||||
- [Global Wheat 2020](detect/globalwheat2020.md): A dataset of wheat head images collected from around the world for object detection and localization tasks.
|
- [Global Wheat 2020](detect/globalwheat2020.md): A dataset of wheat head images collected from around the world for object detection and localization tasks.
|
||||||
- [Objects365](detect/objects365.md): A high-quality, large-scale dataset for object detection with 365 object categories and over 600K annotated images.
|
- [Objects365](detect/objects365.md): A high-quality, large-scale dataset for object detection with 365 object categories and over 600K annotated images.
|
||||||
|
|
|
||||||
|
|
@ -147,7 +147,7 @@ FastSAM is also available directly from the [https://github.com/CASIA-IVA-Lab/Fa
|
||||||
|
|
||||||
4. Install the CLIP model:
|
4. Install the CLIP model:
|
||||||
```shell
|
```shell
|
||||||
pip install git+https://github.com/openai/CLIP.git
|
pip install git+https://github.com/ultralytics/CLIP.git
|
||||||
```
|
```
|
||||||
|
|
||||||
### Example Usage
|
### Example Usage
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,39 @@ This section details the models available with their specific pre-trained weight
|
||||||
|
|
||||||
The YOLO-World models are easy to integrate into your Python applications. Ultralytics provides user-friendly Python API and CLI commands to streamline development.
|
The YOLO-World models are easy to integrate into your Python applications. Ultralytics provides user-friendly Python API and CLI commands to streamline development.
|
||||||
|
|
||||||
|
### Train Usage
|
||||||
|
|
||||||
|
!!! Tip "Tip"
|
||||||
|
|
||||||
|
We strongly recommend to use `yolov8-worldv2` model for custom training, because it supports deterministic training and also easy to export other formats i.e onnx/tensorrt.
|
||||||
|
|
||||||
|
Object detection is straightforward with the `train` method, as illustrated below:
|
||||||
|
|
||||||
|
!!! Example
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
PyTorch pretrained `*.pt` models as well as configuration `*.yaml` files can be passed to the `YOLOWorld()` class to create a model instance in python:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from ultralytics import YOLOWorld
|
||||||
|
|
||||||
|
# Load a pretrained YOLOv8s-worldv2 model
|
||||||
|
model = YOLOWorld('yolov8s-worldv2.pt')
|
||||||
|
|
||||||
|
# Train the model on the COCO8 example dataset for 100 epochs
|
||||||
|
results = model.train(data='coco8.yaml', epochs=100, imgsz=640)
|
||||||
|
|
||||||
|
# Run inference with the YOLOv8n model on the 'bus.jpg' image
|
||||||
|
results = model('path/to/bus.jpg')
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "CLI"
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Load a pretrained YOLOv8s-worldv2 model and train it on the COCO8 example dataset for 100 epochs
|
||||||
|
yolo train model=yolov8s-worldv2.yaml data=coco8.yaml epochs=100 imgsz=640
|
||||||
|
```
|
||||||
|
|
||||||
### Predict Usage
|
### Predict Usage
|
||||||
|
|
||||||
Object detection is straightforward with the `predict` method, as illustrated below:
|
Object detection is straightforward with the `predict` method, as illustrated below:
|
||||||
|
|
@ -196,6 +229,59 @@ You can also save a model after setting custom classes. By doing this you create
|
||||||
|
|
||||||
This approach provides a powerful means of customizing state-of-the-art object detection models for specific tasks, making advanced AI more accessible and applicable to a broader range of practical applications.
|
This approach provides a powerful means of customizing state-of-the-art object detection models for specific tasks, making advanced AI more accessible and applicable to a broader range of practical applications.
|
||||||
|
|
||||||
|
## Reproduce official results from scratch(Experimental)
|
||||||
|
|
||||||
|
### Prepare datasets
|
||||||
|
|
||||||
|
- Train data
|
||||||
|
|
||||||
|
| Dataset | Type | Samples | Boxes | Annotation Files |
|
||||||
|
|-------------------------------------------------------------------|-----------|---------|-------|--------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| [Objects365v1](https://opendatalab.com/OpenDataLab/Objects365_v1) | Detection | 609k | 9621k | [objects365_train.json](https://opendatalab.com/OpenDataLab/Objects365_v1) |
|
||||||
|
| [GQA](https://nlp.stanford.edu/data/gqa/images.zip) | Grounding | 621k | 3681k | [final_mixed_train_no_coco.json](https://huggingface.co/GLIPModel/GLIP/blob/main/mdetr_annotations/final_mixed_train_no_coco.json) |
|
||||||
|
| [Flickr30k](https://shannon.cs.illinois.edu/DenotationGraph/) | Grounding | 149k | 641k | [final_flickr_separateGT_train.json](https://huggingface.co/GLIPModel/GLIP/blob/main/mdetr_annotations/final_flickr_separateGT_train.json) |
|
||||||
|
|
||||||
|
- Val data
|
||||||
|
|
||||||
|
| Dataset | Type | Annotation Files |
|
||||||
|
|---------------------------------------------------------------------------------------------------------|-----------|--------------------------------------------------------------------------------------------------------|
|
||||||
|
| [LVIS minival](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml) | Detection | [minival.txt](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml) |
|
||||||
|
|
||||||
|
### Launch training from scratch
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
|
||||||
|
`WorldTrainerFromScratch` is highly customized to allow training yolo-world models on both detection datasets and grounding datasets simultaneously. More details please checkout [ultralytics.model.yolo.world.train_world.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train_world.py).
|
||||||
|
|
||||||
|
!!! Example
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
||||||
|
from ultralytics import YOLOWorld
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
train=dict(
|
||||||
|
yolo_data=["Objects365.yaml"],
|
||||||
|
grounding_data=[
|
||||||
|
dict(
|
||||||
|
img_path="../datasets/flickr30k/images",
|
||||||
|
json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
img_path="../datasets/GQA/images",
|
||||||
|
json_file="../datasets/GQA/final_mixed_train_no_coco.json",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
val=dict(yolo_data=["lvis.yaml"]),
|
||||||
|
)
|
||||||
|
model = YOLOWorld("yolov8s-worldv2.yaml")
|
||||||
|
model.train(data=data, batch=128, epochs=100, trainer=WorldTrainerFromScratch)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
## Citations and Acknowledgements
|
## Citations and Acknowledgements
|
||||||
|
|
||||||
We extend our gratitude to the [Tencent AILab Computer Vision Center](https://ai.tencent.com/) for their pioneering work in real-time open-vocabulary object detection with YOLO-World:
|
We extend our gratitude to the [Tencent AILab Computer Vision Center](https://ai.tencent.com/) for their pioneering work in real-time open-vocabulary object detection with YOLO-World:
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,10 @@ keywords: Ultralytics, Data Augmentation, BaseTransform, MixUp, RandomHSV, Lette
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.data.augment.RandomLoadText
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
## ::: ultralytics.data.augment.ClassifyLetterBox
|
## ::: ultralytics.data.augment.ClassifyLetterBox
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,10 @@ keywords: Ultralytics, YOLO v3, Data build, DataLoader, InfiniteDataLoader, seed
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.data.build.build_grounding
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
## ::: ultralytics.data.build.build_dataloader
|
## ::: ultralytics.data.build.build_dataloader
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
|
||||||
|
|
@ -19,14 +19,18 @@ keywords: Ultralytics, YOLO, YOLODataset, SemanticDataset, data handling, data m
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.data.dataset.YOLOMultiModalDataset
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.data.dataset.GroundingDataset
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.data.dataset.YOLOConcatDataset
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
## ::: ultralytics.data.dataset.SemanticDataset
|
## ::: ultralytics.data.dataset.SemanticDataset
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
## ::: ultralytics.data.dataset.load_dataset_cache_file
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
## ::: ultralytics.data.dataset.save_dataset_cache_file
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
|
||||||
|
|
@ -66,3 +66,11 @@ keywords: Ultralytics, data utils, YOLO, img2label_paths, exif_size, polygon2mas
|
||||||
## ::: ultralytics.data.utils.autosplit
|
## ::: ultralytics.data.utils.autosplit
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.data.utils.load_dataset_cache_file
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.data.utils.save_dataset_cache_file
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
|
||||||
15
docs/en/reference/models/yolo/world/train.md
Normal file
15
docs/en/reference/models/yolo/world/train.md
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Reference for `ultralytics/models/yolo/world/train.py`
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
|
||||||
|
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/yolo/world/train.py) 🛠️. Thank you 🙏!
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.models.yolo.world.train.WorldTrainer
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.models.yolo.world.train.on_pretrain_routine_end
|
||||||
|
|
||||||
|
<br><br>
|
||||||
11
docs/en/reference/models/yolo/world/train_world.md
Normal file
11
docs/en/reference/models/yolo/world/train_world.md
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
# Reference for `ultralytics/models/yolo/world/train_world.py`
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
|
||||||
|
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train_world.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train_world.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/yolo/world/train_world.py) 🛠️. Thank you 🙏!
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
@ -18,6 +18,7 @@ chr043416@gmail.com: RizwanMunawar
|
||||||
glenn.jocher@ultralytics.com: glenn-jocher
|
glenn.jocher@ultralytics.com: glenn-jocher
|
||||||
muhammadrizwanmunawar123@gmail.com: RizwanMunawar
|
muhammadrizwanmunawar123@gmail.com: RizwanMunawar
|
||||||
not.committed.yet: null
|
not.committed.yet: null
|
||||||
|
plashchynski@gmail.com: plashchynski
|
||||||
priytosh.revolution@live.com: priytosh-tripathi
|
priytosh.revolution@live.com: priytosh-tripathi
|
||||||
shuizhuyuanluo@126.com: null
|
shuizhuyuanluo@126.com: null
|
||||||
xinwang614@gmail.com: GreatV
|
xinwang614@gmail.com: GreatV
|
||||||
|
|
|
||||||
|
|
@ -240,6 +240,7 @@ nav:
|
||||||
- datasets/detect/index.md
|
- datasets/detect/index.md
|
||||||
- Argoverse: datasets/detect/argoverse.md
|
- Argoverse: datasets/detect/argoverse.md
|
||||||
- COCO: datasets/detect/coco.md
|
- COCO: datasets/detect/coco.md
|
||||||
|
- LVIS: datasets/detect/lvis.md
|
||||||
- COCO8: datasets/detect/coco8.md
|
- COCO8: datasets/detect/coco8.md
|
||||||
- GlobalWheat2020: datasets/detect/globalwheat2020.md
|
- GlobalWheat2020: datasets/detect/globalwheat2020.md
|
||||||
- Objects365: datasets/detect/objects365.md
|
- Objects365: datasets/detect/objects365.md
|
||||||
|
|
@ -492,6 +493,9 @@ nav:
|
||||||
- predict: reference/models/yolo/segment/predict.md
|
- predict: reference/models/yolo/segment/predict.md
|
||||||
- train: reference/models/yolo/segment/train.md
|
- train: reference/models/yolo/segment/train.md
|
||||||
- val: reference/models/yolo/segment/val.md
|
- val: reference/models/yolo/segment/val.md
|
||||||
|
- world:
|
||||||
|
- train: reference/models/yolo/world/train.md
|
||||||
|
- train_world: reference/models/yolo/world/train_world.md
|
||||||
- nn:
|
- nn:
|
||||||
- autobackend: reference/nn/autobackend.md
|
- autobackend: reference/nn/autobackend.md
|
||||||
- modules:
|
- modules:
|
||||||
|
|
|
||||||
|
|
@ -643,3 +643,29 @@ def test_yolo_world():
|
||||||
model = YOLO("yolov8s-world.pt") # no YOLOv8n-world model yet
|
model = YOLO("yolov8s-world.pt") # no YOLOv8n-world model yet
|
||||||
model.set_classes(["tree", "window"])
|
model.set_classes(["tree", "window"])
|
||||||
model(ASSETS / "bus.jpg", conf=0.01)
|
model(ASSETS / "bus.jpg", conf=0.01)
|
||||||
|
|
||||||
|
# Training from yaml
|
||||||
|
model = YOLO("yolov8s-worldv2.yaml") # no YOLOv8n-world model yet
|
||||||
|
model.train(data="coco8.yaml", epochs=2, imgsz=32, cache="disk", batch=-1, close_mosaic=1, name="yolo-world")
|
||||||
|
|
||||||
|
model = YOLO("yolov8s-worldv2.pt") # no YOLOv8n-world model yet
|
||||||
|
# val
|
||||||
|
model.val(data="coco8.yaml", imgsz=32, save_txt=True, save_json=True)
|
||||||
|
# Training from pretrain
|
||||||
|
model.train(data="coco8.yaml", epochs=2, imgsz=32, cache="disk", batch=-1, close_mosaic=1, name="yolo-world")
|
||||||
|
|
||||||
|
# test WorWorldTrainerFromScratch
|
||||||
|
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
||||||
|
|
||||||
|
model = YOLO("yolov8s-worldv2.yaml") # no YOLOv8n-world model yet
|
||||||
|
data = dict(train=dict(yolo_data=["coco8.yaml"]), val=dict(yolo_data=["coco8.yaml"]))
|
||||||
|
model.train(
|
||||||
|
data=data,
|
||||||
|
epochs=2,
|
||||||
|
imgsz=32,
|
||||||
|
cache="disk",
|
||||||
|
batch=-1,
|
||||||
|
close_mosaic=1,
|
||||||
|
name="yolo-world",
|
||||||
|
trainer=WorldTrainerFromScratch,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.1.38"
|
__version__ = "8.1.39"
|
||||||
|
|
||||||
from ultralytics.data.explorer.explorer import Explorer
|
from ultralytics.data.explorer.explorer import Explorer
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
||||||
|
|
|
||||||
1239
ultralytics/cfg/datasets/lvis.yaml
Normal file
1239
ultralytics/cfg/datasets/lvis.yaml
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,15 +1,31 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
from .base import BaseDataset
|
from .base import BaseDataset
|
||||||
from .build import build_dataloader, build_yolo_dataset, load_inference_source
|
from .build import (
|
||||||
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
build_dataloader,
|
||||||
|
build_yolo_dataset,
|
||||||
|
build_grounding,
|
||||||
|
load_inference_source,
|
||||||
|
)
|
||||||
|
from .dataset import (
|
||||||
|
ClassificationDataset,
|
||||||
|
SemanticDataset,
|
||||||
|
YOLODataset,
|
||||||
|
YOLOMultiModalDataset,
|
||||||
|
GroundingDataset,
|
||||||
|
YOLOConcatDataset,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"BaseDataset",
|
"BaseDataset",
|
||||||
"ClassificationDataset",
|
"ClassificationDataset",
|
||||||
"SemanticDataset",
|
"SemanticDataset",
|
||||||
"YOLODataset",
|
"YOLODataset",
|
||||||
|
"YOLOMultiModalDataset",
|
||||||
|
"YOLOConcatDataset",
|
||||||
|
"GroundingDataset",
|
||||||
"build_yolo_dataset",
|
"build_yolo_dataset",
|
||||||
|
"build_grounding",
|
||||||
"build_dataloader",
|
"build_dataloader",
|
||||||
"load_inference_source",
|
"load_inference_source",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -66,7 +67,7 @@ class Compose:
|
||||||
|
|
||||||
def __init__(self, transforms):
|
def __init__(self, transforms):
|
||||||
"""Initializes the Compose object with a list of transforms."""
|
"""Initializes the Compose object with a list of transforms."""
|
||||||
self.transforms = transforms
|
self.transforms = transforms if isinstance(transforms, list) else [transforms]
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
"""Applies a series of transformations to input data."""
|
"""Applies a series of transformations to input data."""
|
||||||
|
|
@ -78,6 +79,29 @@ class Compose:
|
||||||
"""Appends a new transform to the existing list of transforms."""
|
"""Appends a new transform to the existing list of transforms."""
|
||||||
self.transforms.append(transform)
|
self.transforms.append(transform)
|
||||||
|
|
||||||
|
def insert(self, index, transform):
|
||||||
|
"""Inserts a new transform to the existing list of transforms."""
|
||||||
|
self.transforms.insert(index, transform)
|
||||||
|
|
||||||
|
def __getitem__(self, index: Union[list, int]) -> "Compose":
|
||||||
|
"""Retrieve a specific transform or a set of transforms using indexing."""
|
||||||
|
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
|
||||||
|
index = [index] if isinstance(index, int) else index
|
||||||
|
return Compose([self.transforms[i] for i in index])
|
||||||
|
|
||||||
|
def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None:
|
||||||
|
"""Retrieve a specific transform or a set of transforms using indexing."""
|
||||||
|
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
|
||||||
|
if isinstance(index, list):
|
||||||
|
assert isinstance(
|
||||||
|
value, list
|
||||||
|
), f"The indices should be the same type as values, but got {type(index)} and {type(value)}"
|
||||||
|
if isinstance(index, int):
|
||||||
|
index, value = [index], [value]
|
||||||
|
for i, v in zip(index, value):
|
||||||
|
assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}."
|
||||||
|
self.transforms[i] = v
|
||||||
|
|
||||||
def tolist(self):
|
def tolist(self):
|
||||||
"""Converts the list of transforms to a standard Python list."""
|
"""Converts the list of transforms to a standard Python list."""
|
||||||
return self.transforms
|
return self.transforms
|
||||||
|
|
@ -118,6 +142,8 @@ class BaseMixTransform:
|
||||||
mix_labels[i] = self.pre_transform(data)
|
mix_labels[i] = self.pre_transform(data)
|
||||||
labels["mix_labels"] = mix_labels
|
labels["mix_labels"] = mix_labels
|
||||||
|
|
||||||
|
# Update cls and texts
|
||||||
|
labels = self._update_label_text(labels)
|
||||||
# Mosaic or MixUp
|
# Mosaic or MixUp
|
||||||
labels = self._mix_transform(labels)
|
labels = self._mix_transform(labels)
|
||||||
labels.pop("mix_labels", None)
|
labels.pop("mix_labels", None)
|
||||||
|
|
@ -131,6 +157,22 @@ class BaseMixTransform:
|
||||||
"""Gets a list of shuffled indexes for mosaic augmentation."""
|
"""Gets a list of shuffled indexes for mosaic augmentation."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _update_label_text(self, labels):
|
||||||
|
"""Update label text."""
|
||||||
|
if "texts" not in labels:
|
||||||
|
return labels
|
||||||
|
|
||||||
|
mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], [])
|
||||||
|
mix_texts = list({tuple(x) for x in mix_texts})
|
||||||
|
text2id = {text: i for i, text in enumerate(mix_texts)}
|
||||||
|
|
||||||
|
for label in [labels] + labels["mix_labels"]:
|
||||||
|
for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
|
||||||
|
text = label["texts"][int(l)]
|
||||||
|
label["cls"][i] = text2id[tuple(text)]
|
||||||
|
label["texts"] = mix_texts
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
class Mosaic(BaseMixTransform):
|
class Mosaic(BaseMixTransform):
|
||||||
"""
|
"""
|
||||||
|
|
@ -320,6 +362,8 @@ class Mosaic(BaseMixTransform):
|
||||||
final_labels["instances"].clip(imgsz, imgsz)
|
final_labels["instances"].clip(imgsz, imgsz)
|
||||||
good = final_labels["instances"].remove_zero_area_boxes()
|
good = final_labels["instances"].remove_zero_area_boxes()
|
||||||
final_labels["cls"] = final_labels["cls"][good]
|
final_labels["cls"] = final_labels["cls"][good]
|
||||||
|
if "texts" in mosaic_labels[0]:
|
||||||
|
final_labels["texts"] = mosaic_labels[0]["texts"]
|
||||||
return final_labels
|
return final_labels
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -970,6 +1014,83 @@ class Format:
|
||||||
return masks, instances, cls
|
return masks, instances, cls
|
||||||
|
|
||||||
|
|
||||||
|
class RandomLoadText:
|
||||||
|
"""
|
||||||
|
Randomly sample positive texts and negative texts and update the class indices accordingly to the number of samples.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
prompt_format (str): Format for prompt. Default is '{}'.
|
||||||
|
neg_samples (tuple[int]): A ranger to randomly sample negative texts, Default is (80, 80).
|
||||||
|
max_samples (int): The max number of different text samples in one image, Default is 80.
|
||||||
|
padding (bool): Whether to pad texts to max_samples. Default is False.
|
||||||
|
padding_value (str): The padding text. Default is "".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_format: str = "{}",
|
||||||
|
neg_samples: Tuple[int, int] = (80, 80),
|
||||||
|
max_samples: int = 80,
|
||||||
|
padding: bool = False,
|
||||||
|
padding_value: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""Initializes the RandomLoadText class with given parameters."""
|
||||||
|
self.prompt_format = prompt_format
|
||||||
|
self.neg_samples = neg_samples
|
||||||
|
self.max_samples = max_samples
|
||||||
|
self.padding = padding
|
||||||
|
self.padding_value = padding_value
|
||||||
|
|
||||||
|
def __call__(self, labels: dict) -> dict:
|
||||||
|
"""Return updated classes and texts."""
|
||||||
|
assert "texts" in labels, "No texts found in labels."
|
||||||
|
class_texts = labels["texts"]
|
||||||
|
num_classes = len(class_texts)
|
||||||
|
cls = np.asarray(labels.pop("cls"), dtype=int)
|
||||||
|
pos_labels = np.unique(cls).tolist()
|
||||||
|
|
||||||
|
if len(pos_labels) > self.max_samples:
|
||||||
|
pos_labels = set(random.sample(pos_labels, k=self.max_samples))
|
||||||
|
|
||||||
|
neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))
|
||||||
|
neg_labels = []
|
||||||
|
for i in range(num_classes):
|
||||||
|
if i not in pos_labels:
|
||||||
|
neg_labels.append(i)
|
||||||
|
neg_labels = random.sample(neg_labels, k=neg_samples)
|
||||||
|
|
||||||
|
sampled_labels = pos_labels + neg_labels
|
||||||
|
random.shuffle(sampled_labels)
|
||||||
|
|
||||||
|
label2ids = {label: i for i, label in enumerate(sampled_labels)}
|
||||||
|
valid_idx = np.zeros(len(labels["instances"]), dtype=bool)
|
||||||
|
new_cls = []
|
||||||
|
for i, label in enumerate(cls.squeeze(-1).tolist()):
|
||||||
|
if label not in label2ids:
|
||||||
|
continue
|
||||||
|
valid_idx[i] = True
|
||||||
|
new_cls.append([label2ids[label]])
|
||||||
|
labels["instances"] = labels["instances"][valid_idx]
|
||||||
|
labels["cls"] = np.array(new_cls)
|
||||||
|
|
||||||
|
# Randomly select one prompt when there's more than one prompts
|
||||||
|
texts = []
|
||||||
|
for label in sampled_labels:
|
||||||
|
prompts = class_texts[label]
|
||||||
|
assert len(prompts) > 0
|
||||||
|
prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))])
|
||||||
|
texts.append(prompt)
|
||||||
|
|
||||||
|
if self.padding:
|
||||||
|
valid_labels = len(pos_labels) + len(neg_labels)
|
||||||
|
num_padding = self.max_samples - valid_labels
|
||||||
|
if num_padding > 0:
|
||||||
|
texts += [self.padding_value] * num_padding
|
||||||
|
|
||||||
|
labels["texts"] = texts
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
def v8_transforms(dataset, imgsz, hyp, stretch=False):
|
def v8_transforms(dataset, imgsz, hyp, stretch=False):
|
||||||
"""Convert images to a size suitable for YOLOv8 training."""
|
"""Convert images to a size suitable for YOLOv8 training."""
|
||||||
pre_transform = Compose(
|
pre_transform = Compose(
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from ultralytics.data.loaders import (
|
||||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
||||||
from ultralytics.utils import RANK, colorstr
|
from ultralytics.utils import RANK, colorstr
|
||||||
from ultralytics.utils.checks import check_file
|
from ultralytics.utils.checks import check_file
|
||||||
from .dataset import YOLODataset
|
from .dataset import YOLODataset, YOLOMultiModalDataset, GroundingDataset
|
||||||
from .utils import PIN_MEMORY
|
from .utils import PIN_MEMORY
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -82,9 +82,10 @@ def seed_worker(worker_id): # noqa
|
||||||
random.seed(worker_seed)
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
|
||||||
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
|
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
|
||||||
"""Build YOLO Dataset."""
|
"""Build YOLO Dataset."""
|
||||||
return YOLODataset(
|
dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
|
||||||
|
return dataset(
|
||||||
img_path=img_path,
|
img_path=img_path,
|
||||||
imgsz=cfg.imgsz,
|
imgsz=cfg.imgsz,
|
||||||
batch_size=batch,
|
batch_size=batch,
|
||||||
|
|
@ -103,6 +104,27 @@ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, str
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):
|
||||||
|
"""Build YOLO Dataset."""
|
||||||
|
return GroundingDataset(
|
||||||
|
img_path=img_path,
|
||||||
|
json_file=json_file,
|
||||||
|
imgsz=cfg.imgsz,
|
||||||
|
batch_size=batch,
|
||||||
|
augment=mode == "train", # augmentation
|
||||||
|
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||||
|
rect=cfg.rect or rect, # rectangular batches
|
||||||
|
cache=cfg.cache or None,
|
||||||
|
single_cls=cfg.single_cls or False,
|
||||||
|
stride=int(stride),
|
||||||
|
pad=0.0 if mode == "train" else 0.5,
|
||||||
|
prefix=colorstr(f"{mode}: "),
|
||||||
|
task=cfg.task,
|
||||||
|
classes=cfg.classes,
|
||||||
|
fraction=cfg.fraction if mode == "train" else 1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
||||||
"""Return an InfiniteDataLoader or DataLoader for training or validation set."""
|
"""Return an InfiniteDataLoader or DataLoader for training or validation set."""
|
||||||
batch = min(batch, len(dataset))
|
batch = min(batch, len(dataset))
|
||||||
|
|
|
||||||
|
|
@ -219,6 +219,7 @@ def convert_coco(
|
||||||
use_segments=False,
|
use_segments=False,
|
||||||
use_keypoints=False,
|
use_keypoints=False,
|
||||||
cls91to80=True,
|
cls91to80=True,
|
||||||
|
lvis=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
|
Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
|
||||||
|
|
@ -229,12 +230,14 @@ def convert_coco(
|
||||||
use_segments (bool, optional): Whether to include segmentation masks in the output.
|
use_segments (bool, optional): Whether to include segmentation masks in the output.
|
||||||
use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
|
use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
|
||||||
cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
|
cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
|
||||||
|
lvis (bool, optional): Whether to convert data in lvis dataset way.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
from ultralytics.data.converter import convert_coco
|
from ultralytics.data.converter import convert_coco
|
||||||
|
|
||||||
convert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True)
|
convert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True)
|
||||||
|
convert_coco('../datasets/lvis/annotations/', use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
|
|
@ -251,8 +254,14 @@ def convert_coco(
|
||||||
|
|
||||||
# Import json
|
# Import json
|
||||||
for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
|
for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
|
||||||
fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name
|
lname = "" if lvis else json_file.stem.replace("instances_", "")
|
||||||
|
fn = Path(save_dir) / "labels" / lname # folder name
|
||||||
fn.mkdir(parents=True, exist_ok=True)
|
fn.mkdir(parents=True, exist_ok=True)
|
||||||
|
if lvis:
|
||||||
|
# NOTE: create folders for both train and val in advance,
|
||||||
|
# since LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split.
|
||||||
|
(fn / "train2017").mkdir(parents=True, exist_ok=True)
|
||||||
|
(fn / "val2017").mkdir(parents=True, exist_ok=True)
|
||||||
with open(json_file) as f:
|
with open(json_file) as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
|
|
@ -263,16 +272,20 @@ def convert_coco(
|
||||||
for ann in data["annotations"]:
|
for ann in data["annotations"]:
|
||||||
imgToAnns[ann["image_id"]].append(ann)
|
imgToAnns[ann["image_id"]].append(ann)
|
||||||
|
|
||||||
|
image_txt = []
|
||||||
# Write labels file
|
# Write labels file
|
||||||
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"):
|
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"):
|
||||||
img = images[f"{img_id:d}"]
|
img = images[f"{img_id:d}"]
|
||||||
h, w, f = img["height"], img["width"], img["file_name"]
|
h, w = img["height"], img["width"]
|
||||||
|
f = str(Path(img["coco_url"]).relative_to("http://images.cocodataset.org")) if lvis else img["file_name"]
|
||||||
|
if lvis:
|
||||||
|
image_txt.append(str(Path("./images") / f))
|
||||||
|
|
||||||
bboxes = []
|
bboxes = []
|
||||||
segments = []
|
segments = []
|
||||||
keypoints = []
|
keypoints = []
|
||||||
for ann in anns:
|
for ann in anns:
|
||||||
if ann["iscrowd"]:
|
if ann.get("iscrowd", False):
|
||||||
continue
|
continue
|
||||||
# The COCO box format is [top left x, top left y, width, height]
|
# The COCO box format is [top left x, top left y, width, height]
|
||||||
box = np.array(ann["bbox"], dtype=np.float64)
|
box = np.array(ann["bbox"], dtype=np.float64)
|
||||||
|
|
@ -314,7 +327,12 @@ def convert_coco(
|
||||||
) # cls, box or segments
|
) # cls, box or segments
|
||||||
file.write(("%g " * len(line)).rstrip() % line + "\n")
|
file.write(("%g " * len(line)).rstrip() % line + "\n")
|
||||||
|
|
||||||
LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}")
|
if lvis:
|
||||||
|
with open((Path(save_dir) / json_file.name.replace("lvis_v1_", "").replace(".json", ".txt")), "a") as f:
|
||||||
|
for l in image_txt:
|
||||||
|
f.write(f"{l}\n")
|
||||||
|
|
||||||
|
LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}")
|
||||||
|
|
||||||
|
|
||||||
def convert_dota_to_yolo_obb(dota_root_path: str):
|
def convert_dota_to_yolo_obb(dota_root_path: str):
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,41 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
import contextlib
|
import contextlib
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
|
from collections import defaultdict
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
|
from torch.utils.data import ConcatDataset
|
||||||
|
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
||||||
from ultralytics.utils.ops import resample_segments
|
from ultralytics.utils.ops import resample_segments
|
||||||
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
|
from .augment import (
|
||||||
|
Compose,
|
||||||
|
Format,
|
||||||
|
Instances,
|
||||||
|
LetterBox,
|
||||||
|
RandomLoadText,
|
||||||
|
classify_augmentations,
|
||||||
|
classify_transforms,
|
||||||
|
v8_transforms,
|
||||||
|
)
|
||||||
from .base import BaseDataset
|
from .base import BaseDataset
|
||||||
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
|
from .utils import (
|
||||||
|
HELP_URL,
|
||||||
|
LOGGER,
|
||||||
|
get_hash,
|
||||||
|
img2label_paths,
|
||||||
|
verify_image,
|
||||||
|
verify_image_label,
|
||||||
|
load_dataset_cache_file,
|
||||||
|
save_dataset_cache_file,
|
||||||
|
)
|
||||||
|
|
||||||
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
||||||
DATASET_CACHE_VERSION = "1.0.3"
|
DATASET_CACHE_VERSION = "1.0.3"
|
||||||
|
|
@ -105,7 +126,7 @@ class YOLODataset(BaseDataset):
|
||||||
x["hash"] = get_hash(self.label_files + self.im_files)
|
x["hash"] = get_hash(self.label_files + self.im_files)
|
||||||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||||||
x["msgs"] = msgs # warnings
|
x["msgs"] = msgs # warnings
|
||||||
save_dataset_cache_file(self.prefix, path, x)
|
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
|
|
@ -339,31 +360,125 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||||
x["hash"] = get_hash([x[0] for x in self.samples])
|
x["hash"] = get_hash([x[0] for x in self.samples])
|
||||||
x["results"] = nf, nc, len(samples), samples
|
x["results"] = nf, nc, len(samples), samples
|
||||||
x["msgs"] = msgs # warnings
|
x["msgs"] = msgs # warnings
|
||||||
save_dataset_cache_file(self.prefix, path, x)
|
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_cache_file(path):
|
class YOLOMultiModalDataset(YOLODataset):
|
||||||
"""Load an Ultralytics *.cache dictionary from path."""
|
"""
|
||||||
import gc
|
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
||||||
|
|
||||||
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
Args:
|
||||||
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
data (dict, optional): A dataset YAML dictionary. Defaults to None.
|
||||||
gc.enable()
|
task (str): An explicit arg to point current task, Defaults to 'detect'.
|
||||||
return cache
|
|
||||||
|
Returns:
|
||||||
|
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, data=None, task="detect", **kwargs):
|
||||||
|
"""Initializes a dataset object for object detection tasks with optional specifications."""
|
||||||
|
super().__init__(*args, data=data, task=task, **kwargs)
|
||||||
|
|
||||||
|
def update_labels_info(self, label):
|
||||||
|
"""Add texts information for multi modal model training."""
|
||||||
|
labels = super().update_labels_info(label)
|
||||||
|
# NOTE: some categories are concatenated with its synonyms by `/`.
|
||||||
|
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def build_transforms(self, hyp=None):
|
||||||
|
"""Enhances data transformations with optional text augmentation for multi-modal training."""
|
||||||
|
transforms = super().build_transforms(hyp)
|
||||||
|
if self.augment:
|
||||||
|
# NOTE: hard-coded the args for now.
|
||||||
|
transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
|
||||||
|
return transforms
|
||||||
|
|
||||||
|
|
||||||
def save_dataset_cache_file(prefix, path, x):
|
class GroundingDataset(YOLODataset):
|
||||||
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
def __init__(self, *args, task="detect", json_file, **kwargs):
|
||||||
x["version"] = DATASET_CACHE_VERSION # add cache version
|
"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
|
||||||
if is_dir_writeable(path.parent):
|
assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
|
||||||
if path.exists():
|
self.json_file = json_file
|
||||||
path.unlink() # remove *.cache file if exists
|
super().__init__(*args, task=task, data={}, **kwargs)
|
||||||
np.save(str(path), x) # save cache for next time
|
|
||||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
def get_img_files(self, img_path):
|
||||||
LOGGER.info(f"{prefix}New cache created: {path}")
|
"""The image files would be read in `get_labels` function, return empty list here."""
|
||||||
else:
|
return []
|
||||||
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
|
||||||
|
def get_labels(self):
|
||||||
|
"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
|
||||||
|
labels = []
|
||||||
|
LOGGER.info("Loading annotation file...")
|
||||||
|
with open(self.json_file, "r") as f:
|
||||||
|
annotations = json.load(f)
|
||||||
|
images = {f'{x["id"]:d}': x for x in annotations["images"]}
|
||||||
|
imgToAnns = defaultdict(list)
|
||||||
|
for ann in annotations["annotations"]:
|
||||||
|
imgToAnns[ann["image_id"]].append(ann)
|
||||||
|
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Reading annotations {self.json_file}"):
|
||||||
|
img = images[f"{img_id:d}"]
|
||||||
|
h, w, f = img["height"], img["width"], img["file_name"]
|
||||||
|
im_file = Path(self.img_path) / f
|
||||||
|
if not im_file.exists():
|
||||||
|
continue
|
||||||
|
self.im_files.append(str(im_file))
|
||||||
|
bboxes = []
|
||||||
|
cat2id = {}
|
||||||
|
texts = []
|
||||||
|
for ann in anns:
|
||||||
|
if ann["iscrowd"]:
|
||||||
|
continue
|
||||||
|
box = np.array(ann["bbox"], dtype=np.float32)
|
||||||
|
box[:2] += box[2:] / 2
|
||||||
|
box[[0, 2]] /= float(w)
|
||||||
|
box[[1, 3]] /= float(h)
|
||||||
|
if box[2] <= 0 or box[3] <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
|
||||||
|
if cat_name not in cat2id:
|
||||||
|
cat2id[cat_name] = len(cat2id)
|
||||||
|
texts.append([cat_name])
|
||||||
|
cls = cat2id[cat_name] # class
|
||||||
|
box = [cls] + box.tolist()
|
||||||
|
if box not in bboxes:
|
||||||
|
bboxes.append(box)
|
||||||
|
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
||||||
|
labels.append(
|
||||||
|
dict(
|
||||||
|
im_file=im_file,
|
||||||
|
shape=(h, w),
|
||||||
|
cls=lb[:, 0:1], # n, 1
|
||||||
|
bboxes=lb[:, 1:], # n, 4
|
||||||
|
normalized=True,
|
||||||
|
bbox_format="xywh",
|
||||||
|
texts=texts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def build_transforms(self, hyp=None):
|
||||||
|
"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
|
||||||
|
transforms = super().build_transforms(hyp)
|
||||||
|
if self.augment:
|
||||||
|
# NOTE: hard-coded the args for now.
|
||||||
|
transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
|
||||||
|
return transforms
|
||||||
|
|
||||||
|
|
||||||
|
class YOLOConcatDataset(ConcatDataset):
|
||||||
|
"""
|
||||||
|
Dataset as a concatenation of multiple datasets.
|
||||||
|
|
||||||
|
This class is useful to assemble different existing datasets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collate_fn(batch):
|
||||||
|
"""Collates data samples into batches."""
|
||||||
|
return YOLODataset.collate_fn(batch)
|
||||||
|
|
||||||
|
|
||||||
# TODO: support semantic segmentation
|
# TODO: support semantic segmentation
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from ultralytics.utils import (
|
||||||
emojis,
|
emojis,
|
||||||
yaml_load,
|
yaml_load,
|
||||||
yaml_save,
|
yaml_save,
|
||||||
|
is_dir_writeable,
|
||||||
)
|
)
|
||||||
from ultralytics.utils.checks import check_file, check_font, is_ascii
|
from ultralytics.utils.checks import check_file, check_font, is_ascii
|
||||||
from ultralytics.utils.downloads import download, safe_download, unzip_file
|
from ultralytics.utils.downloads import download, safe_download, unzip_file
|
||||||
|
|
@ -303,7 +304,7 @@ def check_det_dataset(dataset, autodownload=True):
|
||||||
|
|
||||||
# Set paths
|
# Set paths
|
||||||
data["path"] = path # download scripts
|
data["path"] = path # download scripts
|
||||||
for k in "train", "val", "test":
|
for k in "train", "val", "test", "minival":
|
||||||
if data.get(k): # prepend path
|
if data.get(k): # prepend path
|
||||||
if isinstance(data[k], str):
|
if isinstance(data[k], str):
|
||||||
x = (path / data[k]).resolve()
|
x = (path / data[k]).resolve()
|
||||||
|
|
@ -649,3 +650,26 @@ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annot
|
||||||
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
||||||
with open(path.parent / txt[i], "a") as f:
|
with open(path.parent / txt[i], "a") as f:
|
||||||
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
|
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_cache_file(path):
|
||||||
|
"""Load an Ultralytics *.cache dictionary from path."""
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
||||||
|
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
||||||
|
gc.enable()
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
def save_dataset_cache_file(prefix, path, x, version):
|
||||||
|
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
||||||
|
x["version"] = version # add cache version
|
||||||
|
if is_dir_writeable(path.parent):
|
||||||
|
if path.exists():
|
||||||
|
path.unlink() # remove *.cache file if exists
|
||||||
|
np.save(str(path), x) # save cache for next time
|
||||||
|
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||||
|
LOGGER.info(f"{prefix}New cache created: {path}")
|
||||||
|
else:
|
||||||
|
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
||||||
|
|
|
||||||
|
|
@ -126,22 +126,7 @@ class BaseTrainer:
|
||||||
|
|
||||||
# Model and Dataset
|
# Model and Dataset
|
||||||
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||||
try:
|
self.trainset, self.testset = self.get_dataset()
|
||||||
if self.args.task == "classify":
|
|
||||||
self.data = check_cls_dataset(self.args.data)
|
|
||||||
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
|
||||||
"detect",
|
|
||||||
"segment",
|
|
||||||
"pose",
|
|
||||||
"obb",
|
|
||||||
):
|
|
||||||
self.data = check_det_dataset(self.args.data)
|
|
||||||
if "yaml_file" in self.data:
|
|
||||||
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
|
||||||
|
|
||||||
self.trainset, self.testset = self.get_dataset(self.data)
|
|
||||||
self.ema = None
|
self.ema = None
|
||||||
|
|
||||||
# Optimization utils init
|
# Optimization utils init
|
||||||
|
|
@ -509,13 +494,27 @@ class BaseTrainer:
|
||||||
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
||||||
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
||||||
|
|
||||||
@staticmethod
|
def get_dataset(self):
|
||||||
def get_dataset(data):
|
|
||||||
"""
|
"""
|
||||||
Get train, val path from data dict if it exists.
|
Get train, val path from data dict if it exists.
|
||||||
|
|
||||||
Returns None if data format is not recognized.
|
Returns None if data format is not recognized.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
if self.args.task == "classify":
|
||||||
|
data = check_cls_dataset(self.args.data)
|
||||||
|
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
||||||
|
"detect",
|
||||||
|
"segment",
|
||||||
|
"pose",
|
||||||
|
"obb",
|
||||||
|
):
|
||||||
|
data = check_det_dataset(self.args.data)
|
||||||
|
if "yaml_file" in data:
|
||||||
|
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||||
|
self.data = data
|
||||||
return data["train"], data.get("val") or data.get("test")
|
return data["train"], data.get("val") or data.get("test")
|
||||||
|
|
||||||
def setup_model(self):
|
def setup_model(self):
|
||||||
|
|
@ -666,8 +665,8 @@ class BaseTrainer:
|
||||||
if ckpt is None:
|
if ckpt is None:
|
||||||
return
|
return
|
||||||
best_fitness = 0.0
|
best_fitness = 0.0
|
||||||
start_epoch = ckpt["epoch"] + 1
|
start_epoch = ckpt.get("epoch", -1) + 1
|
||||||
if ckpt["optimizer"] is not None:
|
if ckpt.get("optimizer", None) is not None:
|
||||||
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
||||||
best_fitness = ckpt["best_fitness"]
|
best_fitness = ckpt["best_fitness"]
|
||||||
if self.ema and ckpt.get("ema"):
|
if self.ema and ckpt.get("ema"):
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class FastSAMPrompt:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from ultralytics.utils.checks import check_requirements
|
from ultralytics.utils.checks import check_requirements
|
||||||
|
|
||||||
check_requirements("git+https://github.com/openai/CLIP.git")
|
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||||
import clip
|
import clip
|
||||||
self.clip = clip
|
self.clip = clip
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
from ultralytics.models.yolo import classify, detect, obb, pose, segment
|
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
|
||||||
|
|
||||||
from .model import YOLO, YOLOWorld
|
from .model import YOLO, YOLOWorld
|
||||||
|
|
||||||
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld"
|
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ class DetectionValidator(BaseValidator):
|
||||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||||
self.nt_per_class = None
|
self.nt_per_class = None
|
||||||
self.is_coco = False
|
self.is_coco = False
|
||||||
|
self.is_lvis = False
|
||||||
self.class_map = None
|
self.class_map = None
|
||||||
self.args.task = "detect"
|
self.args.task = "detect"
|
||||||
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||||
|
|
@ -66,8 +67,9 @@ class DetectionValidator(BaseValidator):
|
||||||
"""Initialize evaluation metrics for YOLO."""
|
"""Initialize evaluation metrics for YOLO."""
|
||||||
val = self.data.get(self.args.split, "") # validation path
|
val = self.data.get(self.args.split, "") # validation path
|
||||||
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
|
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
|
||||||
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
|
||||||
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
|
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(len(model.names)))
|
||||||
|
self.args.save_json |= (self.is_coco or self.is_lvis) and not self.training # run on final val if training COCO
|
||||||
self.names = model.names
|
self.names = model.names
|
||||||
self.nc = len(model.names)
|
self.nc = len(model.names)
|
||||||
self.metrics.names = self.names
|
self.metrics.names = self.names
|
||||||
|
|
@ -266,7 +268,8 @@ class DetectionValidator(BaseValidator):
|
||||||
self.jdict.append(
|
self.jdict.append(
|
||||||
{
|
{
|
||||||
"image_id": image_id,
|
"image_id": image_id,
|
||||||
"category_id": self.class_map[int(p[5])],
|
"category_id": self.class_map[int(p[5])]
|
||||||
|
+ (1 if self.is_lvis else 0), # index starts from 1 if it's lvis
|
||||||
"bbox": [round(x, 3) for x in b],
|
"bbox": [round(x, 3) for x in b],
|
||||||
"score": round(p[4], 5),
|
"score": round(p[4], 5),
|
||||||
}
|
}
|
||||||
|
|
@ -274,26 +277,42 @@ class DetectionValidator(BaseValidator):
|
||||||
|
|
||||||
def eval_json(self, stats):
|
def eval_json(self, stats):
|
||||||
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
||||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
|
||||||
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
|
|
||||||
pred_json = self.save_dir / "predictions.json" # predictions
|
pred_json = self.save_dir / "predictions.json" # predictions
|
||||||
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
anno_json = (
|
||||||
|
self.data["path"]
|
||||||
|
/ "annotations"
|
||||||
|
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
|
||||||
|
) # annotations
|
||||||
|
pkg = "pycocotools" if self.is_coco else "lvis"
|
||||||
|
LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
|
||||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||||
check_requirements("pycocotools>=2.0.6")
|
for x in pred_json, anno_json:
|
||||||
|
assert x.is_file(), f"{x} file not found"
|
||||||
|
check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
|
||||||
|
if self.is_coco:
|
||||||
from pycocotools.coco import COCO # noqa
|
from pycocotools.coco import COCO # noqa
|
||||||
from pycocotools.cocoeval import COCOeval # noqa
|
from pycocotools.cocoeval import COCOeval # noqa
|
||||||
|
|
||||||
for x in anno_json, pred_json:
|
|
||||||
assert x.is_file(), f"{x} file not found"
|
|
||||||
anno = COCO(str(anno_json)) # init annotations api
|
anno = COCO(str(anno_json)) # init annotations api
|
||||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||||
eval = COCOeval(anno, pred, "bbox")
|
eval = COCOeval(anno, pred, "bbox")
|
||||||
if self.is_coco:
|
else:
|
||||||
|
from lvis import LVIS, LVISEval
|
||||||
|
|
||||||
|
anno = LVIS(str(anno_json)) # init annotations api
|
||||||
|
pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||||
|
eval = LVISEval(anno, pred, "bbox")
|
||||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||||
eval.evaluate()
|
eval.evaluate()
|
||||||
eval.accumulate()
|
eval.accumulate()
|
||||||
eval.summarize()
|
eval.summarize()
|
||||||
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
if self.is_lvis:
|
||||||
|
eval.print_results() # explicitly call print_results
|
||||||
|
# update mAP50-95 and mAP50
|
||||||
|
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
|
||||||
|
eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]]
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f"pycocotools unable to run: {e}")
|
LOGGER.warning(f"{pkg} unable to run: {e}")
|
||||||
return stats
|
return stats
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,7 @@ class YOLOWorld(Model):
|
||||||
"model": WorldModel,
|
"model": WorldModel,
|
||||||
"validator": yolo.detect.DetectionValidator,
|
"validator": yolo.detect.DetectionValidator,
|
||||||
"predictor": yolo.detect.DetectionPredictor,
|
"predictor": yolo.detect.DetectionPredictor,
|
||||||
|
"trainer": yolo.world.WorldTrainer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
5
ultralytics/models/yolo/world/__init__.py
Normal file
5
ultralytics/models/yolo/world/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
from .train import WorldTrainer
|
||||||
|
|
||||||
|
__all__ = ["WorldTrainer"]
|
||||||
91
ultralytics/models/yolo/world/train.py
Normal file
91
ultralytics/models/yolo/world/train.py
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
from ultralytics.models import yolo
|
||||||
|
from ultralytics.nn.tasks import WorldModel
|
||||||
|
from ultralytics.utils import DEFAULT_CFG, RANK
|
||||||
|
from ultralytics.data import build_yolo_dataset
|
||||||
|
from ultralytics.utils.torch_utils import de_parallel
|
||||||
|
from ultralytics.utils.checks import check_requirements
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
try:
|
||||||
|
import clip
|
||||||
|
except ImportError:
|
||||||
|
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||||
|
import clip
|
||||||
|
|
||||||
|
|
||||||
|
def on_pretrain_routine_end(trainer):
|
||||||
|
"""Callback."""
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
# NOTE: for evaluation
|
||||||
|
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
||||||
|
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
||||||
|
device = next(trainer.model.parameters()).device
|
||||||
|
text_model, _ = clip.load("ViT-B/32", device=device)
|
||||||
|
for p in text_model.parameters():
|
||||||
|
p.requires_grad_(False)
|
||||||
|
trainer.text_model = text_model
|
||||||
|
|
||||||
|
|
||||||
|
class WorldTrainer(yolo.detect.DetectionTrainer):
|
||||||
|
"""
|
||||||
|
A class to fine-tune a world model on a close-set dataset.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.models.yolo.world import WorldModel
|
||||||
|
|
||||||
|
args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
|
||||||
|
trainer = WorldTrainer(overrides=args)
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
|
"""Initialize a WorldTrainer object with given arguments."""
|
||||||
|
if overrides is None:
|
||||||
|
overrides = {}
|
||||||
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
|
|
||||||
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
|
"""Return WorldModel initialized with specified config and weights."""
|
||||||
|
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||||
|
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||||
|
model = WorldModel(
|
||||||
|
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||||
|
ch=3,
|
||||||
|
nc=min(self.data["nc"], 80),
|
||||||
|
verbose=verbose and RANK == -1,
|
||||||
|
)
|
||||||
|
if weights:
|
||||||
|
model.load(weights)
|
||||||
|
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def build_dataset(self, img_path, mode="train", batch=None):
|
||||||
|
"""
|
||||||
|
Build YOLO Dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_path (str): Path to the folder containing images.
|
||||||
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||||
|
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||||
|
"""
|
||||||
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||||
|
return build_yolo_dataset(
|
||||||
|
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
|
||||||
|
)
|
||||||
|
|
||||||
|
def preprocess_batch(self, batch):
|
||||||
|
"""Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
|
||||||
|
batch = super().preprocess_batch(batch)
|
||||||
|
|
||||||
|
# NOTE: add text features
|
||||||
|
texts = list(itertools.chain(*batch["texts"]))
|
||||||
|
text_token = clip.tokenize(texts).to(batch["img"].device)
|
||||||
|
txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
|
||||||
|
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
||||||
|
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
||||||
|
return batch
|
||||||
108
ultralytics/models/yolo/world/train_world.py
Normal file
108
ultralytics/models/yolo/world/train_world.py
Normal file
|
|
@ -0,0 +1,108 @@
|
||||||
|
from ultralytics.data import build_yolo_dataset, build_grounding, YOLOConcatDataset
|
||||||
|
from ultralytics.data.utils import check_det_dataset
|
||||||
|
from ultralytics.models.yolo.world import WorldTrainer
|
||||||
|
from ultralytics.utils.torch_utils import de_parallel
|
||||||
|
from ultralytics.utils import DEFAULT_CFG
|
||||||
|
|
||||||
|
|
||||||
|
class WorldTrainerFromScratch(WorldTrainer):
|
||||||
|
"""
|
||||||
|
A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
||||||
|
from ultralytics import YOLOWorld
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
train=dict(
|
||||||
|
yolo_data=["Objects365.yaml"],
|
||||||
|
grounding_data=[
|
||||||
|
dict(
|
||||||
|
img_path="../datasets/flickr30k/images",
|
||||||
|
json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
img_path="../datasets/GQA/images",
|
||||||
|
json_file="../datasets/GQA/final_mixed_train_no_coco.json",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
val=dict(yolo_data=["lvis.yaml"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
model = YOLOWorld("yolov8s-worldv2.yaml")
|
||||||
|
model.train(data=data, trainer=WorldTrainerFromScratch)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
|
"""Initialize a WorldTrainer object with given arguments."""
|
||||||
|
if overrides is None:
|
||||||
|
overrides = {}
|
||||||
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
|
|
||||||
|
def build_dataset(self, img_path, mode="train", batch=None):
|
||||||
|
"""
|
||||||
|
Build YOLO Dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_path (List[str] | str): Path to the folder containing images.
|
||||||
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||||
|
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||||
|
"""
|
||||||
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||||
|
if mode == "train":
|
||||||
|
dataset = [
|
||||||
|
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
|
||||||
|
if isinstance(im_path, str)
|
||||||
|
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
|
||||||
|
for im_path in img_path
|
||||||
|
]
|
||||||
|
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
|
||||||
|
else:
|
||||||
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
||||||
|
|
||||||
|
def get_dataset(self):
|
||||||
|
"""
|
||||||
|
Get train, val path from data dict if it exists.
|
||||||
|
|
||||||
|
Returns None if data format is not recognized.
|
||||||
|
"""
|
||||||
|
final_data = dict()
|
||||||
|
data_yaml = self.args.data
|
||||||
|
assert data_yaml.get("train", False) # object365.yaml
|
||||||
|
assert data_yaml.get("val", False) # lvis.yaml
|
||||||
|
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
|
||||||
|
assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
|
||||||
|
val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
|
||||||
|
for d in data["val"]:
|
||||||
|
if d.get("minival") is None: # for lvis dataset
|
||||||
|
continue
|
||||||
|
d["minival"] = str(d["path"] / d["minival"])
|
||||||
|
for s in ["train", "val"]:
|
||||||
|
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
|
||||||
|
# save grounding data if there's one
|
||||||
|
grounding_data = data_yaml[s].get("grounding_data")
|
||||||
|
if grounding_data is None:
|
||||||
|
continue
|
||||||
|
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
|
||||||
|
for g in grounding_data:
|
||||||
|
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
|
||||||
|
final_data[s] += grounding_data
|
||||||
|
# NOTE: to make training work properly, set `nc` and `names`
|
||||||
|
final_data["nc"] = data["val"][0]["nc"]
|
||||||
|
final_data["names"] = data["val"][0]["names"]
|
||||||
|
self.data = final_data
|
||||||
|
return final_data["train"], final_data["val"][0]
|
||||||
|
|
||||||
|
def plot_training_labels(self):
|
||||||
|
"""DO NOT plot labels."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def final_eval(self):
|
||||||
|
"""Performs final evaluation and validation for object detection YOLO-World model."""
|
||||||
|
val = self.args.data["val"]["yolo_data"][0]
|
||||||
|
self.validator.args.data = val
|
||||||
|
self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
|
||||||
|
return super().final_eval()
|
||||||
|
|
@ -519,7 +519,8 @@ class ContrastiveHead(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initializes ContrastiveHead with specified region-text similarity parameters."""
|
"""Initializes ContrastiveHead with specified region-text similarity parameters."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.bias = nn.Parameter(torch.zeros([]))
|
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
|
||||||
|
self.bias = nn.Parameter(torch.tensor([-10.0]))
|
||||||
self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
|
self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
|
||||||
|
|
||||||
def forward(self, x, w):
|
def forward(self, x, w):
|
||||||
|
|
@ -542,7 +543,8 @@ class BNContrastiveHead(nn.Module):
|
||||||
"""Initialize ContrastiveHead with region-text similarity parameters."""
|
"""Initialize ContrastiveHead with region-text similarity parameters."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = nn.BatchNorm2d(embed_dims)
|
self.norm = nn.BatchNorm2d(embed_dims)
|
||||||
self.bias = nn.Parameter(torch.zeros([]))
|
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
|
||||||
|
self.bias = nn.Parameter(torch.tensor([-10.0]))
|
||||||
# use -1.0 is more stable
|
# use -1.0 is more stable
|
||||||
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
|
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -250,6 +250,15 @@ class WorldDetect(Detect):
|
||||||
y = torch.cat((dbox, cls.sigmoid()), 1)
|
y = torch.cat((dbox, cls.sigmoid()), 1)
|
||||||
return y if self.export else (y, x)
|
return y if self.export else (y, x)
|
||||||
|
|
||||||
|
def bias_init(self):
|
||||||
|
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
||||||
|
m = self # self.model[-1] # Detect() module
|
||||||
|
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
||||||
|
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
||||||
|
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
||||||
|
a[-1].bias.data[:] = 1.0 # box
|
||||||
|
# b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
||||||
|
|
||||||
|
|
||||||
class RTDETRDecoder(nn.Module):
|
class RTDETRDecoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -564,28 +564,28 @@ class WorldModel(DetectionModel):
|
||||||
self.clip_model = None # CLIP model placeholder
|
self.clip_model = None # CLIP model placeholder
|
||||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||||
|
|
||||||
def set_classes(self, text):
|
def set_classes(self, text, batch=80, cache_clip_model=True):
|
||||||
"""Perform a forward pass with optional profiling, visualization, and embedding extraction."""
|
"""Set classes in advance so that model could do offline-inference without clip model."""
|
||||||
try:
|
try:
|
||||||
import clip
|
import clip
|
||||||
except ImportError:
|
except ImportError:
|
||||||
check_requirements("git+https://github.com/openai/CLIP.git")
|
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||||
import clip
|
import clip
|
||||||
|
|
||||||
if not getattr(self, "clip_model", None): # for backwards compatibility of models lacking clip_model attribute
|
if (
|
||||||
|
not getattr(self, "clip_model", None) and cache_clip_model
|
||||||
|
): # for backwards compatibility of models lacking clip_model attribute
|
||||||
self.clip_model = clip.load("ViT-B/32")[0]
|
self.clip_model = clip.load("ViT-B/32")[0]
|
||||||
device = next(self.clip_model.parameters()).device
|
model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
|
||||||
|
device = next(model.parameters()).device
|
||||||
text_token = clip.tokenize(text).to(device)
|
text_token = clip.tokenize(text).to(device)
|
||||||
txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32)
|
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
|
||||||
|
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
|
||||||
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
||||||
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
|
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
||||||
self.model[-1].nc = len(text)
|
self.model[-1].nc = len(text)
|
||||||
|
|
||||||
def init_criterion(self):
|
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
||||||
"""Initialize the loss criterion for the model."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
|
||||||
"""
|
"""
|
||||||
Perform a forward pass through the model.
|
Perform a forward pass through the model.
|
||||||
|
|
||||||
|
|
@ -593,13 +593,14 @@ class WorldModel(DetectionModel):
|
||||||
x (torch.Tensor): The input tensor.
|
x (torch.Tensor): The input tensor.
|
||||||
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
|
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
|
||||||
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
|
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
|
||||||
|
txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
|
||||||
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
|
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
|
||||||
embed (list, optional): A list of feature vectors/embeddings to return.
|
embed (list, optional): A list of feature vectors/embeddings to return.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(torch.Tensor): Model's output tensor.
|
(torch.Tensor): Model's output tensor.
|
||||||
"""
|
"""
|
||||||
txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
|
txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
|
||||||
if len(txt_feats) != len(x):
|
if len(txt_feats) != len(x):
|
||||||
txt_feats = txt_feats.repeat(len(x), 1, 1)
|
txt_feats = txt_feats.repeat(len(x), 1, 1)
|
||||||
ori_txt_feats = txt_feats.clone()
|
ori_txt_feats = txt_feats.clone()
|
||||||
|
|
@ -627,6 +628,21 @@ class WorldModel(DetectionModel):
|
||||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def loss(self, batch, preds=None):
|
||||||
|
"""
|
||||||
|
Compute loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (dict): Batch to compute loss on.
|
||||||
|
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
||||||
|
"""
|
||||||
|
if not hasattr(self, "criterion"):
|
||||||
|
self.criterion = self.init_criterion()
|
||||||
|
|
||||||
|
if preds is None:
|
||||||
|
preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
|
||||||
|
return self.criterion(preds, batch)
|
||||||
|
|
||||||
|
|
||||||
class Ensemble(nn.ModuleList):
|
class Ensemble(nn.ModuleList):
|
||||||
"""Ensemble of models."""
|
"""Ensemble of models."""
|
||||||
|
|
|
||||||
|
|
@ -157,7 +157,7 @@ class v8DetectionLoss:
|
||||||
self.hyp = h
|
self.hyp = h
|
||||||
self.stride = m.stride # model strides
|
self.stride = m.stride # model strides
|
||||||
self.nc = m.nc # number of classes
|
self.nc = m.nc # number of classes
|
||||||
self.no = m.no
|
self.no = m.nc + m.reg_max * 4
|
||||||
self.reg_max = m.reg_max
|
self.reg_max = m.reg_max
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue