Fix _process_batch() docstrings (#14454)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
e094f9c371
commit
c2f9a12cb4
4 changed files with 77 additions and 25 deletions
|
|
@ -202,13 +202,18 @@ class DetectionValidator(BaseValidator):
|
||||||
Return correct prediction matrix.
|
Return correct prediction matrix.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
|
detections (torch.Tensor): Tensor of shape (N, 6) representing detections where each detection is
|
||||||
Each detection is of the format: x1, y1, x2, y2, conf, class.
|
(x1, y1, x2, y2, conf, class).
|
||||||
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
|
gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground-truth bounding box coordinates. Each
|
||||||
Each label is of the format: class, x1, y1, x2, y2.
|
bounding box is of the format: (x1, y1, x2, y2).
|
||||||
|
gt_cls (torch.Tensor): Tensor of shape (M,) representing target class indices.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
|
(torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The function does not return any value directly usable for metrics calculation. Instead, it provides an
|
||||||
|
intermediate representation used for evaluating predictions against ground truth.
|
||||||
"""
|
"""
|
||||||
iou = box_iou(gt_bboxes, detections[:, :4])
|
iou = box_iou(gt_bboxes, detections[:, :4])
|
||||||
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
||||||
|
|
|
||||||
|
|
@ -52,17 +52,29 @@ class OBBValidator(DetectionValidator):
|
||||||
|
|
||||||
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
||||||
"""
|
"""
|
||||||
Return correct prediction matrix.
|
Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
detections (torch.Tensor): Tensor of shape [N, 7] representing detections.
|
detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated
|
||||||
Each detection is of the format: x1, y1, x2, y2, conf, class, angle.
|
data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle).
|
||||||
gt_bboxes (torch.Tensor): Tensor of shape [M, 5] representing rotated boxes.
|
gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is
|
||||||
Each box is of the format: x1, y1, x2, y2, angle.
|
represented as (x1, y1, x2, y2, angle).
|
||||||
labels (torch.Tensor): Tensor of shape [M] representing labels.
|
gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
|
(torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
|
||||||
|
Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
detections = torch.rand(100, 7) # 100 sample detections
|
||||||
|
gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
|
||||||
|
gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
|
||||||
|
correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
|
||||||
"""
|
"""
|
||||||
iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
|
iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
|
||||||
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
||||||
|
|
|
||||||
|
|
@ -152,19 +152,34 @@ class PoseValidator(DetectionValidator):
|
||||||
|
|
||||||
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
|
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
|
||||||
"""
|
"""
|
||||||
Return correct prediction matrix.
|
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
|
detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each
|
||||||
Each detection is of the format: x1, y1, x2, y2, conf, class.
|
detection is of the format (x1, y1, x2, y2, conf, class).
|
||||||
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
|
gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each
|
||||||
Each label is of the format: class, x1, y1, x2, y2.
|
box is of the format (x1, y1, x2, y2).
|
||||||
pred_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing predicted keypoints.
|
gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
|
||||||
51 corresponds to 17 keypoints each with 3 values.
|
pred_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing predicted keypoints, where
|
||||||
gt_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing ground truth keypoints.
|
51 corresponds to 17 keypoints each having 3 values.
|
||||||
|
gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Correct prediction matrix of shape [N, 10] for 10 IoU levels.
|
torch.Tensor: A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
|
||||||
|
where N is the number of detections.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
detections = torch.rand(100, 6) # 100 predictions: (x1, y1, x2, y2, conf, class)
|
||||||
|
gt_bboxes = torch.rand(50, 4) # 50 ground truth boxes: (x1, y1, x2, y2)
|
||||||
|
gt_cls = torch.randint(0, 2, (50,)) # 50 ground truth class indices
|
||||||
|
pred_kpts = torch.rand(100, 51) # 100 predicted keypoints
|
||||||
|
gt_kpts = torch.rand(50, 51) # 50 ground truth keypoints
|
||||||
|
correct_preds = _process_batch(detections, gt_bboxes, gt_cls, pred_kpts, gt_kpts)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
`0.53` scale factor used in area computation is referenced from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
|
||||||
"""
|
"""
|
||||||
if pred_kpts is not None and gt_kpts is not None:
|
if pred_kpts is not None and gt_kpts is not None:
|
||||||
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
||||||
|
|
|
||||||
|
|
@ -164,14 +164,34 @@ class SegmentationValidator(DetectionValidator):
|
||||||
|
|
||||||
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
||||||
"""
|
"""
|
||||||
Return correct prediction matrix.
|
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
detections (torch.Tensor): Tensor of shape (N, 6) representing detected bounding boxes and
|
||||||
labels (array[M, 5]), class, x1, y1, x2, y2
|
associated confidence scores and class indices. Each row is of the format [x1, y1, x2, y2, conf, class].
|
||||||
|
gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
|
||||||
|
Each row is of the format [x1, y1, x2, y2].
|
||||||
|
gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
|
||||||
|
pred_masks (torch.Tensor | None): Tensor representing predicted masks, if available. The shape should
|
||||||
|
match the ground truth masks.
|
||||||
|
gt_masks (torch.Tensor | None): Tensor of shape (M, H, W) representing ground truth masks, if available.
|
||||||
|
overlap (bool): Flag indicating if overlapping masks should be considered.
|
||||||
|
masks (bool): Flag indicating if the batch contains mask data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
correct (array[N, 10]), for 10 IoU levels
|
(torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
|
||||||
|
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]])
|
||||||
|
gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]])
|
||||||
|
gt_cls = torch.tensor([1, 0])
|
||||||
|
correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
if masks:
|
if masks:
|
||||||
if overlap:
|
if overlap:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue