Simplify metrics calculation (#9338)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Kayzwer 2024-03-28 03:50:57 +08:00 committed by GitHub
parent 1325889305
commit 978a3ca61c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -167,7 +167,7 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17)
sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )
kpt_mask = kpt1[..., 2] != 0 # (N, 17)
e = d / (2 * sigma).pow(2) / (area[:, None, None] + eps) / 2 # from cocoeval
e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2) # from cocoeval
# e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula
return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
@ -402,7 +402,7 @@ class ConfusionMatrix:
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
sn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
ticklabels = (list(names) + ["background"]) if labels else "auto"
with warnings.catch_warnings():