ultralytics 8.0.174 Tuner plots and improvements (#4799)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-09-10 03:27:23 +02:00 committed by GitHub
parent dfe6dfb1d2
commit 16ce193d6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 248 additions and 88 deletions

View file

@ -498,13 +498,23 @@ def plot_images(images,
@plt_settings()
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
"""
Plot training results from results CSV file.
Plot training results from a results CSV file. The function supports various types of data including segmentation,
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
Args:
file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.
dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.
segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.
pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.
classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.
on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
Defaults to None.
Example:
```python
from ultralytics.utils.plotting import plot_results
plot_results('path/to/results.csv')
plot_results('path/to/results.csv', segment=True)
```
"""
import pandas as pd
@ -548,6 +558,92 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
on_plot(fname)
def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none'):
"""
Plots a scatter plot with points colored based on a 2D histogram.
Args:
v (array-like): Values for the x-axis.
f (array-like): Values for the y-axis.
bins (int, optional): Number of bins for the histogram. Defaults to 20.
cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'.
alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8.
edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'.
Examples:
>>> v = np.random.rand(100)
>>> f = np.random.rand(100)
>>> plt_color_scatter(v, f)
"""
# Calculate 2D histogram and corresponding colors
hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
colors = [
hist[min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1)] for i in range(len(v))]
# Scatter plot
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
def plot_tune_results(csv_file='tune_results.csv'):
"""
Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
Args:
csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.
Examples:
>>> plot_tune_results('path/to/tune_results.csv')
"""
import pandas as pd
from scipy.ndimage import gaussian_filter1d
# Scatter plots for each hyperparameter
csv_file = Path(csv_file)
data = pd.read_csv(csv_file)
num_metrics_columns = 1
keys = [x.strip() for x in data.columns][num_metrics_columns:]
x = data.values
fitness = x[:, 0] # fitness
j = np.argmax(fitness) # max fitness index
n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
plt.figure(figsize=(10, 10), tight_layout=True)
for i, k in enumerate(keys):
v = x[:, i + num_metrics_columns]
mu = v[j] # best single result
plt.subplot(n, n, i + 1)
plt_color_scatter(v, fitness, cmap='viridis', alpha=.8, edgecolors='none')
plt.plot(mu, fitness.max(), 'k+', markersize=15)
plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
plt.tick_params(axis='both', labelsize=8) # Set axis label size to 8
if i % n != 0:
plt.yticks([])
file = csv_file.with_name('tune_scatter_plots.png') # filename
plt.savefig(file, dpi=200)
plt.close()
LOGGER.info(f'Saved {file}')
# Fitness vs iteration
x = range(1, len(fitness) + 1)
plt.figure(figsize=(10, 6), tight_layout=True)
plt.plot(x, fitness, marker='o', linestyle='none', label='fitness')
plt.plot(x, gaussian_filter1d(fitness, sigma=3), ':', label='smoothed', linewidth=2) # smoothing line
plt.title('Fitness vs Iteration')
plt.xlabel('Iteration')
plt.ylabel('Fitness')
plt.grid(True)
plt.legend()
file = csv_file.with_name('tune_fitness.png') # filename
plt.savefig(file, dpi=200)
plt.close()
LOGGER.info(f'Saved {file}')
def output_to_target(output, max_det=300):
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
targets = []