Update ViT ops.py to torch.long (#3508)
Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
parent
8a11eda4a9
commit
31b46bf2b4
2 changed files with 4 additions and 4 deletions
|
|
@ -284,11 +284,11 @@ class RTDETRDetectionLoss(DETRLoss):
|
|||
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||
for i, num_gt in enumerate(gt_groups):
|
||||
if num_gt > 0:
|
||||
gt_idx = torch.arange(end=num_gt, dtype=torch.int32) + idx_groups[i]
|
||||
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
|
||||
gt_idx = gt_idx.repeat(dn_num_group)
|
||||
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
|
||||
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
|
||||
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
||||
else:
|
||||
dn_match_indices.append((torch.zeros([0], dtype=torch.int32), torch.zeros([0], dtype=torch.int32)))
|
||||
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
|
||||
return dn_match_indices
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue