ultralytics 8.0.75 fixes and updates (#1967)

Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Jonathan Rayner <jonathan.j.rayner@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-04-13 02:04:10 +02:00 committed by GitHub
parent e5cb35edfc
commit 48c4483795
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 128 additions and 98 deletions

View file

@ -11,7 +11,7 @@ import numpy as np
import torch
import torch.nn as nn
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept, plt_settings
OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
@ -234,6 +234,7 @@ class ConfusionMatrix:
return tp[:-1], fp[:-1] # remove background class
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
@plt_settings()
def plot(self, normalize=True, save_dir='', names=()):
import seaborn as sn
@ -277,6 +278,7 @@ def smooth(y, f=0.05):
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
@plt_settings()
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
# Precision-recall curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
@ -299,6 +301,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
plt.close(fig)
@plt_settings()
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)