ultralytics 8.3.30 run TAL on CPU if torch.OutOfMemoryError (#17515)
Co-authored-by: Muhammad Rizwan Munawar <muhammadrizwanmunawar123@gmail.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
1a5c35366e
commit
f43c211ab4
2 changed files with 35 additions and 7 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.3.29"
|
||||
__version__ = "8.3.30"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -58,17 +58,45 @@ class TaskAlignedAssigner(nn.Module):
|
|||
"""
|
||||
self.bs = pd_scores.shape[0]
|
||||
self.n_max_boxes = gt_bboxes.shape[1]
|
||||
device = gt_bboxes.device
|
||||
|
||||
if self.n_max_boxes == 0:
|
||||
device = gt_bboxes.device
|
||||
return (
|
||||
torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
|
||||
torch.zeros_like(pd_bboxes).to(device),
|
||||
torch.zeros_like(pd_scores).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
torch.full_like(pd_scores[..., 0], self.bg_idx),
|
||||
torch.zeros_like(pd_bboxes),
|
||||
torch.zeros_like(pd_scores),
|
||||
torch.zeros_like(pd_scores[..., 0]),
|
||||
torch.zeros_like(pd_scores[..., 0]),
|
||||
)
|
||||
|
||||
try:
|
||||
return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
|
||||
except torch.OutOfMemoryError:
|
||||
# Move tensors to CPU, compute, then move back to original device
|
||||
cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
|
||||
result = self._forward(*cpu_tensors)
|
||||
return tuple(t.to(device) for t in result)
|
||||
|
||||
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
||||
"""
|
||||
Compute the task-aligned assignment. Reference code is available at
|
||||
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
|
||||
|
||||
Args:
|
||||
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
||||
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
||||
anc_points (Tensor): shape(num_total_anchors, 2)
|
||||
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
|
||||
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
|
||||
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
|
||||
|
||||
Returns:
|
||||
target_labels (Tensor): shape(bs, num_total_anchors)
|
||||
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
||||
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
||||
fg_mask (Tensor): shape(bs, num_total_anchors)
|
||||
target_gt_idx (Tensor): shape(bs, num_total_anchors)
|
||||
"""
|
||||
mask_pos, align_metric, overlaps = self.get_pos_mask(
|
||||
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue