Cleanup argument handling in Model class (#4614)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sebastian Stapf <42514241+Wiqzard@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-08-28 23:12:32 +02:00 committed by GitHub
parent 53b4f8c713
commit 7e99804263
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 67 additions and 93 deletions

View file

@ -355,15 +355,15 @@ class DeformableTransformerDecoder(nn.Module):
for i, layer in enumerate(self.layers):
output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
# refine bboxes, (bs, num_queries+num_denoising, 4)
refined_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(refer_bbox))
bbox = bbox_head[i](output)
refined_bbox = torch.sigmoid(bbox + inverse_sigmoid(refer_bbox))
if self.training:
dec_cls.append(score_head[i](output))
if i == 0:
dec_bboxes.append(refined_bbox)
else:
dec_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(last_refined_bbox)))
dec_bboxes.append(torch.sigmoid(bbox + inverse_sigmoid(last_refined_bbox)))
elif i == self.eval_idx:
dec_cls.append(score_head[i](output))
dec_bboxes.append(refined_bbox)