(B) |
+ |----------------------------------------------------------------------------------------------|-----------------------|-------------------|--------------------------------|-------------------------------------|--------------------|-------------------|
+ | [YOLOv8n-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-obb.pt) | 1024 | <++> | <++> | <++> | 3.2 | 23.3 |
+ | [YOLOv8s-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-obb.pt) | 1024 | <++> | <++> | <++> | 11.4 | 76.3 |
+ | [YOLOv8m-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-obb.pt) | 1024 | <++> | <++> | <++> | 26.4 | 208.6 |
+ | [YOLOv8l-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-obb.pt) | 1024 | <++> | <++> | <++> | 44.5 | 433.8 |
+ | [YOLOv8x-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-obb.pt) | 1024 | <++> | <++> | <++> | 69.5 | 676.7 |
+
## Usage Examples
This example provides simple YOLOv8 training and inference examples. For full documentation on these and other [modes](../modes/index.md) see the [Predict](../modes/predict.md), [Train](../modes/train.md), [Val](../modes/val.md) and [Export](../modes/export.md) docs pages.
-Note the below example is for YOLOv8 [Detect](../tasks/detect.md) models for object detection. For additional supported tasks see the [Segment](../tasks/segment.md), [Classify](../tasks/classify.md) and [Pose](../tasks/pose.md) docs.
+Note the below example is for YOLOv8 [Detect](../tasks/detect.md) models for object detection. For additional supported tasks see the [Segment](../tasks/segment.md), [Classify](../tasks/classify.md), [Obb](../tasks/obb.md) docs and [Pose](../tasks/pose.md) docs.
!!! Example
diff --git a/docs/en/reference/data/split_dota.md b/docs/en/reference/data/split_dota.md
new file mode 100644
index 00000000..1c2cf23d
--- /dev/null
+++ b/docs/en/reference/data/split_dota.md
@@ -0,0 +1,39 @@
+# Reference for `ultralytics/data/split_dota.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/split_dota.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/split_dota.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/data/split_dota.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.data.split_dota.bbox_iof
+
+
+
+## ::: ultralytics.data.split_dota.load_yolo_dota
+
+
+
+## ::: ultralytics.data.split_dota.get_windows
+
+
+
+## ::: ultralytics.data.split_dota.get_window_obj
+
+
+
+## ::: ultralytics.data.split_dota.crop_and_save
+
+
+
+## ::: ultralytics.data.split_dota.split_images_and_labels
+
+
+
+## ::: ultralytics.data.split_dota.split_trainval
+
+
+
+## ::: ultralytics.data.split_dota.split_test
+
+
diff --git a/docs/en/reference/engine/results.md b/docs/en/reference/engine/results.md
index b4b709e8..9b389ee3 100644
--- a/docs/en/reference/engine/results.md
+++ b/docs/en/reference/engine/results.md
@@ -34,3 +34,7 @@ keywords: Ultralytics, engine, results, base tensor, boxes, keypoints
## ::: ultralytics.engine.results.Probs
+
+## ::: ultralytics.engine.results.OBB
+
+
diff --git a/docs/en/reference/models/yolo/obb/predict.md b/docs/en/reference/models/yolo/obb/predict.md
new file mode 100644
index 00000000..8279a641
--- /dev/null
+++ b/docs/en/reference/models/yolo/obb/predict.md
@@ -0,0 +1,11 @@
+# Reference for `ultralytics/models/yolo/obb/predict.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/obb/predict.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/obb/predict.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/yolo/obb/predict.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.yolo.obb.predict.OBBPredictor
+
+
diff --git a/docs/en/reference/models/yolo/obb/train.md b/docs/en/reference/models/yolo/obb/train.md
new file mode 100644
index 00000000..3888aff8
--- /dev/null
+++ b/docs/en/reference/models/yolo/obb/train.md
@@ -0,0 +1,11 @@
+# Reference for `ultralytics/models/yolo/obb/train.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/obb/train.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/obb/train.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/yolo/obb/train.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.yolo.obb.train.OBBTrainer
+
+
diff --git a/docs/en/reference/models/yolo/obb/val.md b/docs/en/reference/models/yolo/obb/val.md
new file mode 100644
index 00000000..aeeccea1
--- /dev/null
+++ b/docs/en/reference/models/yolo/obb/val.md
@@ -0,0 +1,11 @@
+# Reference for `ultralytics/models/yolo/obb/val.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/obb/val.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/obb/val.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/yolo/obb/val.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.yolo.obb.val.OBBValidator
+
+
diff --git a/docs/en/reference/nn/modules/head.md b/docs/en/reference/nn/modules/head.md
index 40ffb25b..00a5ea6a 100644
--- a/docs/en/reference/nn/modules/head.md
+++ b/docs/en/reference/nn/modules/head.md
@@ -19,6 +19,10 @@ keywords: Ultralytics, YOLO, Detection, Pose, RTDETRDecoder, nn modules, guides
+## ::: ultralytics.nn.modules.head.OBB
+
+
+
## ::: ultralytics.nn.modules.head.Pose
diff --git a/docs/en/reference/nn/tasks.md b/docs/en/reference/nn/tasks.md
index aa84d881..d7908f56 100644
--- a/docs/en/reference/nn/tasks.md
+++ b/docs/en/reference/nn/tasks.md
@@ -19,6 +19,10 @@ keywords: Ultralytics, YOLO, nn tasks, DetectionModel, PoseModel, RTDETRDetectio
+## ::: ultralytics.nn.tasks.OBBModel
+
+
+
## ::: ultralytics.nn.tasks.SegmentationModel
diff --git a/docs/en/reference/utils/loss.md b/docs/en/reference/utils/loss.md
index 922ad6c3..94ae7918 100644
--- a/docs/en/reference/utils/loss.md
+++ b/docs/en/reference/utils/loss.md
@@ -23,6 +23,10 @@ keywords: Ultralytics, Loss functions, VarifocalLoss, BboxLoss, v8DetectionLoss,
+## ::: ultralytics.utils.loss.RotatedBboxLoss
+
+
+
## ::: ultralytics.utils.loss.KeypointLoss
@@ -42,3 +46,7 @@ keywords: Ultralytics, Loss functions, VarifocalLoss, BboxLoss, v8DetectionLoss,
## ::: ultralytics.utils.loss.v8ClassificationLoss
+
+## ::: ultralytics.utils.loss.v8OBBLoss
+
+
diff --git a/docs/en/reference/utils/metrics.md b/docs/en/reference/utils/metrics.md
index 3154a739..6851ae98 100644
--- a/docs/en/reference/utils/metrics.md
+++ b/docs/en/reference/utils/metrics.md
@@ -35,6 +35,10 @@ keywords: Ultralytics, YOLO, YOLOv3, YOLOv4, metrics, confusion matrix, detectio
+## ::: ultralytics.utils.metrics.OBBMetrics
+
+
+
## ::: ultralytics.utils.metrics.bbox_ioa
@@ -55,6 +59,18 @@ keywords: Ultralytics, YOLO, YOLOv3, YOLOv4, metrics, confusion matrix, detectio
+## ::: ultralytics.utils.metrics._get_covariance_matrix
+
+
+
+## ::: ultralytics.utils.metrics.probiou
+
+
+
+## ::: ultralytics.utils.metrics.batch_probiou
+
+
+
## ::: ultralytics.utils.metrics.smooth_BCE
diff --git a/docs/en/reference/utils/ops.md b/docs/en/reference/utils/ops.md
index c366fd94..43a3e0ef 100644
--- a/docs/en/reference/utils/ops.md
+++ b/docs/en/reference/utils/ops.md
@@ -27,6 +27,10 @@ keywords: Ultralytics YOLO, Utility Operations, segment2box, make_divisible, cli
+## ::: ultralytics.utils.ops.nms_rotated
+
+
+
## ::: ultralytics.utils.ops.non_max_suppression
diff --git a/docs/en/reference/utils/plotting.md b/docs/en/reference/utils/plotting.md
index b465af43..bf00d2fb 100644
--- a/docs/en/reference/utils/plotting.md
+++ b/docs/en/reference/utils/plotting.md
@@ -47,6 +47,10 @@ keywords: Ultralytics, plotting, utils, color annotation, label plotting, image
+## ::: ultralytics.utils.plotting.output_to_rotated_target
+
+
+
## ::: ultralytics.utils.plotting.feature_visualization
diff --git a/docs/en/reference/utils/tal.md b/docs/en/reference/utils/tal.md
index 6519d35d..9f832aad 100644
--- a/docs/en/reference/utils/tal.md
+++ b/docs/en/reference/utils/tal.md
@@ -15,11 +15,7 @@ keywords: Ultralytics, task aligned assigner, select highest overlaps, make anch
-## ::: ultralytics.utils.tal.select_candidates_in_gts
-
-
-
-## ::: ultralytics.utils.tal.select_highest_overlaps
+## ::: ultralytics.utils.tal.RotatedTaskAlignedAssigner
@@ -34,3 +30,7 @@ keywords: Ultralytics, task aligned assigner, select highest overlaps, make anch
## ::: ultralytics.utils.tal.bbox2dist
+
+## ::: ultralytics.utils.tal.dist2rbox
+
+
diff --git a/docs/en/tasks/index.md b/docs/en/tasks/index.md
index a5de2754..77521170 100644
--- a/docs/en/tasks/index.md
+++ b/docs/en/tasks/index.md
@@ -1,7 +1,7 @@
---
comments: true
description: Learn about the cornerstone computer vision tasks YOLOv8 can perform including detection, segmentation, classification, and pose estimation. Understand their uses in your AI projects.
-keywords: Ultralytics, YOLOv8, Detection, Segmentation, Classification, Pose Estimation, AI Framework, Computer Vision Tasks
+keywords: Ultralytics, YOLOv8, Detection, Segmentation, Classification, Pose Estimation, Oriented Object Detection, AI Framework, Computer Vision Tasks
---
# Ultralytics YOLOv8 Tasks
@@ -9,7 +9,7 @@ keywords: Ultralytics, YOLOv8, Detection, Segmentation, Classification, Pose Est
-YOLOv8 is an AI framework that supports multiple computer vision **tasks**. The framework can be used to perform [detection](detect.md), [segmentation](segment.md), [classification](classify.md), and [pose](pose.md) estimation. Each of these tasks has a different objective and use case.
+YOLOv8 is an AI framework that supports multiple computer vision **tasks**. The framework can be used to perform [detection](detect.md), [segmentation](segment.md), [obb](obb.md), [classification](classify.md), and [pose](pose.md) estimation. Each of these tasks has a different objective and use case.
@@ -19,7 +19,7 @@ YOLOv8 is an AI framework that supports multiple computer vision **tasks**. The
allowfullscreen>
- Watch: Explore Ultralytics YOLO Tasks: Object Detection, Segmentation, Tracking, and Pose Estimation.
+ Watch: Explore Ultralytics YOLO Tasks: Object Detection, Segmentation, OBB, Tracking, and Pose Estimation.
## [Detection](detect.md)
@@ -46,6 +46,12 @@ Pose/keypoint detection is a task that involves detecting specific points in an
[Pose Examples](pose.md){ .md-button }
+## [Obb](obb.md)
+
+Oriented object detection goes a step further than regular object detection with introducing an extra angle to locate objects more accurate in an image. YOLOv8 can detect rotated objects in an image or video frame with high accuracy and speed.
+
+[Oriented Detection](obb.md){ .md-button }
+
## Conclusion
-YOLOv8 supports multiple tasks, including detection, segmentation, classification, and keypoints detection. Each of these tasks has different objectives and use cases. By understanding the differences between these tasks, you can choose the appropriate task for your computer vision application.
+YOLOv8 supports multiple tasks, including detection, segmentation, classification, oriented object detection and keypoints detection. Each of these tasks has different objectives and use cases. By understanding the differences between these tasks, you can choose the appropriate task for your computer vision application.
diff --git a/docs/en/tasks/obb.md b/docs/en/tasks/obb.md
new file mode 100644
index 00000000..cc6ed72c
--- /dev/null
+++ b/docs/en/tasks/obb.md
@@ -0,0 +1,181 @@
+---
+comments: true
+description: Learn how to use oriented object detection models with Ultralytics YOLO. Instructions on training, validation, image prediction, and model export.
+keywords: yolov8, oriented object detection, Ultralytics, DOTA dataset, rotated object detection, object detection, model training, model validation, image prediction, model export
+---
+
+# Oriented Object Detection
+
+
+
+Oriented object detection goes a step further than object detection and introduce an extra angle to locate objects more accurate in an image.
+
+The output of an oriented object detector is a set of rotated bounding boxes that exactly enclose the objects in the image, along with class labels and confidence scores for each box. Object detection is a good choice when you need to identify objects of interest in a scene, but don't need to know exactly where the object is or its exact shape.
+
+
+
+
+
+!!! Tip "Tip"
+
+ YOLOv8 Obb models use the `-obb` suffix, i.e. `yolov8n-obb.pt` and are pretrained on [DOTAv1](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/DOTAv1.yaml).
+
+## [Models](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/models/v8)
+
+YOLOv8 pretrained Obb models are shown here, which are pretrained on the [DOTAv1](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/DOTAv1.yaml) dataset.
+
+[Models](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/models) download automatically from the latest Ultralytics [release](https://github.com/ultralytics/assets/releases) on first use.
+
+| Model | size
(pixels) | mAPbox
50 | Speed
CPU ONNX
(ms) | Speed
A100 TensorRT
(ms) | params
(M) | FLOPs
(B) |
+|----------------------------------------------------------------------------------------------|-----------------------|-------------------|--------------------------------|-------------------------------------|--------------------|-------------------|
+| [YOLOv8n-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-obb.pt) | 1024 | <++> | <++> | <++> | 3.2 | 23.3 |
+| [YOLOv8s-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-obb.pt) | 1024 | <++> | <++> | <++> | 11.4 | 76.3 |
+| [YOLOv8m-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-obb.pt) | 1024 | <++> | <++> | <++> | 26.4 | 208.6 |
+| [YOLOv8l-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-obb.pt) | 1024 | <++> | <++> | <++> | 44.5 | 433.8 |
+| [YOLOv8x-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-obb.pt) | 1024 | <++> | <++> | <++> | 69.5 | 676.7 |
+
+
+- **mAPval** values are for single-model single-scale on [DOTAv1 test](http://cocodataset.org) dataset.
+
Reproduce by `yolo val obb data=DOTAv1.yaml device=0`
+- **Speed** averaged over DOTAv1 val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/)
+ instance.
+
Reproduce by `yolo val obb data=DOTAv1.yaml batch=1 device=0|cpu`
+
+## Train
+
+
+Train YOLOv8n-obb on the dota128.yaml dataset for 100 epochs at image size 640. For a full list of available arguments see the [Configuration](../usage/cfg.md) page.
+
+!!! Example
+
+ === "Python"
+
+ ```python
+ from ultralytics import YOLO
+
+ # Load a model
+ model = YOLO('yolov8n-obb.yaml') # build a new model from YAML
+ model = YOLO('yolov8n-obb.pt') # load a pretrained model (recommended for training)
+ model = YOLO('yolov8n-obb.yaml').load('yolov8n.pt') # build from YAML and transfer weights
+
+ # Train the model
+ results = model.train(data='dota128-obb.yaml', epochs=100, imgsz=640)
+ ```
+ === "CLI"
+
+ ```bash
+ # Build a new model from YAML and start training from scratch
+ yolo obb train data=dota128-obb.yaml model=yolov8n-obb.yaml epochs=100 imgsz=640
+
+ # Start training from a pretrained *.pt model
+ yolo obb train data=dota128-obb.yaml model=yolov8n-obb.pt epochs=100 imgsz=640
+
+ # Build a new model from YAML, transfer pretrained weights to it and start training
+ yolo obb train data=dota128-obb.yaml model=yolov8n-obb.yaml pretrained=yolov8n-obb.pt epochs=100 imgsz=640
+ ```
+
+### Dataset format
+
+yolo obb dataset format can be found in detail in the [Dataset Guide](../datasets/obb/index.md)..
+
+## Val
+
+Validate trained YOLOv8n-obb model accuracy on the dota128-obb dataset. No argument need to passed as the `model`
+retains it's training `data` and arguments as model attributes.
+
+!!! Example
+
+ === "Python"
+
+ ```python
+ from ultralytics import YOLO
+
+ # Load a model
+ model = YOLO('yolov8n-obb.pt') # load an official model
+ model = YOLO('path/to/best.pt') # load a custom model
+
+ # Validate the model
+ metrics = model.val() # no arguments needed, dataset and settings remembered
+ metrics.box.map # map50-95(B)
+ metrics.box.map50 # map50(B)
+ metrics.box.map75 # map75(B)
+ metrics.box.maps # a list contains map50-95(B) of each category
+ ```
+ === "CLI"
+
+ ```bash
+ yolo obb val model=yolov8n-obb.pt # val official model
+ yolo obb val model=path/to/best.pt # val custom model
+ ```
+
+## Predict
+
+Use a trained YOLOv8n-obb model to run predictions on images.
+
+!!! Example
+
+ === "Python"
+
+ ```python
+ from ultralytics import YOLO
+
+ # Load a model
+ model = YOLO('yolov8n-obb.pt') # load an official model
+ model = YOLO('path/to/best.pt') # load a custom model
+
+ # Predict with the model
+ results = model('https://ultralytics.com/images/bus.jpg') # predict on an image
+ ```
+ === "CLI"
+
+ ```bash
+ yolo obb predict model=yolov8n-obb.pt source='https://ultralytics.com/images/bus.jpg' # predict with official model
+ yolo obb predict model=path/to/best.pt source='https://ultralytics.com/images/bus.jpg' # predict with custom model
+ ```
+
+See full `predict` mode details in the [Predict](https://docs.ultralytics.com/modes/predict/) page.
+
+## Export
+
+Export a YOLOv8n-obb model to a different format like ONNX, CoreML, etc.
+
+!!! Example
+
+ === "Python"
+
+ ```python
+ from ultralytics import YOLO
+
+ # Load a model
+ model = YOLO('yolov8n-obb.pt') # load an official model
+ model = YOLO('path/to/best.pt') # load a custom trained model
+
+ # Export the model
+ model.export(format='onnx')
+ ```
+ === "CLI"
+
+ ```bash
+ yolo export model=yolov8n-obb.pt format=onnx # export official model
+ yolo export model=path/to/best.pt format=onnx # export custom trained model
+ ```
+
+Available YOLOv8-obb export formats are in the table below. You can predict or validate directly on exported models, i.e. `yolo predict model=yolov8n-obb.onnx`. Usage examples are shown for your model after export completes.
+
+| Format | `format` Argument | Model | Metadata | Arguments |
+|--------------------------------------------------------------------|-------------------|-------------------------------|----------|-----------------------------------------------------|
+| [PyTorch](https://pytorch.org/) | - | `yolov8n-obb.pt` | ✅ | - |
+| [TorchScript](https://pytorch.org/docs/stable/jit.html) | `torchscript` | `yolov8n-obb.torchscript` | ✅ | `imgsz`, `optimize` |
+| [ONNX](https://onnx.ai/) | `onnx` | `yolov8n-obb.onnx` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `opset` |
+| [OpenVINO](https://docs.openvino.ai/latest/index.html) | `openvino` | `yolov8n-obb_openvino_model/` | ✅ | `imgsz`, `half` |
+| [TensorRT](https://developer.nvidia.com/tensorrt) | `engine` | `yolov8n-obb.engine` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `workspace` |
+| [CoreML](https://github.com/apple/coremltools) | `coreml` | `yolov8n-obb.mlpackage` | ✅ | `imgsz`, `half`, `int8`, `nms` |
+| [TF SavedModel](https://www.tensorflow.org/guide/saved_model) | `saved_model` | `yolov8n-obb_saved_model/` | ✅ | `imgsz`, `keras` |
+| [TF GraphDef](https://www.tensorflow.org/api_docs/python/tf/Graph) | `pb` | `yolov8n-obb.pb` | ❌ | `imgsz` |
+| [TF Lite](https://www.tensorflow.org/lite) | `tflite` | `yolov8n-obb.tflite` | ✅ | `imgsz`, `half`, `int8` |
+| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n-obb_edgetpu.tflite` | ✅ | `imgsz` |
+| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n-obb_web_model/` | ✅ | `imgsz`, `half`, `int8` |
+| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n-obb_paddle_model/` | ✅ | `imgsz` |
+| [ncnn](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n-obb_ncnn_model/` | ✅ | `imgsz`, `half` |
+
+See full `export` details in the [Export](https://docs.ultralytics.com/modes/export/) page.
diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml
index 61b06003..9401924c 100644
--- a/docs/mkdocs.yml
+++ b/docs/mkdocs.yml
@@ -170,6 +170,7 @@ nav:
- Segment: tasks/segment.md
- Classify: tasks/classify.md
- Pose: tasks/pose.md
+ - Obb: tasks/obb.md
- Guides:
- guides/index.md
- Models:
@@ -203,6 +204,7 @@ nav:
- Segment: tasks/segment.md
- Classify: tasks/classify.md
- Pose: tasks/pose.md
+ - Obb: tasks/obb.md
- Models:
- models/index.md
- YOLOv3: models/yolov3.md
@@ -335,131 +337,137 @@ nav:
- 'iOS': hub/app/ios.md
- 'Android': hub/app/android.md
- Inference API: hub/inference_api.md
+
- Reference:
- - cfg:
- - __init__: reference/cfg/__init__.md
- - data:
- - annotator: reference/data/annotator.md
- - augment: reference/data/augment.md
- - base: reference/data/base.md
- - build: reference/data/build.md
- - converter: reference/data/converter.md
- - dataset: reference/data/dataset.md
- - loaders: reference/data/loaders.md
- - utils: reference/data/utils.md
- - engine:
- - exporter: reference/engine/exporter.md
- - model: reference/engine/model.md
- - predictor: reference/engine/predictor.md
- - results: reference/engine/results.md
- - trainer: reference/engine/trainer.md
- - tuner: reference/engine/tuner.md
- - validator: reference/engine/validator.md
- - hub:
- - __init__: reference/hub/__init__.md
- - auth: reference/hub/auth.md
- - session: reference/hub/session.md
- - utils: reference/hub/utils.md
- - models:
- - fastsam:
- - model: reference/models/fastsam/model.md
- - predict: reference/models/fastsam/predict.md
- - prompt: reference/models/fastsam/prompt.md
- - utils: reference/models/fastsam/utils.md
- - val: reference/models/fastsam/val.md
- - nas:
- - model: reference/models/nas/model.md
- - predict: reference/models/nas/predict.md
- - val: reference/models/nas/val.md
- - rtdetr:
- - model: reference/models/rtdetr/model.md
- - predict: reference/models/rtdetr/predict.md
- - train: reference/models/rtdetr/train.md
- - val: reference/models/rtdetr/val.md
- - sam:
- - amg: reference/models/sam/amg.md
- - build: reference/models/sam/build.md
- - model: reference/models/sam/model.md
- - modules:
- - decoders: reference/models/sam/modules/decoders.md
- - encoders: reference/models/sam/modules/encoders.md
- - sam: reference/models/sam/modules/sam.md
- - tiny_encoder: reference/models/sam/modules/tiny_encoder.md
- - transformer: reference/models/sam/modules/transformer.md
- - predict: reference/models/sam/predict.md
- - utils:
- - loss: reference/models/utils/loss.md
- - ops: reference/models/utils/ops.md
- - yolo:
- - classify:
- - predict: reference/models/yolo/classify/predict.md
- - train: reference/models/yolo/classify/train.md
- - val: reference/models/yolo/classify/val.md
- - detect:
- - predict: reference/models/yolo/detect/predict.md
- - train: reference/models/yolo/detect/train.md
- - val: reference/models/yolo/detect/val.md
- - model: reference/models/yolo/model.md
- - pose:
- - predict: reference/models/yolo/pose/predict.md
- - train: reference/models/yolo/pose/train.md
- - val: reference/models/yolo/pose/val.md
- - segment:
- - predict: reference/models/yolo/segment/predict.md
- - train: reference/models/yolo/segment/train.md
- - val: reference/models/yolo/segment/val.md
- - nn:
- - autobackend: reference/nn/autobackend.md
- - modules:
- - block: reference/nn/modules/block.md
- - conv: reference/nn/modules/conv.md
- - head: reference/nn/modules/head.md
- - transformer: reference/nn/modules/transformer.md
- - utils: reference/nn/modules/utils.md
- - tasks: reference/nn/tasks.md
- - solutions:
- - ai_gym: reference/solutions/ai_gym.md
- - heatmap: reference/solutions/heatmap.md
- - object_counter: reference/solutions/object_counter.md
- - trackers:
- - basetrack: reference/trackers/basetrack.md
- - bot_sort: reference/trackers/bot_sort.md
- - byte_tracker: reference/trackers/byte_tracker.md
- - track: reference/trackers/track.md
- - utils:
- - gmc: reference/trackers/utils/gmc.md
- - kalman_filter: reference/trackers/utils/kalman_filter.md
- - matching: reference/trackers/utils/matching.md
+ - cfg:
+ - __init__: reference/cfg/__init__.md
+ - data:
+ - annotator: reference/data/annotator.md
+ - augment: reference/data/augment.md
+ - base: reference/data/base.md
+ - build: reference/data/build.md
+ - converter: reference/data/converter.md
+ - dataset: reference/data/dataset.md
+ - loaders: reference/data/loaders.md
+ - split_dota: reference/data/split_dota.md
+ - utils: reference/data/utils.md
+ - engine:
+ - exporter: reference/engine/exporter.md
+ - model: reference/engine/model.md
+ - predictor: reference/engine/predictor.md
+ - results: reference/engine/results.md
+ - trainer: reference/engine/trainer.md
+ - tuner: reference/engine/tuner.md
+ - validator: reference/engine/validator.md
+ - hub:
+ - __init__: reference/hub/__init__.md
+ - auth: reference/hub/auth.md
+ - session: reference/hub/session.md
+ - utils: reference/hub/utils.md
+ - models:
+ - fastsam:
+ - model: reference/models/fastsam/model.md
+ - predict: reference/models/fastsam/predict.md
+ - prompt: reference/models/fastsam/prompt.md
+ - utils: reference/models/fastsam/utils.md
+ - val: reference/models/fastsam/val.md
+ - nas:
+ - model: reference/models/nas/model.md
+ - predict: reference/models/nas/predict.md
+ - val: reference/models/nas/val.md
+ - rtdetr:
+ - model: reference/models/rtdetr/model.md
+ - predict: reference/models/rtdetr/predict.md
+ - train: reference/models/rtdetr/train.md
+ - val: reference/models/rtdetr/val.md
+ - sam:
+ - amg: reference/models/sam/amg.md
+ - build: reference/models/sam/build.md
+ - model: reference/models/sam/model.md
+ - modules:
+ - decoders: reference/models/sam/modules/decoders.md
+ - encoders: reference/models/sam/modules/encoders.md
+ - sam: reference/models/sam/modules/sam.md
+ - tiny_encoder: reference/models/sam/modules/tiny_encoder.md
+ - transformer: reference/models/sam/modules/transformer.md
+ - predict: reference/models/sam/predict.md
- utils:
- - __init__: reference/utils/__init__.md
- - autobatch: reference/utils/autobatch.md
- - benchmarks: reference/utils/benchmarks.md
- - callbacks:
- - base: reference/utils/callbacks/base.md
- - clearml: reference/utils/callbacks/clearml.md
- - comet: reference/utils/callbacks/comet.md
- - dvc: reference/utils/callbacks/dvc.md
- - hub: reference/utils/callbacks/hub.md
- - mlflow: reference/utils/callbacks/mlflow.md
- - neptune: reference/utils/callbacks/neptune.md
- - raytune: reference/utils/callbacks/raytune.md
- - tensorboard: reference/utils/callbacks/tensorboard.md
- - wb: reference/utils/callbacks/wb.md
- - checks: reference/utils/checks.md
- - dist: reference/utils/dist.md
- - downloads: reference/utils/downloads.md
- - errors: reference/utils/errors.md
- - files: reference/utils/files.md
- - instance: reference/utils/instance.md
- - loss: reference/utils/loss.md
- - metrics: reference/utils/metrics.md
- - ops: reference/utils/ops.md
- - patches: reference/utils/patches.md
- - plotting: reference/utils/plotting.md
- - tal: reference/utils/tal.md
- - torch_utils: reference/utils/torch_utils.md
- - triton: reference/utils/triton.md
- - tuner: reference/utils/tuner.md
+ - loss: reference/models/utils/loss.md
+ - ops: reference/models/utils/ops.md
+ - yolo:
+ - classify:
+ - predict: reference/models/yolo/classify/predict.md
+ - train: reference/models/yolo/classify/train.md
+ - val: reference/models/yolo/classify/val.md
+ - detect:
+ - predict: reference/models/yolo/detect/predict.md
+ - train: reference/models/yolo/detect/train.md
+ - val: reference/models/yolo/detect/val.md
+ - model: reference/models/yolo/model.md
+ - obb:
+ - predict: reference/models/yolo/obb/predict.md
+ - train: reference/models/yolo/obb/train.md
+ - val: reference/models/yolo/obb/val.md
+ - pose:
+ - predict: reference/models/yolo/pose/predict.md
+ - train: reference/models/yolo/pose/train.md
+ - val: reference/models/yolo/pose/val.md
+ - segment:
+ - predict: reference/models/yolo/segment/predict.md
+ - train: reference/models/yolo/segment/train.md
+ - val: reference/models/yolo/segment/val.md
+ - nn:
+ - autobackend: reference/nn/autobackend.md
+ - modules:
+ - block: reference/nn/modules/block.md
+ - conv: reference/nn/modules/conv.md
+ - head: reference/nn/modules/head.md
+ - transformer: reference/nn/modules/transformer.md
+ - utils: reference/nn/modules/utils.md
+ - tasks: reference/nn/tasks.md
+ - solutions:
+ - ai_gym: reference/solutions/ai_gym.md
+ - heatmap: reference/solutions/heatmap.md
+ - object_counter: reference/solutions/object_counter.md
+ - trackers:
+ - basetrack: reference/trackers/basetrack.md
+ - bot_sort: reference/trackers/bot_sort.md
+ - byte_tracker: reference/trackers/byte_tracker.md
+ - track: reference/trackers/track.md
+ - utils:
+ - gmc: reference/trackers/utils/gmc.md
+ - kalman_filter: reference/trackers/utils/kalman_filter.md
+ - matching: reference/trackers/utils/matching.md
+ - utils:
+ - __init__: reference/utils/__init__.md
+ - autobatch: reference/utils/autobatch.md
+ - benchmarks: reference/utils/benchmarks.md
+ - callbacks:
+ - base: reference/utils/callbacks/base.md
+ - clearml: reference/utils/callbacks/clearml.md
+ - comet: reference/utils/callbacks/comet.md
+ - dvc: reference/utils/callbacks/dvc.md
+ - hub: reference/utils/callbacks/hub.md
+ - mlflow: reference/utils/callbacks/mlflow.md
+ - neptune: reference/utils/callbacks/neptune.md
+ - raytune: reference/utils/callbacks/raytune.md
+ - tensorboard: reference/utils/callbacks/tensorboard.md
+ - wb: reference/utils/callbacks/wb.md
+ - checks: reference/utils/checks.md
+ - dist: reference/utils/dist.md
+ - downloads: reference/utils/downloads.md
+ - errors: reference/utils/errors.md
+ - files: reference/utils/files.md
+ - instance: reference/utils/instance.md
+ - loss: reference/utils/loss.md
+ - metrics: reference/utils/metrics.md
+ - ops: reference/utils/ops.md
+ - patches: reference/utils/patches.md
+ - plotting: reference/utils/plotting.md
+ - tal: reference/utils/tal.md
+ - torch_utils: reference/utils/torch_utils.md
+ - triton: reference/utils/triton.md
+ - tuner: reference/utils/tuner.md
- Help:
- Help: help/index.md
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 5063086a..56709178 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = '8.0.234'
+__version__ = '8.0.235'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM
diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py
index d892b519..04984e27 100644
--- a/ultralytics/cfg/__init__.py
+++ b/ultralytics/cfg/__init__.py
@@ -13,18 +13,25 @@ from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CF
# Define valid tasks and modes
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
-TASKS = 'detect', 'segment', 'classify', 'pose'
-TASK2DATA = {'detect': 'coco8.yaml', 'segment': 'coco8-seg.yaml', 'classify': 'imagenet10', 'pose': 'coco8-pose.yaml'}
+TASKS = 'detect', 'segment', 'classify', 'pose', 'obb'
+TASK2DATA = {
+ 'detect': 'coco8.yaml',
+ 'segment': 'coco8-seg.yaml',
+ 'classify': 'imagenet10',
+ 'pose': 'coco8-pose.yaml',
+ 'obb': 'dota8-obb.yaml'} # not implemented yet
TASK2MODEL = {
'detect': 'yolov8n.pt',
'segment': 'yolov8n-seg.pt',
'classify': 'yolov8n-cls.pt',
- 'pose': 'yolov8n-pose.pt'}
+ 'pose': 'yolov8n-pose.pt',
+ 'obb': 'yolov8n-obb.pt'}
TASK2METRIC = {
'detect': 'metrics/mAP50-95(B)',
'segment': 'metrics/mAP50-95(M)',
'classify': 'metrics/accuracy_top1',
- 'pose': 'metrics/mAP50-95(P)'}
+ 'pose': 'metrics/mAP50-95(P)',
+ 'obb': 'metrics/mAP50-95(OBB)'}
CLI_HELP_MSG = \
f"""
@@ -72,7 +79,7 @@ CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic'
CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val',
'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop',
'save_frames', 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks',
- 'show_boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile')
+ 'show_boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile', 'multi_scale')
def cfg2dict(cfg):
diff --git a/ultralytics/cfg/datasets/DOTAv2.yaml b/ultralytics/cfg/datasets/DOTAv1.5.yaml
similarity index 84%
rename from ultralytics/cfg/datasets/DOTAv2.yaml
rename to ultralytics/cfg/datasets/DOTAv1.5.yaml
index c663bdd5..39f89a00 100644
--- a/ultralytics/cfg/datasets/DOTAv2.yaml
+++ b/ultralytics/cfg/datasets/DOTAv1.5.yaml
@@ -1,5 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-# DOTA 2.0 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University
+# DOTA 1.5 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University
# Example usage: yolo train model=yolov8n-obb.pt data=DOTAv2.yaml
# parent
# ├── ultralytics
@@ -7,12 +7,12 @@
# └── dota2 ← downloads here (2GB)
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
-path: ../datasets/DOTAv2 # dataset root dir
+path: ../datasets/DOTAv1.5 # dataset root dir
train: images/train # train images (relative to 'path') 1411 images
val: images/val # val images (relative to 'path') 458 images
test: images/test # test images (optional) 937 images
-# Classes for DOTA 2.0
+# Classes for DOTA 1.5
names:
0: plane
1: ship
@@ -30,8 +30,6 @@ names:
13: soccer ball field
14: swimming pool
15: container crane
- 16: airport
- 17: helipad
# Download script/URL (optional)
-download: https://github.com/ultralytics/yolov5/releases/download/v1.0/DOTAv2.zip
+download: https://github.com/ultralytics/yolov5/releases/download/v1.0/DOTAv1.5.zip
diff --git a/ultralytics/cfg/datasets/DOTAv1.yaml b/ultralytics/cfg/datasets/DOTAv1.yaml
new file mode 100644
index 00000000..bdec4925
--- /dev/null
+++ b/ultralytics/cfg/datasets/DOTAv1.yaml
@@ -0,0 +1,34 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+# DOTA 1.0 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University
+# Example usage: yolo train model=yolov8n-obb.pt data=DOTAv2.yaml
+# parent
+# ├── ultralytics
+# └── datasets
+# └── dota2 ← downloads here (2GB)
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+path: ../datasets/DOTAv1 # dataset root dir
+train: images/train # train images (relative to 'path') 1411 images
+val: images/val # val images (relative to 'path') 458 images
+test: images/test # test images (optional) 937 images
+
+# Classes for DOTA 1.0
+names:
+ 0: plane
+ 1: ship
+ 2: storage tank
+ 3: baseball diamond
+ 4: tennis court
+ 5: basketball court
+ 6: ground track field
+ 7: harbor
+ 8: bridge
+ 9: large vehicle
+ 10: small vehicle
+ 11: helicopter
+ 12: roundabout
+ 13: soccer ball field
+ 14: swimming pool
+
+# Download script/URL (optional)
+download: https://github.com/ultralytics/yolov5/releases/download/v1.0/DOTAv1.zip
diff --git a/ultralytics/cfg/default.yaml b/ultralytics/cfg/default.yaml
index ead9735e..a05edb34 100644
--- a/ultralytics/cfg/default.yaml
+++ b/ultralytics/cfg/default.yaml
@@ -34,6 +34,7 @@ amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, Fal
fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)
profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers
freeze: None # (int | list, optional) freeze first n layers, or freeze list of layer indices during training
+multi_scale: False # (bool) Whether to use multi-scale during training
# Segmentation
overlap_mask: True # (bool) masks should overlap during training (segment train only)
mask_ratio: 4 # (int) mask downsample ratio (segment train only)
diff --git a/ultralytics/cfg/models/v8/yolov8-obb.yaml b/ultralytics/cfg/models/v8/yolov8-obb.yaml
new file mode 100644
index 00000000..049b9273
--- /dev/null
+++ b/ultralytics/cfg/models/v8/yolov8-obb.yaml
@@ -0,0 +1,46 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+# YOLOv8 Oriented Bounding Boxes (OBB) model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
+
+# Parameters
+nc: 80 # number of classes
+scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
+ # [depth, width, max_channels]
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
+
+# YOLOv8.0n backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+ - [-1, 1, SPPF, [1024, 5]] # 9
+
+# YOLOv8.0n head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 12
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
+
+ - [[15, 18, 21], 1, OBB, [nc, 1]] # OBB(P3, P4, P5)
diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py
index 73728d73..a1cd1c3c 100644
--- a/ultralytics/data/augment.py
+++ b/ultralytics/data/augment.py
@@ -13,7 +13,7 @@ from ultralytics.utils import LOGGER, colorstr
from ultralytics.utils.checks import check_version
from ultralytics.utils.instance import Instances
from ultralytics.utils.metrics import bbox_ioa
-from ultralytics.utils.ops import segment2box
+from ultralytics.utils.ops import segment2box, xyxyxyxy2xywhr
from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13
from .utils import polygons2masks, polygons2masks_overlap
@@ -485,6 +485,8 @@ class RandomPerspective:
xy = xy[:, :2] / xy[:, 2:3]
segments = xy.reshape(n, -1, 2)
bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0)
+ segments[..., 0] = segments[..., 0].clip(bboxes[:, 0:1], bboxes[:, 2:3])
+ segments[..., 1] = segments[..., 1].clip(bboxes[:, 1:2], bboxes[:, 3:4])
return bboxes, segments
def apply_keypoints(self, keypoints, M):
@@ -891,6 +893,7 @@ class Format:
normalize=True,
return_mask=False,
return_keypoint=False,
+ return_obb=False,
mask_ratio=4,
mask_overlap=True,
batch_idx=True):
@@ -899,6 +902,7 @@ class Format:
self.normalize = normalize
self.return_mask = return_mask # set False when training detection only
self.return_keypoint = return_keypoint
+ self.return_obb = return_obb
self.mask_ratio = mask_ratio
self.mask_overlap = mask_overlap
self.batch_idx = batch_idx # keep the batch indexes
@@ -928,6 +932,9 @@ class Format:
labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
if self.return_keypoint:
labels['keypoints'] = torch.from_numpy(instances.keypoints)
+ if self.return_obb:
+ labels['bboxes'] = xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(
+ instances.segments) else torch.zeros((0, 5))
# Then we can use collate_fn
if self.batch_idx:
labels['batch_idx'] = torch.zeros(nl)
diff --git a/ultralytics/data/build.py b/ultralytics/data/build.py
index 1d961aee..99fed0c4 100644
--- a/ultralytics/data/build.py
+++ b/ultralytics/data/build.py
@@ -89,8 +89,7 @@ def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, str
stride=int(stride),
pad=0.0 if mode == 'train' else 0.5,
prefix=colorstr(f'{mode}: '),
- use_segments=cfg.task == 'segment',
- use_keypoints=cfg.task == 'pose',
+ task=cfg.task,
classes=cfg.classes,
data=data,
fraction=cfg.fraction if mode == 'train' else 1.0)
diff --git a/ultralytics/data/dataset.py b/ultralytics/data/dataset.py
index 538a1ff2..ad0ba56e 100644
--- a/ultralytics/data/dataset.py
+++ b/ultralytics/data/dataset.py
@@ -11,6 +11,7 @@ import torchvision
from PIL import Image
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
+from ultralytics.utils.ops import resample_segments
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
from .base import BaseDataset
@@ -26,17 +27,17 @@ class YOLODataset(BaseDataset):
Args:
data (dict, optional): A dataset YAML dictionary. Defaults to None.
- use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False.
- use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False.
+ task (str): An explicit arg to point current task, Defaults to 'detect'.
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
- def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs):
+ def __init__(self, *args, data=None, task='detect', **kwargs):
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
- self.use_segments = use_segments
- self.use_keypoints = use_keypoints
+ self.use_segments = task == 'segment'
+ self.use_keypoints = task == 'pose'
+ self.use_obb = task == 'obb'
self.data = data
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
super().__init__(*args, **kwargs)
@@ -148,6 +149,7 @@ class YOLODataset(BaseDataset):
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
+ return_obb=self.use_obb,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask))
@@ -165,10 +167,19 @@ class YOLODataset(BaseDataset):
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
# We can make it also support classification and semantic segmentation by add or remove some dict keys there.
bboxes = label.pop('bboxes')
- segments = label.pop('segments')
+ segments = label.pop('segments', [])
keypoints = label.pop('keypoints', None)
bbox_format = label.pop('bbox_format')
normalized = label.pop('normalized')
+
+ # NOTE: do NOT resample oriented boxes
+ segment_resamples = 100 if self.use_obb else 1000
+ if len(segments) > 0:
+ # list[np.array(1000, 2)] * num_samples
+ # (N, 1000, 2)
+ segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
+ else:
+ segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
return label
@@ -182,7 +193,7 @@ class YOLODataset(BaseDataset):
value = values[i]
if k == 'img':
value = torch.stack(value, 0)
- if k in ['masks', 'keypoints', 'bboxes', 'cls']:
+ if k in ['masks', 'keypoints', 'bboxes', 'cls', 'segments', 'obb']:
value = torch.cat(value, 0)
new_batch[k] = value
new_batch['batch_idx'] = list(new_batch['batch_idx'])
diff --git a/ultralytics/data/split_dota.py b/ultralytics/data/split_dota.py
new file mode 100644
index 00000000..9cabd5bc
--- /dev/null
+++ b/ultralytics/data/split_dota.py
@@ -0,0 +1,288 @@
+import itertools
+import os
+from glob import glob
+from math import ceil
+from pathlib import Path
+
+import cv2
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+from ultralytics.data.utils import exif_size, img2label_paths
+from ultralytics.utils.checks import check_requirements
+
+check_requirements('shapely')
+from shapely.geometry import Polygon
+
+
+def bbox_iof(polygon1, bbox2, eps=1e-6):
+ """
+ Calculate iofs between bbox1 and bbox2.
+
+ Args:
+ polygon1 (np.ndarray): Polygon coordinates, (n, 8).
+ bbox2 (np.ndarray): Bounding boxes, (n ,4).
+ """
+ polygon1 = polygon1.reshape(-1, 4, 2)
+ lt_point = np.min(polygon1, axis=-2)
+ rb_point = np.max(polygon1, axis=-2)
+ bbox1 = np.concatenate([lt_point, rb_point], axis=-1)
+
+ lt = np.maximum(bbox1[:, None, :2], bbox2[..., :2])
+ rb = np.minimum(bbox1[:, None, 2:], bbox2[..., 2:])
+ wh = np.clip(rb - lt, 0, np.inf)
+ h_overlaps = wh[..., 0] * wh[..., 1]
+
+ l, t, r, b = (bbox2[..., i] for i in range(4))
+ polygon2 = np.stack([l, t, r, t, r, b, l, b], axis=-1).reshape(-1, 4, 2)
+
+ sg_polys1 = [Polygon(p) for p in polygon1]
+ sg_polys2 = [Polygon(p) for p in polygon2]
+ overlaps = np.zeros(h_overlaps.shape)
+ for p in zip(*np.nonzero(h_overlaps)):
+ overlaps[p] = sg_polys1[p[0]].intersection(sg_polys2[p[-1]]).area
+ unions = np.array([p.area for p in sg_polys1], dtype=np.float32)
+ unions = unions[..., None]
+
+ unions = np.clip(unions, eps, np.inf)
+ outputs = overlaps / unions
+ if outputs.ndim == 1:
+ outputs = outputs[..., None]
+ return outputs
+
+
+def load_yolo_dota(data_root, split='train'):
+ """Load DOTA dataset.
+ Args:
+ data_root (str): Data root.
+ split (str): The split data set, could be train or val.
+ Notes:
+ The directory structure assumed for the DOTA dataset:
+ - data_root
+ - images
+ - train
+ - val
+ - labels
+ - train
+ - val
+ """
+ assert split in ['train', 'val']
+ im_dir = os.path.join(data_root, f'images/{split}')
+ assert Path(im_dir).exists(), f"Can't find {im_dir}, please check your data root."
+ im_files = glob(os.path.join(data_root, f'images/{split}/*'))
+ lb_files = img2label_paths(im_files)
+ annos = []
+ for im_file, lb_file in zip(im_files, lb_files):
+ w, h = exif_size(Image.open(im_file))
+ with open(lb_file) as f:
+ lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
+ lb = np.array(lb, dtype=np.float32)
+ annos.append(dict(ori_size=(h, w), label=lb, filepath=im_file))
+ return annos
+
+
+def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.01):
+ """
+ Get the coordinates of windows.
+
+ Args:
+ im_size (tuple): Original image size, (h, w).
+ crop_sizes (List(int)): Crop size of windows.
+ gaps (List(int)): Gap between each crops.
+ im_rate_thr (float): Threshold of windows areas divided by image ares.
+ """
+ h, w = im_size
+ windows = []
+ for crop_size, gap in zip(crop_sizes, gaps):
+ assert crop_size > gap, f'invaild crop_size gap pair [{crop_size} {gap}]'
+ step = crop_size - gap
+
+ xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
+ xs = [step * i for i in range(xn)]
+ if len(xs) > 1 and xs[-1] + crop_size > w:
+ xs[-1] = w - crop_size
+
+ yn = 1 if h <= crop_size else ceil((h - crop_size) / step + 1)
+ ys = [step * i for i in range(yn)]
+ if len(ys) > 1 and ys[-1] + crop_size > h:
+ ys[-1] = h - crop_size
+
+ start = np.array(list(itertools.product(xs, ys)), dtype=np.int64)
+ stop = start + crop_size
+ windows.append(np.concatenate([start, stop], axis=1))
+ windows = np.concatenate(windows, axis=0)
+
+ im_in_wins = windows.copy()
+ im_in_wins[:, 0::2] = np.clip(im_in_wins[:, 0::2], 0, w)
+ im_in_wins[:, 1::2] = np.clip(im_in_wins[:, 1::2], 0, h)
+ im_areas = (im_in_wins[:, 2] - im_in_wins[:, 0]) * (im_in_wins[:, 3] - im_in_wins[:, 1])
+ win_areas = (windows[:, 2] - windows[:, 0]) * (windows[:, 3] - windows[:, 1])
+ im_rates = im_areas / win_areas
+ if not (im_rates > im_rate_thr).any():
+ max_rate = im_rates.max()
+ im_rates[abs(im_rates - max_rate) < eps] = 1
+ return windows[im_rates > im_rate_thr]
+
+
+def get_window_obj(anno, windows, iof_thr=0.7):
+ """Get objects for each window."""
+ h, w = anno['ori_size']
+ label = anno['label']
+ if len(label):
+ label[:, 1::2] *= w
+ label[:, 2::2] *= h
+ iofs = bbox_iof(label[:, 1:], windows)
+ # unnormalized and misaligned coordinates
+ window_anns = [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))]
+ else:
+ window_anns = [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))]
+ return window_anns
+
+
+def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
+ """Crop images and save new labels.
+ Args:
+ anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys.
+ windows (list): A list of windows coordinates.
+ window_objs (list): A list of labels inside each window.
+ im_dir (str): The output directory path of images.
+ lb_dir (str): The output directory path of labels.
+ Notes:
+ The directory structure assumed for the DOTA dataset:
+ - data_root
+ - images
+ - train
+ - val
+ - labels
+ - train
+ - val
+ """
+ im = cv2.imread(anno['filepath'])
+ name = Path(anno['filepath']).stem
+ for i, window in enumerate(windows):
+ x_start, y_start, x_stop, y_stop = window.tolist()
+ new_name = name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start)
+ patch_im = im[y_start:y_stop, x_start:x_stop]
+ ph, pw = patch_im.shape[:2]
+
+ cv2.imwrite(os.path.join(im_dir, f'{new_name}.jpg'), patch_im)
+ label = window_objs[i]
+ if len(label) == 0:
+ continue
+ label[:, 1::2] -= x_start
+ label[:, 2::2] -= y_start
+ label[:, 1::2] /= pw
+ label[:, 2::2] /= ph
+
+ with open(os.path.join(lb_dir, f'{new_name}.txt'), 'w') as f:
+ for lb in label:
+ formatted_coords = ['{:.6g}'.format(coord) for coord in lb[1:]]
+ f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")
+
+
+def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024], gaps=[200]):
+ """
+ Split both images and labels.
+
+ NOTES:
+ The directory structure assumed for the DOTA dataset:
+ - data_root
+ - images
+ - split
+ - labels
+ - split
+ and the output directory structure is:
+ - save_dir
+ - images
+ - split
+ - labels
+ - split
+ """
+ im_dir = Path(save_dir) / 'images' / split
+ im_dir.mkdir(parents=True, exist_ok=True)
+ lb_dir = Path(save_dir) / 'labels' / split
+ lb_dir.mkdir(parents=True, exist_ok=True)
+
+ annos = load_yolo_dota(data_root, split=split)
+ for anno in tqdm(annos, total=len(annos), desc=split):
+ windows = get_windows(anno['ori_size'], crop_sizes, gaps)
+ window_objs = get_window_obj(anno, windows)
+ crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
+
+
+def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
+ """
+ Split train and val set of DOTA.
+
+ NOTES:
+ The directory structure assumed for the DOTA dataset:
+ - data_root
+ - images
+ - train
+ - val
+ - labels
+ - train
+ - val
+ and the output directory structure is:
+ - save_dir
+ - images
+ - train
+ - val
+ - labels
+ - train
+ - val
+ """
+ crop_sizes, gaps = [], []
+ for r in rates:
+ crop_sizes.append(int(crop_size / r))
+ gaps.append(int(gap / r))
+ for split in ['train', 'val']:
+ split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)
+
+
+def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
+ """
+ Split test set of DOTA, labels are not included within this set.
+
+ NOTES:
+ The directory structure assumed for the DOTA dataset:
+ - data_root
+ - images
+ - test
+ and the output directory structure is:
+ - save_dir
+ - images
+ - test
+ """
+ crop_sizes, gaps = [], []
+ for r in rates:
+ crop_sizes.append(int(crop_size / r))
+ gaps.append(int(gap / r))
+ save_dir = Path(save_dir) / 'images' / 'test'
+ save_dir.mkdir(parents=True, exist_ok=True)
+
+ im_dir = Path(os.path.join(data_root, 'images/test'))
+ assert im_dir.exists(), f"Can't find {str(im_dir)}, please check your data root."
+ im_files = glob(str(im_dir / '*'))
+ for im_file in tqdm(im_files, total=len(im_files), desc='test'):
+ w, h = exif_size(Image.open(im_file))
+ windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
+ im = cv2.imread(im_file)
+ name = Path(im_file).stem
+ for window in windows:
+ x_start, y_start, x_stop, y_stop = window.tolist()
+ new_name = (name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start))
+ patch_im = im[y_start:y_stop, x_start:x_stop]
+ cv2.imwrite(os.path.join(str(save_dir), f'{new_name}.jpg'), patch_im)
+
+
+if __name__ == '__main__':
+ split_trainval(
+ data_root='DOTAv2',
+ save_dir='DOTAv2-split',
+ )
+ split_test(
+ data_root='DOTAv2',
+ save_dir='DOTAv2-split',
+ )
diff --git a/ultralytics/data/utils.py b/ultralytics/data/utils.py
index a0b6f3b8..c3447394 100644
--- a/ultralytics/data/utils.py
+++ b/ultralytics/data/utils.py
@@ -516,10 +516,7 @@ class HUBDatasetStats:
else:
from ultralytics.data import YOLODataset
- dataset = YOLODataset(img_path=self.data[split],
- data=self.data,
- use_segments=self.task == 'segment',
- use_keypoints=self.task == 'pose')
+ dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
x = np.array([
np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
for label in TQDM(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py
index 3feabbc1..4444ef1c 100644
--- a/ultralytics/engine/results.py
+++ b/ultralytics/engine/results.py
@@ -89,7 +89,7 @@ class Results(SimpleClass):
_keys (tuple): A tuple of attribute names for non-empty attributes.
"""
- def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None:
+ def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None, obb=None) -> None:
"""Initialize the Results class."""
self.orig_img = orig_img
self.orig_shape = orig_img.shape[:2]
@@ -97,11 +97,12 @@ class Results(SimpleClass):
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
self.probs = Probs(probs) if probs is not None else None
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
+ self.obb = OBB(obb, self.orig_shape) if obb is not None else None
self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
self.names = names
self.path = path
self.save_dir = None
- self._keys = 'boxes', 'masks', 'probs', 'keypoints'
+ self._keys = 'boxes', 'masks', 'probs', 'keypoints', 'obb'
def __getitem__(self, idx):
"""Return a Results object for the specified index."""
@@ -218,7 +219,8 @@ class Results(SimpleClass):
img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy()
names = self.names
- pred_boxes, show_boxes = self.boxes, boxes
+ is_obb = self.obb is not None
+ pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes
pred_masks, show_masks = self.masks, masks
pred_probs, show_probs = self.probs, probs
annotator = Annotator(
@@ -239,12 +241,13 @@ class Results(SimpleClass):
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
# Plot Detect results
- if pred_boxes and show_boxes:
+ if pred_boxes is not None and show_boxes:
for d in reversed(pred_boxes):
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
name = ('' if id is None else f'id:{id} ') + names[c]
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
- annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
+ box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze()
+ annotator.box_label(box, label, color=colors(c, True), rotated=is_obb)
# Plot Classify results
if pred_probs is not None and show_probs:
@@ -390,7 +393,7 @@ class Boxes(BaseTensor):
if boxes.ndim == 1:
boxes = boxes[None, :]
n = boxes.shape[-1]
- assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, track_id, conf, cls
+ assert n in (6, 7), f'expected 6 or 7 values but got {n}' # xyxy, track_id, conf, cls
super().__init__(boxes, orig_shape)
self.is_track = n == 7
self.orig_shape = orig_shape
@@ -571,3 +574,77 @@ class Probs(BaseTensor):
def top5conf(self):
"""Return the confidences of top 5."""
return self.data[self.top5]
+
+
+class OBB(BaseTensor):
+ """
+ A class for storing and manipulating Oriented Bounding Boxes (OBB).
+
+ Args:
+ boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes,
+ with shape (num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values.
+ If present, the third last column contains track IDs, and the fifth column from the left contains rotation.
+ orig_shape (tuple): Original image size, in the format (height, width).
+
+ Attributes:
+ xywhr (torch.Tensor | numpy.ndarray): The boxes in [x_center, y_center, width, height, rotation] format.
+ conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes.
+ cls (torch.Tensor | numpy.ndarray): The class values of the boxes.
+ id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available).
+ xyxyxyxy (torch.Tensor | numpy.ndarray): The boxes in xyxyxyxy format normalized by original image size.
+ data (torch.Tensor): The raw OBB tensor (alias for `boxes`).
+
+ Methods:
+ cpu(): Move the object to CPU memory.
+ numpy(): Convert the object to a numpy array.
+ cuda(): Move the object to CUDA memory.
+ to(*args, **kwargs): Move the object to the specified device.
+ """
+
+ def __init__(self, boxes, orig_shape) -> None:
+ """Initialize the Boxes class."""
+ if boxes.ndim == 1:
+ boxes = boxes[None, :]
+ n = boxes.shape[-1]
+ assert n in (7, 8), f'expected 7 or 8 values but got {n}' # xywh, rotation, track_id, conf, cls
+ super().__init__(boxes, orig_shape)
+ self.is_track = n == 8
+ self.orig_shape = orig_shape
+
+ @property
+ def xywhr(self):
+ """Return the rotated boxes in xywhr format."""
+ return self.data[:, :5]
+
+ @property
+ def conf(self):
+ """Return the confidence values of the boxes."""
+ return self.data[:, -2]
+
+ @property
+ def cls(self):
+ """Return the class values of the boxes."""
+ return self.data[:, -1]
+
+ @property
+ def id(self):
+ """Return the track IDs of the boxes (if available)."""
+ return self.data[:, -3] if self.is_track else None
+
+ @property
+ @lru_cache(maxsize=2)
+ def xyxyxyxy(self):
+ """Return the boxes in xyxyxyxy format, (N, 4, 2)."""
+ return ops.xywhr2xyxyxyxy(self.xywhr)
+
+ @property
+ @lru_cache(maxsize=2)
+ def xyxy(self):
+ """Return the horizontal boxes in xyxy format, (N, 4)."""
+ # This way to fit both torch and numpy version
+ x1 = self.xyxyxyxy[..., 0].min(1).values
+ x2 = self.xyxyxyxy[..., 0].max(1).values
+ y1 = self.xyxyxyxy[..., 1].min(1).values
+ y2 = self.xyxyxyxy[..., 1].max(1).values
+ xyxy = [x1, y1, x2, y2]
+ return np.stack(xyxy, axis=-1) if isinstance(self.data, np.ndarray) else torch.stack(xyxy, dim=-1)
diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py
index 7a458453..57fa03af 100644
--- a/ultralytics/engine/trainer.py
+++ b/ultralytics/engine/trainer.py
@@ -249,6 +249,7 @@ class BaseTrainer:
# Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
+ self.stride = gs # for multi-scale training
# Batch size
if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
@@ -258,7 +259,11 @@ class BaseTrainer:
batch_size = self.batch_size // max(world_size, 1)
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
if RANK in (-1, 0):
- self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
+ # NOTE: When training DOTA dataset, double batch size could get OOM cause some images got more than 2000 objects.
+ self.test_loader = self.get_dataloader(self.testset,
+ batch_size=batch_size if self.args.task == 'obb' else batch_size * 2,
+ rank=-1,
+ mode='val')
self.validator = self.get_validator()
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
diff --git a/ultralytics/models/rtdetr/val.py b/ultralytics/models/rtdetr/val.py
index 468026be..f96d206a 100644
--- a/ultralytics/models/rtdetr/val.py
+++ b/ultralytics/models/rtdetr/val.py
@@ -1,7 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-from pathlib import Path
-
import torch
from ultralytics.data import YOLODataset
@@ -22,7 +20,7 @@ class RTDETRDataset(YOLODataset):
def __init__(self, *args, data=None, **kwargs):
"""Initialize the RTDETRDataset class by inheriting from the YOLODataset class."""
- super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs)
+ super().__init__(*args, data=data, **kwargs)
# NOTE: add stretch version load_image for RTDETR mosaic
def load_image(self, i, rect_mode=False):
@@ -108,47 +106,22 @@ class RTDETRValidator(DetectionValidator):
return outputs
- def update_metrics(self, preds, batch):
- """Metrics."""
- for si, pred in enumerate(preds):
- idx = batch['batch_idx'] == si
- cls = batch['cls'][idx]
- bbox = batch['bboxes'][idx]
- nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
- shape = batch['ori_shape'][si]
- correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
- self.seen += 1
+ def _prepare_batch(self, si, batch):
+ idx = batch['batch_idx'] == si
+ cls = batch['cls'][idx].squeeze(-1)
+ bbox = batch['bboxes'][idx]
+ ori_shape = batch['ori_shape'][si]
+ imgsz = batch['img'].shape[2:]
+ ratio_pad = batch['ratio_pad'][si]
+ if len(cls):
+ bbox = ops.xywh2xyxy(bbox) # target boxes
+ bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
+ bbox[..., [1, 3]] *= ori_shape[0] # native-space pred
+ prepared_batch = dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
+ return prepared_batch
- if npr == 0:
- if nl:
- self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
- if self.args.plots:
- self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
- continue
-
- # Predictions
- if self.args.single_cls:
- pred[:, 5] = 0
- predn = pred.clone()
- predn[..., [0, 2]] *= shape[1] / self.args.imgsz # native-space pred
- predn[..., [1, 3]] *= shape[0] / self.args.imgsz # native-space pred
-
- # Evaluate
- if nl:
- tbox = ops.xywh2xyxy(bbox) # target boxes
- tbox[..., [0, 2]] *= shape[1] # native-space pred
- tbox[..., [1, 3]] *= shape[0] # native-space pred
- labelsn = torch.cat((cls, tbox), 1) # native-space labels
- # NOTE: To get correct metrics, the inputs of `_process_batch` should always be float32 type.
- correct_bboxes = self._process_batch(predn.float(), labelsn)
- # TODO: maybe remove these `self.` arguments as they already are member variable
- if self.args.plots:
- self.confusion_matrix.process_batch(predn, labelsn)
- self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
-
- # Save
- if self.args.save_json:
- self.pred_to_json(predn, batch['im_file'][si])
- if self.args.save_txt:
- file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt'
- self.save_one_txt(predn, self.args.save_conf, shape, file)
+ def _prepare_pred(self, pred, pbatch):
+ predn = pred.clone()
+ predn[..., [0, 2]] *= pbatch['ori_shape'][1] / self.args.imgsz # native-space pred
+ predn[..., [1, 3]] *= pbatch['ori_shape'][0] / self.args.imgsz # native-space pred
+ return predn.float()
diff --git a/ultralytics/models/yolo/__init__.py b/ultralytics/models/yolo/__init__.py
index c66e3762..602307b8 100644
--- a/ultralytics/models/yolo/__init__.py
+++ b/ultralytics/models/yolo/__init__.py
@@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-from ultralytics.models.yolo import classify, detect, pose, segment
+from ultralytics.models.yolo import classify, detect, obb, pose, segment
from .model import YOLO
-__all__ = 'classify', 'segment', 'detect', 'pose', 'YOLO'
+__all__ = 'classify', 'segment', 'detect', 'pose', 'obb', 'YOLO'
diff --git a/ultralytics/models/yolo/detect/train.py b/ultralytics/models/yolo/detect/train.py
index d0028c6e..5cfaa9f4 100644
--- a/ultralytics/models/yolo/detect/train.py
+++ b/ultralytics/models/yolo/detect/train.py
@@ -1,8 +1,11 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
+import math
+import random
from copy import copy
import numpy as np
+import torch.nn as nn
from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
@@ -54,6 +57,16 @@ class DetectionTrainer(BaseTrainer):
def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
+ if self.args.multi_scale:
+ imgs = batch['img']
+ sz = (random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride) // self.stride *
+ self.stride) # size
+ sf = sz / max(imgs.shape[2:]) # scale factor
+ if sf != 1:
+ ns = [math.ceil(x * sf / self.stride) * self.stride
+ for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
+ imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
+ batch['img'] = imgs
return batch
def set_model_attributes(self):
diff --git a/ultralytics/models/yolo/detect/val.py b/ultralytics/models/yolo/detect/val.py
index 4d439330..e794931b 100644
--- a/ultralytics/models/yolo/detect/val.py
+++ b/ultralytics/models/yolo/detect/val.py
@@ -70,7 +70,7 @@ class DetectionValidator(BaseValidator):
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
self.seen = 0
self.jdict = []
- self.stats = []
+ self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[])
def get_desc(self):
"""Return a formatted string summarizing class metrics of YOLO model."""
@@ -86,51 +86,68 @@ class DetectionValidator(BaseValidator):
agnostic=self.args.single_cls,
max_det=self.args.max_det)
+ def _prepare_batch(self, si, batch):
+ idx = batch['batch_idx'] == si
+ cls = batch['cls'][idx].squeeze(-1)
+ bbox = batch['bboxes'][idx]
+ ori_shape = batch['ori_shape'][si]
+ imgsz = batch['img'].shape[2:]
+ ratio_pad = batch['ratio_pad'][si]
+ if len(cls):
+ bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
+ ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
+ prepared_batch = dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
+ return prepared_batch
+
+ def _prepare_pred(self, pred, pbatch):
+ predn = pred.clone()
+ ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'],
+ ratio_pad=pbatch['ratio_pad']) # native-space pred
+ return predn
+
def update_metrics(self, preds, batch):
"""Metrics."""
for si, pred in enumerate(preds):
- idx = batch['batch_idx'] == si
- cls = batch['cls'][idx]
- bbox = batch['bboxes'][idx]
- nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
- shape = batch['ori_shape'][si]
- correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
-
+ npr = len(pred)
+ stat = dict(conf=torch.zeros(0, device=self.device),
+ pred_cls=torch.zeros(0, device=self.device),
+ tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
+ pbatch = self._prepare_batch(si, batch)
+ cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
+ nl = len(cls)
+ stat['target_cls'] = cls
if npr == 0:
if nl:
- self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
- if self.args.plots:
- self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
+ for k in self.stats.keys():
+ self.stats[k].append(stat[k])
+ # TODO: obb has not supported confusion_matrix yet.
+ if self.args.plots and self.args.task != 'obb':
+ self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
continue
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
- predn = pred.clone()
- ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
- ratio_pad=batch['ratio_pad'][si]) # native-space pred
+ predn = self._prepare_pred(pred, pbatch)
+ stat['conf'] = predn[:, 4]
+ stat['pred_cls'] = predn[:, 5]
# Evaluate
if nl:
- height, width = batch['img'].shape[2:]
- tbox = ops.xywh2xyxy(bbox) * torch.tensor(
- (width, height, width, height), device=self.device) # target boxes
- ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
- ratio_pad=batch['ratio_pad'][si]) # native-space labels
- labelsn = torch.cat((cls, tbox), 1) # native-space labels
- correct_bboxes = self._process_batch(predn, labelsn)
- # TODO: maybe remove these `self.` arguments as they already are member variable
- if self.args.plots:
- self.confusion_matrix.process_batch(predn, labelsn)
- self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
+ stat['tp'] = self._process_batch(predn, bbox, cls)
+ # TODO: obb has not supported confusion_matrix yet.
+ if self.args.plots and self.args.task != 'obb':
+ self.confusion_matrix.process_batch(predn, bbox, cls)
+ for k in self.stats.keys():
+ self.stats[k].append(stat[k])
# Save
if self.args.save_json:
self.pred_to_json(predn, batch['im_file'][si])
if self.args.save_txt:
file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt'
- self.save_one_txt(predn, self.args.save_conf, shape, file)
+ self.save_one_txt(predn, self.args.save_conf, pbatch['ori_shape'], file)
def finalize_metrics(self, *args, **kwargs):
"""Set final values for metrics speed and confusion matrix."""
@@ -139,10 +156,11 @@ class DetectionValidator(BaseValidator):
def get_stats(self):
"""Returns metrics statistics and results dictionary."""
- stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
- if len(stats) and stats[0].any():
- self.metrics.process(*stats)
- self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class
+ stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
+ if len(stats) and stats['tp'].any():
+ self.metrics.process(**stats)
+ self.nt_per_class = np.bincount(stats['target_cls'].astype(int),
+ minlength=self.nc) # number of targets per class
return self.metrics.results_dict
def print_results(self):
@@ -165,7 +183,7 @@ class DetectionValidator(BaseValidator):
normalize=normalize,
on_plot=self.on_plot)
- def _process_batch(self, detections, labels):
+ def _process_batch(self, detections, gt_bboxes, gt_cls):
"""
Return correct prediction matrix.
@@ -178,8 +196,8 @@ class DetectionValidator(BaseValidator):
Returns:
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
"""
- iou = box_iou(labels[:, 1:], detections[:, :4])
- return self.match_predictions(detections[:, 5], labels[:, 0], iou)
+ iou = box_iou(gt_bboxes, detections[:, :4])
+ return self.match_predictions(detections[:, 5], gt_cls, iou)
def build_dataset(self, img_path, mode='val', batch=None):
"""
diff --git a/ultralytics/models/yolo/model.py b/ultralytics/models/yolo/model.py
index ef1b41ab..eb2225d8 100644
--- a/ultralytics/models/yolo/model.py
+++ b/ultralytics/models/yolo/model.py
@@ -1,8 +1,8 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.engine.model import Model
-from ultralytics.models import yolo # noqa
-from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel, SegmentationModel
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel
class YOLO(Model):
@@ -31,4 +31,9 @@ class YOLO(Model):
'model': PoseModel,
'trainer': yolo.pose.PoseTrainer,
'validator': yolo.pose.PoseValidator,
- 'predictor': yolo.pose.PosePredictor, }, }
+ 'predictor': yolo.pose.PosePredictor, },
+ 'obb': {
+ 'model': OBBModel,
+ 'trainer': yolo.obb.OBBTrainer,
+ 'validator': yolo.obb.OBBValidator,
+ 'predictor': yolo.obb.OBBPredictor, }, }
diff --git a/ultralytics/models/yolo/obb/__init__.py b/ultralytics/models/yolo/obb/__init__.py
new file mode 100644
index 00000000..09f10481
--- /dev/null
+++ b/ultralytics/models/yolo/obb/__init__.py
@@ -0,0 +1,7 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+from .predict import OBBPredictor
+from .train import OBBTrainer
+from .val import OBBValidator
+
+__all__ = 'OBBPredictor', 'OBBTrainer', 'OBBValidator'
diff --git a/ultralytics/models/yolo/obb/predict.py b/ultralytics/models/yolo/obb/predict.py
new file mode 100644
index 00000000..227a7530
--- /dev/null
+++ b/ultralytics/models/yolo/obb/predict.py
@@ -0,0 +1,51 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+import torch
+
+from ultralytics.engine.results import Results
+from ultralytics.models.yolo.detect.predict import DetectionPredictor
+from ultralytics.utils import DEFAULT_CFG, ops
+
+
+class OBBPredictor(DetectionPredictor):
+ """
+ A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
+
+ Example:
+ ```python
+ from ultralytics.utils import ASSETS
+ from ultralytics.models.yolo.obb import OBBPredictor
+
+ args = dict(model='yolov8n-obb.pt', source=ASSETS)
+ predictor = OBBPredictor(overrides=args)
+ predictor.predict_cli()
+ ```
+ """
+
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+ super().__init__(cfg, overrides, _callbacks)
+ self.args.task = 'obb'
+
+ def postprocess(self, preds, img, orig_imgs):
+ """Post-processes predictions and returns a list of Results objects."""
+ preds = ops.non_max_suppression(preds,
+ self.args.conf,
+ self.args.iou,
+ agnostic=self.args.agnostic_nms,
+ max_det=self.args.max_det,
+ nc=len(self.model.names),
+ classes=self.args.classes,
+ rotated=True)
+
+ if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
+ orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+
+ results = []
+ for i, pred in enumerate(preds):
+ orig_img = orig_imgs[i]
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
+ img_path = self.batch[0][i]
+ # xywh, r, conf, cls
+ obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)
+ results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
+ return results
diff --git a/ultralytics/models/yolo/obb/train.py b/ultralytics/models/yolo/obb/train.py
new file mode 100644
index 00000000..0d1284a7
--- /dev/null
+++ b/ultralytics/models/yolo/obb/train.py
@@ -0,0 +1,42 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+from copy import copy
+
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import OBBModel
+from ultralytics.utils import DEFAULT_CFG, RANK
+
+
+class OBBTrainer(yolo.detect.DetectionTrainer):
+ """
+ A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
+
+ Example:
+ ```python
+ from ultralytics.models.yolo.obb import OBBTrainer
+
+ args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
+ trainer = OBBTrainer(overrides=args)
+ trainer.train()
+ ```
+ """
+
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+ """Initialize a OBBTrainer object with given arguments."""
+ if overrides is None:
+ overrides = {}
+ overrides['task'] = 'obb'
+ super().__init__(cfg, overrides, _callbacks)
+
+ def get_model(self, cfg=None, weights=None, verbose=True):
+ """Return OBBModel initialized with specified config and weights."""
+ model = OBBModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
+ if weights:
+ model.load(weights)
+
+ return model
+
+ def get_validator(self):
+ """Return an instance of OBBValidator for validation of YOLO model."""
+ self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
+ return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
diff --git a/ultralytics/models/yolo/obb/val.py b/ultralytics/models/yolo/obb/val.py
new file mode 100644
index 00000000..e16d75f2
--- /dev/null
+++ b/ultralytics/models/yolo/obb/val.py
@@ -0,0 +1,187 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+from pathlib import Path
+
+import torch
+
+from ultralytics.models.yolo.detect import DetectionValidator
+from ultralytics.utils import LOGGER, ops
+from ultralytics.utils.metrics import OBBMetrics, batch_probiou
+from ultralytics.utils.plotting import output_to_rotated_target, plot_images
+
+
+class OBBValidator(DetectionValidator):
+ """
+ A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
+
+ Example:
+ ```python
+ from ultralytics.models.yolo.obb import OBBValidator
+
+ args = dict(model='yolov8n-obb.pt', data='coco8-seg.yaml')
+ validator = OBBValidator(args=args)
+ validator(model=args['model'])
+ ```
+ """
+
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
+ """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
+ super().__init__(dataloader, save_dir, pbar, args, _callbacks)
+ self.args.task = 'obb'
+ self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
+
+ def init_metrics(self, model):
+ """Initialize evaluation metrics for YOLO."""
+ super().init_metrics(model)
+ val = self.data.get(self.args.split, '') # validation path
+ self.is_dota = isinstance(val, str) and 'DOTA' in val # is COCO
+
+ def postprocess(self, preds):
+ """Apply Non-maximum suppression to prediction outputs."""
+ return ops.non_max_suppression(preds,
+ self.args.conf,
+ self.args.iou,
+ labels=self.lb,
+ nc=self.nc,
+ multi_label=True,
+ agnostic=self.args.single_cls,
+ max_det=self.args.max_det,
+ rotated=True)
+
+ def _process_batch(self, detections, gt_bboxes, gt_cls):
+ """
+ Return correct prediction matrix.
+
+ Args:
+ detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
+ Each detection is of the format: x1, y1, x2, y2, conf, class.
+ labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
+ Each label is of the format: class, x1, y1, x2, y2.
+
+ Returns:
+ (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
+ """
+ iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -2:-1]], dim=-1))
+ return self.match_predictions(detections[:, 5], gt_cls, iou)
+
+ def _prepare_batch(self, si, batch):
+ idx = batch['batch_idx'] == si
+ cls = batch['cls'][idx].squeeze(-1)
+ bbox = batch['bboxes'][idx]
+ ori_shape = batch['ori_shape'][si]
+ imgsz = batch['img'].shape[2:]
+ ratio_pad = batch['ratio_pad'][si]
+ if len(cls):
+ bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
+ ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
+ prepared_batch = dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
+ return prepared_batch
+
+ def _prepare_pred(self, pred, pbatch):
+ predn = pred.clone()
+ ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'],
+ xywh=True) # native-space pred
+ return predn
+
+ def plot_predictions(self, batch, preds, ni):
+ """Plots predicted bounding boxes on input images and saves the result."""
+ plot_images(batch['img'],
+ *output_to_rotated_target(preds, max_det=self.args.max_det),
+ paths=batch['im_file'],
+ fname=self.save_dir / f'val_batch{ni}_pred.jpg',
+ names=self.names,
+ on_plot=self.on_plot) # pred
+
+ def pred_to_json(self, predn, filename):
+ """Serialize YOLO predictions to COCO json format."""
+ stem = Path(filename).stem
+ image_id = int(stem) if stem.isnumeric() else stem
+ rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
+ poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
+ for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
+ self.jdict.append({
+ 'image_id': image_id,
+ 'category_id': self.class_map[int(predn[i, 5].item())],
+ 'score': round(predn[i, 4].item(), 5),
+ 'rbox': [round(x, 3) for x in r],
+ 'poly': [round(x, 3) for x in b]})
+
+ def eval_json(self, stats):
+ """Evaluates YOLO output in JSON format and returns performance statistics."""
+ if self.args.save_json and self.is_dota and len(self.jdict):
+ import json
+ import re
+ from collections import defaultdict
+ pred_json = self.save_dir / 'predictions.json' # predictions
+ pred_txt = self.save_dir / 'predictions_txt' # predictions
+ pred_txt.mkdir(parents=True, exist_ok=True)
+ data = json.load(open(pred_json))
+ # Save split results
+ LOGGER.info(f'Saving predictions with DOTA format to {str(pred_txt)}...')
+ for d in data:
+ image_id = d['image_id']
+ score = d['score']
+ classname = self.names[d['category_id']].replace(' ', '-')
+
+ lines = '{} {} {} {} {} {} {} {} {} {}\n'.format(
+ image_id,
+ score,
+ d['poly'][0],
+ d['poly'][1],
+ d['poly'][2],
+ d['poly'][3],
+ d['poly'][4],
+ d['poly'][5],
+ d['poly'][6],
+ d['poly'][7],
+ )
+ with open(str(pred_txt / f'Task1_{classname}') + '.txt', 'a') as f:
+ f.writelines(lines)
+ # Save merged results, this could result slightly lower map than using official merging script,
+ # because of the probiou calculation.
+ pred_merged_txt = self.save_dir / 'predictions_merged_txt' # predictions
+ pred_merged_txt.mkdir(parents=True, exist_ok=True)
+ merged_results = defaultdict(list)
+ LOGGER.info(f'Saving merged predictions with DOTA format to {str(pred_merged_txt)}...')
+ for d in data:
+ image_id = d['image_id'].split('__')[0]
+ pattern = re.compile(r'\d+___\d+')
+ x, y = (int(c) for c in re.findall(pattern, d['image_id'])[0].split('___'))
+ bbox, score, cls = d['rbox'], d['score'], d['category_id']
+ bbox[0] += x
+ bbox[1] += y
+ bbox.extend([score, cls])
+ merged_results[image_id].append(bbox)
+ for image_id, bbox in merged_results.items():
+ bbox = torch.tensor(bbox)
+ max_wh = torch.max(bbox[:, :2]).item() * 2
+ c = bbox[:, 6:7] * max_wh # classes
+ scores = bbox[:, 5] # scores
+ b = bbox[:, :5].clone()
+ b[:, :2] += c
+ # 0.3 could get results close to the ones from official merging script, even slightly better.
+ i = ops.nms_rotated(b, scores, 0.3)
+ bbox = bbox[i]
+
+ b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
+ for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
+ classname = self.names[int(x[-1])].replace(' ', '-')
+ poly = [round(i, 3) for i in x[:-2]]
+ score = round(x[-2], 3)
+
+ lines = '{} {} {} {} {} {} {} {} {} {}\n'.format(
+ image_id,
+ score,
+ poly[0],
+ poly[1],
+ poly[2],
+ poly[3],
+ poly[4],
+ poly[5],
+ poly[6],
+ poly[7],
+ )
+ with open(str(pred_merged_txt / f'Task1_{classname}') + '.txt', 'a') as f:
+ f.writelines(lines)
+
+ return stats
diff --git a/ultralytics/models/yolo/pose/val.py b/ultralytics/models/yolo/pose/val.py
index b8ebf57e..69d32399 100644
--- a/ultralytics/models/yolo/pose/val.py
+++ b/ultralytics/models/yolo/pose/val.py
@@ -66,57 +66,63 @@ class PoseValidator(DetectionValidator):
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0]
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
+ self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[])
+
+ def _prepare_batch(self, si, batch):
+ pbatch = super()._prepare_batch(si, batch)
+ kpts = batch['keypoints'][batch['batch_idx'] == si]
+ h, w = pbatch['imgsz']
+ kpts = kpts.clone()
+ kpts[..., 0] *= w
+ kpts[..., 1] *= h
+ kpts = ops.scale_coords(pbatch['imgsz'], kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'])
+ pbatch['kpts'] = kpts
+ return pbatch
+
+ def _prepare_pred(self, pred, pbatch):
+ predn = super()._prepare_pred(pred, pbatch)
+ nk = pbatch['kpts'].shape[1]
+ pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
+ ops.scale_coords(pbatch['imgsz'], pred_kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'])
+ return predn, pred_kpts
def update_metrics(self, preds, batch):
"""Metrics."""
for si, pred in enumerate(preds):
- idx = batch['batch_idx'] == si
- cls = batch['cls'][idx]
- bbox = batch['bboxes'][idx]
- kpts = batch['keypoints'][idx]
- nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
- nk = kpts.shape[1] # number of keypoints
- shape = batch['ori_shape'][si]
- correct_kpts = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
- correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
-
+ npr = len(pred)
+ stat = dict(conf=torch.zeros(0, device=self.device),
+ pred_cls=torch.zeros(0, device=self.device),
+ tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+ tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
+ pbatch = self._prepare_batch(si, batch)
+ cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
+ nl = len(cls)
+ stat['target_cls'] = cls
if npr == 0:
if nl:
- self.stats.append((correct_bboxes, correct_kpts, *torch.zeros(
- (2, 0), device=self.device), cls.squeeze(-1)))
+ for k in self.stats.keys():
+ self.stats[k].append(stat[k])
if self.args.plots:
- self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
+ self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
continue
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
- predn = pred.clone()
- ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
- ratio_pad=batch['ratio_pad'][si]) # native-space pred
- pred_kpts = predn[:, 6:].view(npr, nk, -1)
- ops.scale_coords(batch['img'][si].shape[1:], pred_kpts, shape, ratio_pad=batch['ratio_pad'][si])
+ predn, pred_kpts = self._prepare_pred(pred, pbatch)
+ stat['conf'] = predn[:, 4]
+ stat['pred_cls'] = predn[:, 5]
# Evaluate
if nl:
- height, width = batch['img'].shape[2:]
- tbox = ops.xywh2xyxy(bbox) * torch.tensor(
- (width, height, width, height), device=self.device) # target boxes
- ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
- ratio_pad=batch['ratio_pad'][si]) # native-space labels
- tkpts = kpts.clone()
- tkpts[..., 0] *= width
- tkpts[..., 1] *= height
- tkpts = ops.scale_coords(batch['img'][si].shape[1:], tkpts, shape, ratio_pad=batch['ratio_pad'][si])
- labelsn = torch.cat((cls, tbox), 1) # native-space labels
- correct_bboxes = self._process_batch(predn[:, :6], labelsn)
- correct_kpts = self._process_batch(predn[:, :6], labelsn, pred_kpts, tkpts)
+ stat['tp'] = self._process_batch(predn, bbox, cls)
+ stat['tp_p'] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch['kpts'])
if self.args.plots:
- self.confusion_matrix.process_batch(predn, labelsn)
+ self.confusion_matrix.process_batch(predn, bbox, cls)
- # Append correct_masks, correct_boxes, pconf, pcls, tcls
- self.stats.append((correct_bboxes, correct_kpts, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
+ for k in self.stats.keys():
+ self.stats[k].append(stat[k])
# Save
if self.args.save_json:
@@ -124,7 +130,7 @@ class PoseValidator(DetectionValidator):
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
- def _process_batch(self, detections, labels, pred_kpts=None, gt_kpts=None):
+ def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
"""
Return correct prediction matrix.
@@ -142,12 +148,12 @@ class PoseValidator(DetectionValidator):
"""
if pred_kpts is not None and gt_kpts is not None:
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
- area = ops.xyxy2xywh(labels[:, 1:])[:, 2:].prod(1) * 0.53
+ area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53
iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
else: # boxes
- iou = box_iou(labels[:, 1:], detections[:, :4])
+ iou = box_iou(gt_bboxes, detections[:, :4])
- return self.match_predictions(detections[:, 5], labels[:, 0], iou)
+ return self.match_predictions(detections[:, 5], gt_cls, iou)
def plot_val_samples(self, batch, ni):
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
diff --git a/ultralytics/models/yolo/segment/val.py b/ultralytics/models/yolo/segment/val.py
index 599b0d53..b1204ad8 100644
--- a/ultralytics/models/yolo/segment/val.py
+++ b/ultralytics/models/yolo/segment/val.py
@@ -51,6 +51,7 @@ class SegmentationValidator(DetectionValidator):
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
+ self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[])
def get_desc(self):
"""Return a formatted description of evaluation metrics."""
@@ -70,59 +71,62 @@ class SegmentationValidator(DetectionValidator):
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
return p, proto
+ def _prepare_batch(self, si, batch):
+ prepared_batch = super()._prepare_batch(si, batch)
+ midx = [si] if self.args.overlap_mask else batch['batch_idx'] == si
+ prepared_batch['masks'] = batch['masks'][midx]
+ return prepared_batch
+
+ def _prepare_pred(self, pred, pbatch, proto):
+ predn = super()._prepare_pred(pred, pbatch)
+ pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch['imgsz'])
+ return predn, pred_masks
+
def update_metrics(self, preds, batch):
"""Metrics."""
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
- idx = batch['batch_idx'] == si
- cls = batch['cls'][idx]
- bbox = batch['bboxes'][idx]
- nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
- shape = batch['ori_shape'][si]
- correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
- correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
-
+ npr = len(pred)
+ stat = dict(conf=torch.zeros(0, device=self.device),
+ pred_cls=torch.zeros(0, device=self.device),
+ tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+ tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
+ pbatch = self._prepare_batch(si, batch)
+ cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
+ nl = len(cls)
+ stat['target_cls'] = cls
if npr == 0:
if nl:
- self.stats.append((correct_bboxes, correct_masks, *torch.zeros(
- (2, 0), device=self.device), cls.squeeze(-1)))
+ for k in self.stats.keys():
+ self.stats[k].append(stat[k])
if self.args.plots:
- self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
+ self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
continue
# Masks
- midx = [si] if self.args.overlap_mask else idx
- gt_masks = batch['masks'][midx]
- pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
-
+ gt_masks = pbatch.pop('masks')
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
- predn = pred.clone()
- ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
- ratio_pad=batch['ratio_pad'][si]) # native-space pred
+ predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
+ stat['conf'] = predn[:, 4]
+ stat['pred_cls'] = predn[:, 5]
# Evaluate
if nl:
- height, width = batch['img'].shape[2:]
- tbox = ops.xywh2xyxy(bbox) * torch.tensor(
- (width, height, width, height), device=self.device) # target boxes
- ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
- ratio_pad=batch['ratio_pad'][si]) # native-space labels
- labelsn = torch.cat((cls, tbox), 1) # native-space labels
- correct_bboxes = self._process_batch(predn, labelsn)
- # TODO: maybe remove these `self.` arguments as they already are member variable
- correct_masks = self._process_batch(predn,
- labelsn,
- pred_masks,
- gt_masks,
- overlap=self.args.overlap_mask,
- masks=True)
+ stat['tp'] = self._process_batch(predn, bbox, cls)
+ stat['tp_m'] = self._process_batch(predn,
+ bbox,
+ cls,
+ pred_masks,
+ gt_masks,
+ self.args.overlap_mask,
+ masks=True)
if self.args.plots:
- self.confusion_matrix.process_batch(predn, labelsn)
+ self.confusion_matrix.process_batch(predn, bbox, cls)
- # Append correct_masks, correct_boxes, pconf, pcls, tcls
- self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
+ for k in self.stats.keys():
+ self.stats[k].append(stat[k])
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if self.args.plots and self.batch_i < 3:
@@ -131,7 +135,7 @@ class SegmentationValidator(DetectionValidator):
# Save
if self.args.save_json:
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
- shape,
+ pbatch['ori_shape'],
ratio_pad=batch['ratio_pad'][si])
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
# if self.args.save_txt:
@@ -142,7 +146,7 @@ class SegmentationValidator(DetectionValidator):
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
- def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
+ def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
Return correct prediction matrix.
@@ -155,7 +159,7 @@ class SegmentationValidator(DetectionValidator):
"""
if masks:
if overlap:
- nl = len(labels)
+ nl = len(gt_cls)
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
@@ -164,9 +168,9 @@ class SegmentationValidator(DetectionValidator):
gt_masks = gt_masks.gt_(0.5)
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
else: # boxes
- iou = box_iou(labels[:, 1:], detections[:, :4])
+ iou = box_iou(gt_bboxes, detections[:, :4])
- return self.match_predictions(detections[:, 5], labels[:, 0], iou)
+ return self.match_predictions(detections[:, 5], gt_cls, iou)
def plot_val_samples(self, batch, ni):
"""Plots validation samples with bounding box labels."""
@@ -174,7 +178,7 @@ class SegmentationValidator(DetectionValidator):
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
- batch['masks'],
+ masks=batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
diff --git a/ultralytics/nn/modules/__init__.py b/ultralytics/nn/modules/__init__.py
index dfcb0ec9..bbc7a5b1 100644
--- a/ultralytics/nn/modules/__init__.py
+++ b/ultralytics/nn/modules/__init__.py
@@ -21,7 +21,7 @@ from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP,
HGBlock, HGStem, Proto, RepC3, ResNetLayer)
from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
GhostConv, LightConv, RepConv, SpatialAttention)
-from .head import Classify, Detect, Pose, RTDETRDecoder, Segment
+from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment
from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
@@ -30,4 +30,5 @@ __all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d
'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
- 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP', 'ResNetLayer')
+ 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP', 'ResNetLayer',
+ 'OBB')
diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py
index e7c36ba2..35363fe2 100644
--- a/ultralytics/nn/modules/head.py
+++ b/ultralytics/nn/modules/head.py
@@ -7,14 +7,14 @@ import torch
import torch.nn as nn
from torch.nn.init import constant_, xavier_uniform_
-from ultralytics.utils.tal import TORCH_1_10, dist2bbox, make_anchors
+from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
from .block import DFL, Proto
from .conv import Conv
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
from .utils import bias_init_with_prob, linear_init_
-__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'RTDETRDecoder'
+__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'OBB', 'RTDETRDecoder'
class Detect(nn.Module):
@@ -41,22 +41,24 @@ class Detect(nn.Module):
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
- shape = x[0].shape # BCHW
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
- if self.training:
+ if self.training: # Training path
return x
- elif self.dynamic or self.shape != shape:
+
+ # Inference path
+ shape = x[0].shape # BCHW
+ x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
+ if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
- x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
box = x_cat[:, :self.reg_max * 4]
cls = x_cat[:, self.reg_max * 4:]
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
- dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
+ dbox = self.decode_bboxes(box)
if self.export and self.format in ('tflite', 'edgetpu'):
# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
@@ -79,6 +81,10 @@ class Detect(nn.Module):
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
+ def decode_bboxes(self, bboxes):
+ """Decode bounding boxes."""
+ return dist2bbox(self.dfl(bboxes), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
+
class Segment(Detect):
"""YOLOv8 Segment head for segmentation models."""
@@ -106,6 +112,35 @@ class Segment(Detect):
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
+class OBB(Detect):
+ """YOLOv8 OBB detection head for detection with rotation models."""
+
+ def __init__(self, nc=80, ne=1, ch=()):
+ super().__init__(nc, ch)
+ self.ne = ne # number of extra parameters
+ self.detect = Detect.forward
+
+ c4 = max(ch[0] // 4, self.ne)
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
+
+ def forward(self, x):
+ bs = x[0].shape[0] # batch size
+ angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
+ # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
+ angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
+ # angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
+ if not self.training:
+ self.angle = angle
+ x = self.detect(self, x)
+ if self.training:
+ return x, angle
+ return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
+
+ def decode_bboxes(self, bboxes):
+ """Decode rotated bounding boxes."""
+ return dist2rbox(self.dfl(bboxes), self.angle, self.anchors.unsqueeze(0), dim=1) * self.strides
+
+
class Pose(Detect):
"""YOLOv8 Pose head for keypoints models."""
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index c7856ac1..0f4af06d 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -7,13 +7,13 @@ from pathlib import Path
import torch
import torch.nn as nn
-from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
- Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
- Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
- ResNetLayer, RTDETRDecoder, Segment)
+from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, OBB, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost,
+ C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv,
+ DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3,
+ RepConv, ResNetLayer, RTDETRDecoder, Segment)
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
-from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
+from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.utils.plotting import feature_visualization
from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
make_divisible, model_info, scale_img, time_sync)
@@ -241,10 +241,10 @@ class DetectionModel(BaseModel):
# Build strides
m = self.model[-1] # Detect()
- if isinstance(m, (Detect, Segment, Pose)):
+ if isinstance(m, (Detect, Segment, Pose, OBB)):
s = 256 # 2x min stride
m.inplace = self.inplace
- forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x)
+ forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
self.stride = m.stride
m.bias_init() # only run once
@@ -298,6 +298,17 @@ class DetectionModel(BaseModel):
return v8DetectionLoss(self)
+class OBBModel(DetectionModel):
+ """"YOLOv8 Oriented Bounding Box (OBB) model."""
+
+ def __init__(self, cfg='yolov8n-obb.yaml', ch=3, nc=None, verbose=True):
+ """Initialize YOLOv8 OBB model with given config and parameters."""
+ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
+
+ def init_criterion(self):
+ return v8OBBLoss(self)
+
+
class SegmentationModel(DetectionModel):
"""YOLOv8 segmentation model."""
@@ -616,7 +627,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# Module updates
for m in ensemble.modules():
t = type(m)
- if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
+ if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
m.inplace = inplace
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
@@ -652,7 +663,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
# Module updates
for m in model.modules():
t = type(m)
- if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
+ if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
m.inplace = inplace
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
@@ -717,7 +728,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
- elif m in (Detect, Segment, Pose):
+ elif m in (Detect, Segment, Pose, OBB):
args.append([ch[x] for x in f])
if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
@@ -801,6 +812,8 @@ def guess_model_task(model):
return 'segment'
if m == 'pose':
return 'pose'
+ if m == 'obb':
+ return 'obb'
# Guess from model cfg
if isinstance(model, dict):
@@ -825,6 +838,8 @@ def guess_model_task(model):
return 'classify'
elif isinstance(m, Pose):
return 'pose'
+ elif isinstance(m, OBB):
+ return 'obb'
# Guess from model filename
if isinstance(model, (str, Path)):
@@ -835,10 +850,12 @@ def guess_model_task(model):
return 'classify'
elif '-pose' in model.stem or 'pose' in model.parts:
return 'pose'
+ elif '-obb' in model.stem or 'obb' in model.parts:
+ return 'obb'
elif 'detect' in model.parts:
return 'detect'
# Unable to determine task from model
LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
- "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify', or 'pose'.")
+ "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.")
return 'detect' # assume detect
diff --git a/ultralytics/trackers/track.py b/ultralytics/trackers/track.py
index ce6282a6..2ad4f5d7 100644
--- a/ultralytics/trackers/track.py
+++ b/ultralytics/trackers/track.py
@@ -24,6 +24,8 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
Raises:
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
"""
+ if predictor.args.task == 'obb':
+ raise NotImplementedError('ERROR ❌ OBB task does not support track mode!')
if hasattr(predictor, 'trackers') and persist:
return
diff --git a/ultralytics/utils/instance.py b/ultralytics/utils/instance.py
index 7df1453d..a1d85aaf 100644
--- a/ultralytics/utils/instance.py
+++ b/ultralytics/utils/instance.py
@@ -7,7 +7,7 @@ from typing import List
import numpy as np
-from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
+from .ops import ltwh2xywh, ltwh2xyxy, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
def _ntuple(n):
@@ -212,19 +212,9 @@ class Instances:
segments (list | ndarray): segments.
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3].
"""
- if segments is None:
- segments = []
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
self.keypoints = keypoints
self.normalized = normalized
-
- if len(segments) > 0:
- # List[np.array(1000, 2)] * num_samples
- segments = resample_segments(segments)
- # (N, 1000, 2)
- segments = np.stack(segments, axis=0)
- else:
- segments = np.zeros((0, 1000, 2), dtype=np.float32)
self.segments = segments
def convert_bbox(self, format):
diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py
index da2e5842..38d633bf 100644
--- a/ultralytics/utils/loss.py
+++ b/ultralytics/utils/loss.py
@@ -6,9 +6,9 @@ import torch.nn.functional as F
from ultralytics.utils.metrics import OKS_SIGMA
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
-from ultralytics.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
+from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
-from .metrics import bbox_iou
+from .metrics import bbox_iou, probiou
from .tal import bbox2dist
@@ -95,6 +95,30 @@ class BboxLoss(nn.Module):
F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
+class RotatedBboxLoss(BboxLoss):
+ """Criterion class for computing training losses during training."""
+
+ def __init__(self, reg_max, use_dfl=False):
+ """Initialize the BboxLoss module with regularization maximum and DFL settings."""
+ super().__init__(reg_max, use_dfl)
+
+ def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
+ """IoU loss."""
+ weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
+ iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
+ loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
+
+ # DFL loss
+ if self.use_dfl:
+ target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
+ loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
+ loss_dfl = loss_dfl.sum() / target_scores_sum
+ else:
+ loss_dfl = torch.tensor(0.0).to(pred_dist.device)
+
+ return loss_iou, loss_dfl
+
+
class KeypointLoss(nn.Module):
"""Criterion class for computing training losses."""
@@ -243,9 +267,9 @@ class v8SegmentationLoss(v8DetectionLoss):
except RuntimeError as e:
raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
- "i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
- "correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
- 'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e
+ "i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
+ "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
+ 'as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help.') from e
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
@@ -526,3 +550,109 @@ class v8ClassificationLoss:
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='mean')
loss_items = loss.detach()
return loss, loss_items
+
+
+class v8OBBLoss(v8DetectionLoss):
+
+ def __init__(self, model): # model must be de-paralleled
+ super().__init__(model)
+ self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
+ self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)
+
+ def preprocess(self, targets, batch_size, scale_tensor):
+ """Preprocesses the target counts and matches with the input batch size to output a tensor."""
+ if targets.shape[0] == 0:
+ out = torch.zeros(batch_size, 0, 6, device=self.device)
+ else:
+ i = targets[:, 0] # image index
+ _, counts = i.unique(return_counts=True)
+ counts = counts.to(dtype=torch.int32)
+ out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
+ for j in range(batch_size):
+ matches = i == j
+ n = matches.sum()
+ if n:
+ bboxes = targets[matches, 2:]
+ bboxes[..., :4].mul_(scale_tensor)
+ out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
+ return out
+
+ def __call__(self, preds, batch):
+ """Calculate and return the loss for the YOLO model."""
+ loss = torch.zeros(3, device=self.device) # box, cls, dfl
+ feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
+ batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
+ pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
+ (self.reg_max * 4, self.nc), 1)
+
+ # b, grids, ..
+ pred_scores = pred_scores.permute(0, 2, 1).contiguous()
+ pred_distri = pred_distri.permute(0, 2, 1).contiguous()
+ pred_angle = pred_angle.permute(0, 2, 1).contiguous()
+
+ dtype = pred_scores.dtype
+ imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
+ anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
+
+ # targets
+ try:
+ batch_idx = batch['batch_idx'].view(-1, 1)
+ targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes'].view(-1, 5)), 1)
+ rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
+ targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
+ gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
+ except RuntimeError as e:
+ raise TypeError('ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n'
+ "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
+ "i.e. 'yolo train model=yolov8n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
+ "correctly formatted 'OBB' dataset using 'data=coco8-obb.yaml' "
+ 'as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help.') from e
+
+ # Pboxes
+ pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
+
+ bboxes_for_assigner = pred_bboxes.clone().detach()
+ # Only the first four elements need to be scaled
+ bboxes_for_assigner[..., :4] *= stride_tensor
+ _, target_bboxes, target_scores, fg_mask, _ = self.assigner(pred_scores.detach().sigmoid(),
+ bboxes_for_assigner.type(gt_bboxes.dtype),
+ anchor_points * stride_tensor, gt_labels, gt_bboxes,
+ mask_gt)
+
+ target_scores_sum = max(target_scores.sum(), 1)
+
+ # Cls loss
+ # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
+ loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
+
+ # Bbox loss
+ if fg_mask.sum():
+ target_bboxes[..., :4] /= stride_tensor
+ loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
+ target_scores_sum, fg_mask)
+ else:
+ loss[0] += (pred_angle * 0).sum()
+
+ loss[0] *= self.hyp.box # box gain
+ loss[1] *= self.hyp.cls # cls gain
+ loss[2] *= self.hyp.dfl # dfl gain
+
+ return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
+
+ def bbox_decode(self, anchor_points, pred_dist, pred_angle):
+ """
+ Decode predicted object bounding box coordinates from anchor points and distribution.
+
+ Args:
+ anchor_points (torch.Tensor): Anchor points, (h*w, 2).
+ pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
+ pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
+ Returns:
+ (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
+ """
+ if self.use_dfl:
+ b, a, c = pred_dist.shape # batch, anchors, channels
+ pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
+ return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py
index 27e41b77..8d60b372 100644
--- a/ultralytics/utils/metrics.py
+++ b/ultralytics/utils/metrics.py
@@ -165,6 +165,92 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
return (torch.exp(-e) * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
+def _get_covariance_matrix(boxes):
+ """
+ Generating covariance matrix from obbs.
+
+ Args:
+ boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
+
+ Returns:
+ (torch.Tensor): Covariance metrixs corresponding to original rotated bounding boxes.
+ """
+ # Gaussian bounding boxes, ignored the center points(the first two columns) cause it's not needed here.
+ gbbs = torch.cat((torch.pow(boxes[:, 2:4], 2) / 12, boxes[:, 4:]), dim=-1)
+ a, b, c = gbbs.split(1, dim=-1)
+ return (
+ a * torch.cos(c) ** 2 + b * torch.sin(c) ** 2,
+ a * torch.sin(c) ** 2 + b * torch.cos(c) ** 2,
+ a * torch.cos(c) * torch.sin(c) - b * torch.sin(c) * torch.cos(c),
+ )
+
+
+def probiou(obb1, obb2, CIoU=False, eps=1e-7):
+ """
+ Calculate the prob iou between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
+
+ Args:
+ obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
+ obb2 (torch.Tensor): A tensor of shape (N, 5) representing predicted obbs, with xywhr format.
+ eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
+
+ Returns:
+ (torch.Tensor): A tensor of shape (N, ) representing obb similarities.
+ """
+ x1, y1 = obb1[..., :2].split(1, dim=-1)
+ x2, y2 = obb2[..., :2].split(1, dim=-1)
+ a1, b1, c1 = _get_covariance_matrix(obb1)
+ a2, b2, c2 = _get_covariance_matrix(obb2)
+
+ t1 = (((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) /
+ ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.25
+ t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5
+ t3 = torch.log(((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) /
+ (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) *
+ (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + eps) * 0.5
+ bd = t1 + t2 + t3
+ bd = torch.clamp(bd, eps, 100.0)
+ hd = torch.sqrt(1.0 - torch.exp(-bd) + eps)
+ iou = 1 - hd
+ if CIoU: # only include the wh aspect ratio part
+ w1, h1 = obb1[..., 2:4].split(1, dim=-1)
+ w2, h2 = obb2[..., 2:4].split(1, dim=-1)
+ v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
+ with torch.no_grad():
+ alpha = v / (v - iou + (1 + eps))
+ return iou - v * alpha # CIoU
+ return iou
+
+
+def batch_probiou(obb1, obb2, eps=1e-7):
+ """
+ Calculate the prob iou between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
+
+ Args:
+ obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
+ obb2 (torch.Tensor): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
+ eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
+
+ Returns:
+ (torch.Tensor): A tensor of shape (N, M) representing obb similarities.
+ """
+ x1, y1 = obb1[..., :2].split(1, dim=-1)
+ x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
+ a1, b1, c1 = _get_covariance_matrix(obb1)
+ a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))
+
+ t1 = (((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) /
+ ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.25
+ t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5
+ t3 = torch.log(((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) /
+ (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) *
+ (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + eps) * 0.5
+ bd = t1 + t2 + t3
+ bd = torch.clamp(bd, eps, 100.0)
+ hd = torch.sqrt(1.0 - torch.exp(-bd) + eps)
+ return 1 - hd
+
+
def smooth_BCE(eps=0.1):
"""
Computes smoothed positive and negative Binary Cross-Entropy targets.
@@ -213,17 +299,17 @@ class ConfusionMatrix:
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
self.matrix[p][t] += 1
- def process_batch(self, detections, labels):
+ def process_batch(self, detections, gt_bboxes, gt_cls):
"""
Update confusion matrix for object detection task.
Args:
detections (Array[N, 6]): Detected bounding boxes and their associated information.
Each row should contain (x1, y1, x2, y2, conf, class).
- labels (Array[M, 5]): Ground truth bounding boxes and their associated class labels.
- Each row should contain (class, x1, y1, x2, y2).
+ gt_bboxes (Array[M, 4]): Ground truth bounding boxes with xyxy format.
+ gt_cls (Array[M]): The class labels.
"""
- if labels.size(0) == 0: # Check if labels is empty
+ if gt_cls.size(0) == 0: # Check if labels is empty
if detections is not None:
detections = detections[detections[:, 4] > self.conf]
detection_classes = detections[:, 5].int()
@@ -231,15 +317,15 @@ class ConfusionMatrix:
self.matrix[dc, self.nc] += 1 # false positives
return
if detections is None:
- gt_classes = labels.int()
+ gt_classes = gt_cls.int()
for gc in gt_classes:
self.matrix[self.nc, gc] += 1 # background FN
return
detections = detections[detections[:, 4] > self.conf]
- gt_classes = labels[:, 0].int()
+ gt_classes = gt_cls.int()
detection_classes = detections[:, 5].int()
- iou = box_iou(labels[:, 1:], detections[:, :4])
+ iou = box_iou(gt_bboxes, detections[:, :4])
x = torch.where(iou > self.iou_thres)
if x[0].shape[0]:
@@ -814,12 +900,12 @@ class SegmentMetrics(SimpleClass):
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.task = 'segment'
- def process(self, tp_b, tp_m, conf, pred_cls, target_cls):
+ def process(self, tp, tp_m, conf, pred_cls, target_cls):
"""
Processes the detection and segmentation metrics over the given set of predictions.
Args:
- tp_b (list): List of True Positive boxes.
+ tp (list): List of True Positive boxes.
tp_m (list): List of True Positive masks.
conf (list): List of confidence scores.
pred_cls (list): List of predicted classes.
@@ -837,7 +923,7 @@ class SegmentMetrics(SimpleClass):
prefix='Mask')[2:]
self.seg.nc = len(self.names)
self.seg.update(results_mask)
- results_box = ap_per_class(tp_b,
+ results_box = ap_per_class(tp,
conf,
pred_cls,
target_cls,
@@ -938,12 +1024,12 @@ class PoseMetrics(SegmentMetrics):
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.task = 'pose'
- def process(self, tp_b, tp_p, conf, pred_cls, target_cls):
+ def process(self, tp, tp_p, conf, pred_cls, target_cls):
"""
Processes the detection and pose metrics over the given set of predictions.
Args:
- tp_b (list): List of True Positive boxes.
+ tp (list): List of True Positive boxes.
tp_p (list): List of True Positive keypoints.
conf (list): List of confidence scores.
pred_cls (list): List of predicted classes.
@@ -961,7 +1047,7 @@ class PoseMetrics(SegmentMetrics):
prefix='Pose')[2:]
self.pose.nc = len(self.names)
self.pose.update(results_pose)
- results_box = ap_per_class(tp_b,
+ results_box = ap_per_class(tp,
conf,
pred_cls,
target_cls,
@@ -1067,3 +1153,70 @@ class ClassifyMetrics(SimpleClass):
def curves_results(self):
"""Returns a list of curves for accessing specific metrics curves."""
return []
+
+
+class OBBMetrics(SimpleClass):
+
+ def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
+ self.save_dir = save_dir
+ self.plot = plot
+ self.on_plot = on_plot
+ self.names = names
+ self.box = Metric()
+ self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
+
+ def process(self, tp, conf, pred_cls, target_cls):
+ """Process predicted results for object detection and update metrics."""
+ results = ap_per_class(tp,
+ conf,
+ pred_cls,
+ target_cls,
+ plot=self.plot,
+ save_dir=self.save_dir,
+ names=self.names,
+ on_plot=self.on_plot)[2:]
+ self.box.nc = len(self.names)
+ self.box.update(results)
+
+ @property
+ def keys(self):
+ """Returns a list of keys for accessing specific metrics."""
+ return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
+
+ def mean_results(self):
+ """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
+ return self.box.mean_results()
+
+ def class_result(self, i):
+ """Return the result of evaluating the performance of an object detection model on a specific class."""
+ return self.box.class_result(i)
+
+ @property
+ def maps(self):
+ """Returns mean Average Precision (mAP) scores per class."""
+ return self.box.maps
+
+ @property
+ def fitness(self):
+ """Returns the fitness of box object."""
+ return self.box.fitness()
+
+ @property
+ def ap_class_index(self):
+ """Returns the average precision index per class."""
+ return self.box.ap_class_index
+
+ @property
+ def results_dict(self):
+ """Returns dictionary of computed performance metrics and statistics."""
+ return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
+
+ @property
+ def curves(self):
+ """Returns a list of curves for accessing specific metrics curves."""
+ return []
+
+ @property
+ def curves_results(self):
+ """Returns a list of curves for accessing specific metrics curves."""
+ return []
diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py
index 44e26ba2..12ac0461 100644
--- a/ultralytics/utils/ops.py
+++ b/ultralytics/utils/ops.py
@@ -12,6 +12,7 @@ import torch.nn.functional as F
import torchvision
from ultralytics.utils import LOGGER
+from ultralytics.utils.metrics import batch_probiou
class Profile(contextlib.ContextDecorator):
@@ -80,10 +81,10 @@ def segment2box(segment, width=640, height=640):
4, dtype=segment.dtype) # xyxy
-def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
+def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
"""
- Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
- (img1_shape) to the shape of a different image (img0_shape).
+ Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
+ specified in (img1_shape) to the shape of a different image (img0_shape).
Args:
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
@@ -93,6 +94,7 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
calculated based on the size difference between the two images.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
+ xywh (bool): The box format is xywh or not, default=False.
Returns:
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
@@ -106,8 +108,11 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
pad = ratio_pad[1]
if padding:
- boxes[..., [0, 2]] -= pad[0] # x padding
- boxes[..., [1, 3]] -= pad[1] # y padding
+ boxes[..., 0] -= pad[0] # x padding
+ boxes[..., 1] -= pad[1] # y padding
+ if not xywh:
+ boxes[..., 2] -= pad[0] # x padding
+ boxes[..., 3] -= pad[1] # y padding
boxes[..., :4] /= gain
return clip_boxes(boxes, img0_shape)
@@ -128,19 +133,40 @@ def make_divisible(x, divisor):
return math.ceil(x / divisor) * divisor
+def nms_rotated(boxes, scores, threshold=0.45):
+ """
+ NMS for obbs, powered by probiou and fast-nms.
+
+ Args:
+ boxes (torch.Tensor): (N, 5), xywhr.
+ scores (torch.Tensor): (N, ).
+ threshold (float): Iou threshold.
+
+ Returns:
+ """
+ if len(boxes) == 0:
+ return np.empty((0, ), dtype=np.int8)
+ sorted_idx = torch.argsort(scores, descending=True)
+ boxes = boxes[sorted_idx]
+ ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
+ pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
+ return sorted_idx[pick]
+
+
def non_max_suppression(
- prediction,
- conf_thres=0.25,
- iou_thres=0.45,
- classes=None,
- agnostic=False,
- multi_label=False,
- labels=(),
- max_det=300,
- nc=0, # number of classes (optional)
- max_time_img=0.05,
- max_nms=30000,
- max_wh=7680,
+ prediction,
+ conf_thres=0.25,
+ iou_thres=0.45,
+ classes=None,
+ agnostic=False,
+ multi_label=False,
+ labels=(),
+ max_det=300,
+ nc=0, # number of classes (optional)
+ max_time_img=0.05,
+ max_nms=30000,
+ max_wh=7680,
+ rotated=False,
):
"""
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
@@ -190,7 +216,8 @@ def non_max_suppression(
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
- prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
+ if not rotated:
+ prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
@@ -200,7 +227,7 @@ def non_max_suppression(
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
- if labels and len(labels[xi]):
+ if labels and len(labels[xi]) and not rotated:
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
@@ -234,8 +261,13 @@ def non_max_suppression(
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
- boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
- i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+ scores = x[:, 4] # scores
+ if rotated:
+ boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -2:-1]), dim=-1) # xywhr
+ i = nms_rotated(boxes, scores, iou_thres)
+ else:
+ boxes = x[:, :4] + c # boxes (offset by class)
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
# # Experimental
@@ -320,7 +352,7 @@ def scale_image(masks, im0_shape, ratio_pad=None):
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
else:
- gain = ratio_pad[0][0]
+ # gain = ratio_pad[0][0]
pad = ratio_pad[1]
top, left = (int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1))) # y, x
bottom, right = (int(round(im1_shape[0] - pad[1] + 0.1)), int(round(im1_shape[1] - pad[0] + 0.1)))
@@ -476,7 +508,8 @@ def ltwh2xywh(x):
def xyxyxyxy2xywhr(corners):
"""
- Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation].
+ Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
+ expected in degrees from 0 to 90.
Args:
corners (numpy.ndarray | torch.Tensor): Input corners of shape (n, 8).
@@ -484,61 +517,46 @@ def xyxyxyxy2xywhr(corners):
Returns:
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
"""
- is_numpy = isinstance(corners, np.ndarray)
- atan2, sqrt = (np.arctan2, np.sqrt) if is_numpy else (torch.atan2, torch.sqrt)
-
- x1, y1, x2, y2, x3, y3, x4, y4 = corners.T
- cx = (x1 + x3) / 2
- cy = (y1 + y3) / 2
- dx21 = x2 - x1
- dy21 = y2 - y1
-
- w = sqrt(dx21 ** 2 + dy21 ** 2)
- h = sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
-
- rotation = atan2(-dy21, dx21)
- rotation *= 180.0 / math.pi # radians to degrees
-
- return np.vstack((cx, cy, w, h, rotation)).T if is_numpy else torch.stack((cx, cy, w, h, rotation), dim=1)
+ is_torch = isinstance(corners, torch.Tensor)
+ points = corners.cpu().numpy() if is_torch else corners
+ points = points.reshape(len(corners), -1, 2)
+ rboxes = []
+ for pts in points:
+ # NOTE: Use cv2.minAreaRect to get accurate xywhr,
+ # especially some objects are cut off by augmentations in dataloader.
+ (x, y), (w, h), angle = cv2.minAreaRect(pts)
+ rboxes.append([x, y, w, h, angle / 180 * np.pi])
+ rboxes = torch.tensor(rboxes, device=corners.device, dtype=corners.dtype) if is_torch else np.asarray(
+ rboxes, dtype=points.dtype)
+ return rboxes
def xywhr2xyxyxyxy(center):
"""
- Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4].
+ Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
+ be in degrees from 0 to 90.
Args:
- center (numpy.ndarray | torch.Tensor): Input data in [cx, cy, w, h, rotation] format of shape (n, 5).
+ center (numpy.ndarray | torch.Tensor): Input data in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
Returns:
- (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 8).
+ (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
"""
is_numpy = isinstance(center, np.ndarray)
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
- cx, cy, w, h, rotation = center.T
- rotation *= math.pi / 180.0 # degrees to radians
-
- dx = w / 2
- dy = h / 2
-
- cos_rot = cos(rotation)
- sin_rot = sin(rotation)
- dx_cos_rot = dx * cos_rot
- dx_sin_rot = dx * sin_rot
- dy_cos_rot = dy * cos_rot
- dy_sin_rot = dy * sin_rot
-
- x1 = cx - dx_cos_rot - dy_sin_rot
- y1 = cy + dx_sin_rot - dy_cos_rot
- x2 = cx + dx_cos_rot - dy_sin_rot
- y2 = cy - dx_sin_rot - dy_cos_rot
- x3 = cx + dx_cos_rot + dy_sin_rot
- y3 = cy - dx_sin_rot + dy_cos_rot
- x4 = cx - dx_cos_rot + dy_sin_rot
- y4 = cy + dx_sin_rot + dy_cos_rot
-
- return np.vstack((x1, y1, x2, y2, x3, y3, x4, y4)).T if is_numpy else torch.stack(
- (x1, y1, x2, y2, x3, y3, x4, y4), dim=1)
+ ctr = center[..., :2]
+ w, h, angle = (center[..., i:i + 1] for i in range(2, 5))
+ cos_value, sin_value = cos(angle), sin(angle)
+ vec1 = [w / 2 * cos_value, w / 2 * sin_value]
+ vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
+ vec1 = np.concatenate(vec1, axis=-1) if is_numpy else torch.cat(vec1, dim=-1)
+ vec2 = np.concatenate(vec2, axis=-1) if is_numpy else torch.cat(vec2, dim=-1)
+ pt1 = ctr + vec1 + vec2
+ pt2 = ctr + vec1 - vec2
+ pt3 = ctr - vec1 - vec2
+ pt4 = ctr - vec1 + vec2
+ return np.stack([pt1, pt2, pt3, pt4], axis=-2) if is_numpy else torch.stack([pt1, pt2, pt3, pt4], dim=-2)
def ltwh2xyxy(x):
diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py
index 0929c7bf..aaab72a1 100644
--- a/ultralytics/utils/plotting.py
+++ b/ultralytics/utils/plotting.py
@@ -100,25 +100,35 @@ class Annotator:
self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
- def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
+ def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
"""Add one xyxy box to image with label."""
if isinstance(box, torch.Tensor):
box = box.tolist()
if self.pil or not is_ascii(label):
- self.draw.rectangle(box, width=self.lw, outline=color) # box
+ if rotated:
+ p1 = box[0]
+ # NOTE: PIL-version polygon needs tuple type.
+ self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color)
+ else:
+ p1 = (box[0], box[1])
+ self.draw.rectangle(box, width=self.lw, outline=color) # box
if label:
w, h = self.font.getsize(label) # text width, height
- outside = box[1] - h >= 0 # label fits outside box
+ outside = p1[1] - h >= 0 # label fits outside box
self.draw.rectangle(
- (box[0], box[1] - h if outside else box[1], box[0] + w + 1,
- box[1] + 1 if outside else box[1] + h + 1),
+ (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),
fill=color,
)
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
- self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
+ self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
else: # cv2
- p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
- cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
+ if rotated:
+ p1 = [int(b) for b in box[0]]
+ # NOTE: cv2-version polylines needs np.asarray type.
+ cv2.polylines(self.im, [np.asarray(box, dtype=np.int)], True, color, self.lw)
+ else:
+ p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
+ cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
if label:
w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
outside = p1[1] - h >= 3
@@ -563,6 +573,7 @@ def plot_images(images,
batch_idx,
cls,
bboxes=np.zeros(0, dtype=np.float32),
+ confs=None,
masks=np.zeros(0, dtype=np.uint8),
kpts=np.zeros((0, 51), dtype=np.float32),
paths=None,
@@ -618,27 +629,29 @@ def plot_images(images,
if len(cls) > 0:
idx = batch_idx == i
classes = cls[idx].astype('int')
+ labels = confs is None
if len(bboxes):
- boxes = ops.xywh2xyxy(bboxes[idx, :4]).T
- labels = bboxes.shape[1] == 4 # labels if no conf column
- conf = None if labels else bboxes[idx, 4] # check for confidence presence (label vs pred)
-
- if boxes.shape[1]:
- if boxes.max() <= 1.01: # if normalized with tolerance 0.01
- boxes[[0, 2]] *= w # scale to pixels
- boxes[[1, 3]] *= h
+ boxes = bboxes[idx]
+ conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
+ if len(boxes):
+ if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
+ boxes[:, [0, 2]] *= w # scale to pixels
+ boxes[:, [1, 3]] *= h
elif scale < 1: # absolute coords need scale if image scales
- boxes *= scale
- boxes[[0, 2]] += x
- boxes[[1, 3]] += y
- for j, box in enumerate(boxes.T.tolist()):
+ boxes[:, :4] *= scale
+ boxes[:, 0] += x
+ boxes[:, 1] += y
+ is_obb = boxes.shape[-1] == 5 # xywhr
+ boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
+ for j, box in enumerate(boxes.astype(np.int64).tolist()):
c = classes[j]
color = colors(c)
c = names.get(c, c) if names else c
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
- annotator.box_label(box, label, color=color)
+ annotator.box_label(box, label, color=color, rotated=is_obb)
+
elif len(classes):
for c in classes:
color = colors(c)
@@ -847,7 +860,18 @@ def output_to_target(output, max_det=300):
j = torch.full((conf.shape[0], 1), i)
targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
targets = torch.cat(targets, 0).numpy()
- return targets[:, 0], targets[:, 1], targets[:, 2:]
+ return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
+
+
+def output_to_rotated_target(output, max_det=300):
+ """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
+ targets = []
+ for i, o in enumerate(output):
+ box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
+ j = torch.full((conf.shape[0], 1), i)
+ targets.append(torch.cat((j, cls, box, angle, conf), 1))
+ targets = torch.cat(targets, 0).numpy()
+ return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py
index 8a8333f1..97406cd7 100644
--- a/ultralytics/utils/tal.py
+++ b/ultralytics/utils/tal.py
@@ -4,59 +4,12 @@ import torch
import torch.nn as nn
from .checks import check_version
-from .metrics import bbox_iou
+from .metrics import bbox_iou, probiou
+from .ops import xywhr2xyxyxyxy
TORCH_1_10 = check_version(torch.__version__, '1.10.0')
-def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
- """
- Select the positive anchor center in gt.
-
- Args:
- xy_centers (Tensor): shape(h*w, 2)
- gt_bboxes (Tensor): shape(b, n_boxes, 4)
-
- Returns:
- (Tensor): shape(b, n_boxes, h*w)
- """
- n_anchors = xy_centers.shape[0]
- bs, n_boxes, _ = gt_bboxes.shape
- lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
- bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
- # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
- return bbox_deltas.amin(3).gt_(eps)
-
-
-def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
- """
- If an anchor box is assigned to multiple gts, the one with the highest IoI will be selected.
-
- Args:
- mask_pos (Tensor): shape(b, n_max_boxes, h*w)
- overlaps (Tensor): shape(b, n_max_boxes, h*w)
-
- Returns:
- target_gt_idx (Tensor): shape(b, h*w)
- fg_mask (Tensor): shape(b, h*w)
- mask_pos (Tensor): shape(b, n_max_boxes, h*w)
- """
- # (b, n_max_boxes, h*w) -> (b, h*w)
- fg_mask = mask_pos.sum(-2)
- if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
- mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
- max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
-
- is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
- is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
-
- mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
- fg_mask = mask_pos.sum(-2)
- # Find each grid serve which gt(index)
- target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
- return target_gt_idx, fg_mask, mask_pos
-
-
class TaskAlignedAssigner(nn.Module):
"""
A task-aligned assigner for object detection.
@@ -115,7 +68,7 @@ class TaskAlignedAssigner(nn.Module):
mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
mask_gt)
- target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
+ target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
# Assigned target
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
@@ -131,7 +84,7 @@ class TaskAlignedAssigner(nn.Module):
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
"""Get in_gts mask, (b, max_num_obj, h*w)."""
- mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+ mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
# Get anchor_align metric, (b, max_num_obj, h*w)
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
# Get topk_metric mask, (b, max_num_obj, h*w)
@@ -157,11 +110,15 @@ class TaskAlignedAssigner(nn.Module):
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
- overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
+ overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
return align_metric, overlaps
+ def iou_calculation(self, gt_bboxes, pd_bboxes):
+ """Iou calculation for horizontal bounding boxes."""
+ return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
+
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
"""
Select the top-k candidates based on the given metrics.
@@ -229,7 +186,7 @@ class TaskAlignedAssigner(nn.Module):
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
- target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
+ target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
# Assigned target scores
target_labels.clamp_(0)
@@ -245,6 +202,89 @@ class TaskAlignedAssigner(nn.Module):
return target_labels, target_bboxes, target_scores
+ @staticmethod
+ def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+ """
+ Select the positive anchor center in gt.
+
+ Args:
+ xy_centers (Tensor): shape(h*w, 2)
+ gt_bboxes (Tensor): shape(b, n_boxes, 4)
+
+ Returns:
+ (Tensor): shape(b, n_boxes, h*w)
+ """
+ n_anchors = xy_centers.shape[0]
+ bs, n_boxes, _ = gt_bboxes.shape
+ lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
+ bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
+ # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
+ return bbox_deltas.amin(3).gt_(eps)
+
+ @staticmethod
+ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+ """
+ If an anchor box is assigned to multiple gts, the one with the highest IoI will be selected.
+
+ Args:
+ mask_pos (Tensor): shape(b, n_max_boxes, h*w)
+ overlaps (Tensor): shape(b, n_max_boxes, h*w)
+
+ Returns:
+ target_gt_idx (Tensor): shape(b, h*w)
+ fg_mask (Tensor): shape(b, h*w)
+ mask_pos (Tensor): shape(b, n_max_boxes, h*w)
+ """
+ # (b, n_max_boxes, h*w) -> (b, h*w)
+ fg_mask = mask_pos.sum(-2)
+ if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
+ mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
+ max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
+
+ is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
+ is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
+
+ mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
+ fg_mask = mask_pos.sum(-2)
+ # Find each grid serve which gt(index)
+ target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
+ return target_gt_idx, fg_mask, mask_pos
+
+
+class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
+
+ def iou_calculation(self, gt_bboxes, pd_bboxes):
+ """Iou calculation for rotated bounding boxes."""
+ return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
+
+ @staticmethod
+ def select_candidates_in_gts(xy_centers, gt_bboxes):
+ """
+ Select the positive anchor center in gt for rotated bounding boxes.
+
+ Args:
+ xy_centers (Tensor): shape(h*w, 2)
+ gt_bboxes (Tensor): shape(b, n_boxes, 5)
+
+ Returns:
+ (Tensor): shape(b, n_boxes, h*w)
+ """
+ # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
+ corners = xywhr2xyxyxyxy(gt_bboxes)
+ # (b, n_boxes, 1, 2)
+ a, b, _, d = corners.split(1, dim=-2)
+ ab = b - a
+ ad = d - a
+
+ # (b, n_boxes, h*w, 2)
+ ap = xy_centers - a
+ norm_ab = (ab * ab).sum(dim=-1)
+ norm_ad = (ad * ad).sum(dim=-1)
+ ap_dot_ab = (ap * ab).sum(dim=-1)
+ ap_dot_ad = (ap * ad).sum(dim=-1)
+ is_in_box = (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad)
+ return is_in_box
+
def make_anchors(feats, strides, grid_cell_offset=0.5):
"""Generate anchors from features."""
@@ -277,3 +317,23 @@ def bbox2dist(anchor_points, bbox, reg_max):
"""Transform bbox(xyxy) to dist(ltrb)."""
x1y1, x2y2 = bbox.chunk(2, -1)
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
+
+
+def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
+ """
+ Decode predicted object bounding box coordinates from anchor points and distribution.
+
+ Args:
+ pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
+ pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
+ anchor_points (torch.Tensor): Anchor points, (h*w, 2).
+ Returns:
+ (torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4).
+ """
+ lt, rb = pred_dist.split(2, dim=dim)
+ cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
+ # (bs, h*w, 1)
+ xf, yf = ((rb - lt) / 2).split(1, dim=dim)
+ x, y = xf * cos - yf * sin, xf * sin + yf * cos
+ xy = torch.cat([x, y], dim=dim) + anchor_points
+ return torch.cat([xy, lt + rb], dim=dim)