Fix _process_batch() docstrings (#14454)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-07-16 01:01:18 +02:00 committed by GitHub
parent e094f9c371
commit c2f9a12cb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 77 additions and 25 deletions

View file

@ -52,17 +52,29 @@ class OBBValidator(DetectionValidator):
def _process_batch(self, detections, gt_bboxes, gt_cls):
"""
Return correct prediction matrix.
Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
Args:
detections (torch.Tensor): Tensor of shape [N, 7] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class, angle.
gt_bboxes (torch.Tensor): Tensor of shape [M, 5] representing rotated boxes.
Each box is of the format: x1, y1, x2, y2, angle.
labels (torch.Tensor): Tensor of shape [M] representing labels.
detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated
data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle).
gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is
represented as (x1, y1, x2, y2, angle).
gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes.
Returns:
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
(torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.
Example:
```python
detections = torch.rand(100, 7) # 100 sample detections
gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
```
Note:
This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
"""
iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
return self.match_predictions(detections[:, 5], gt_cls, iou)