ultralytics 8.0.105 classification hyp fix and new onplot callbacks (#2684)

Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ivan Shcheklein <shcheklein@gmail.com>
This commit is contained in:
Glenn Jocher 2023-05-17 19:10:20 +02:00 committed by GitHub
parent b1119d512e
commit 23fc50641c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
92 changed files with 378 additions and 206 deletions

View file

@ -228,7 +228,7 @@ class Annotator:
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()
def plot_labels(boxes, cls, names=(), save_dir=Path('')):
def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
"""Save and plot image with no axis or spines."""
import pandas as pd
import seaborn as sn
@ -271,8 +271,11 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
for s in ['top', 'right', 'left', 'bottom']:
ax[a].spines[s].set_visible(False)
plt.savefig(save_dir / 'labels.jpg', dpi=200)
fname = save_dir / 'labels.jpg'
plt.savefig(fname, dpi=200)
plt.close()
if on_plot:
on_plot(fname)
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
@ -301,7 +304,8 @@ def plot_images(images,
kpts=np.zeros((0, 51), dtype=np.float32),
paths=None,
fname='images.jpg',
names=None):
names=None,
on_plot=None):
# Plot image grid with labels
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
@ -419,10 +423,12 @@ def plot_images(images,
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
annotator.fromarray(im)
annotator.im.save(fname) # save
if on_plot:
on_plot(fname)
@plt_settings()
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False):
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
"""Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')."""
import pandas as pd
save_dir = Path(file).parent if file else Path(dir)
@ -456,8 +462,11 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
except Exception as e:
LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
ax[1].legend()
fig.savefig(save_dir / 'results.png', dpi=200)
fname = save_dir / 'results.png'
fig.savefig(fname, dpi=200)
plt.close()
if on_plot:
on_plot(fname)
def output_to_target(output, max_det=300):