Tests improvements (#4616)

Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-08-29 15:38:43 +02:00 committed by GitHub
parent 4bd62a299c
commit 896da0c0a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 19 additions and 15 deletions

View file

@ -122,8 +122,8 @@ class TaskAlignedAssigner(nn.Module):
# Normalize
align_metric *= mask_pos
pos_align_metrics = align_metric.amax(axis=-1, keepdim=True) # b, max_num_obj
pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True) # b, max_num_obj
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
target_scores = target_scores * norm_align_metric