[RTDETR]Fix val loss (#3280)
This commit is contained in:
parent
d8701b42ca
commit
9d1e5567de
4 changed files with 12 additions and 9 deletions
|
|
@ -89,9 +89,8 @@ class RTDETRValidator(DetectionValidator):
|
|||
|
||||
def postprocess(self, preds):
|
||||
"""Apply Non-maximum suppression to prediction outputs."""
|
||||
bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc)
|
||||
bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0) # (bs, 300, 4)
|
||||
bs = len(bboxes)
|
||||
bs, _, nd = preds[0].shape
|
||||
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
|
||||
for i, bbox in enumerate(bboxes): # (300, 4)
|
||||
bbox = ops.xywh2xyxy(bbox)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue