Fix _process_batch() docstrings (#14454)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-07-16 01:01:18 +02:00 committed by GitHub
parent e094f9c371
commit c2f9a12cb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 77 additions and 25 deletions

View file

@ -164,14 +164,34 @@ class SegmentationValidator(DetectionValidator):
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
Return correct prediction matrix.
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
Args:
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
detections (torch.Tensor): Tensor of shape (N, 6) representing detected bounding boxes and
associated confidence scores and class indices. Each row is of the format [x1, y1, x2, y2, conf, class].
gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
Each row is of the format [x1, y1, x2, y2].
gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
pred_masks (torch.Tensor | None): Tensor representing predicted masks, if available. The shape should
match the ground truth masks.
gt_masks (torch.Tensor | None): Tensor of shape (M, H, W) representing ground truth masks, if available.
overlap (bool): Flag indicating if overlapping masks should be considered.
masks (bool): Flag indicating if the batch contains mask data.
Returns:
correct (array[N, 10]), for 10 IoU levels
(torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels.
Note:
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
Example:
```python
detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]])
gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]])
gt_cls = torch.tensor([1, 0])
correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
```
"""
if masks:
if overlap: