ultralytics 8.0.174 Tuner plots and improvements (#4799)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
dfe6dfb1d2
commit
16ce193d6e
12 changed files with 248 additions and 88 deletions
|
|
@ -635,7 +635,33 @@ SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml'
|
|||
|
||||
|
||||
def colorstr(*input):
|
||||
"""Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')."""
|
||||
"""
|
||||
Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes.
|
||||
See https://en.wikipedia.org/wiki/ANSI_escape_code for more details.
|
||||
|
||||
This function can be called in two ways:
|
||||
- colorstr('color', 'style', 'your string')
|
||||
- colorstr('your string')
|
||||
|
||||
In the second form, 'blue' and 'bold' will be applied by default.
|
||||
|
||||
Args:
|
||||
*input (str): A sequence of strings where the first n-1 strings are color and style arguments,
|
||||
and the last string is the one to be colored.
|
||||
|
||||
Supported Colors and Styles:
|
||||
Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'
|
||||
Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',
|
||||
'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'
|
||||
Misc: 'end', 'bold', 'underline'
|
||||
|
||||
Returns:
|
||||
(str): The input string wrapped with ANSI escape codes for the specified color and style.
|
||||
|
||||
Examples:
|
||||
>>> colorstr('blue', 'bold', 'hello world')
|
||||
>>> '\033[34m\033[1mhello world\033[0m'
|
||||
"""
|
||||
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
||||
colors = {
|
||||
'black': '\033[30m', # basic colors
|
||||
|
|
@ -660,6 +686,24 @@ def colorstr(*input):
|
|||
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
|
||||
|
||||
|
||||
def remove_colorstr(input_string):
|
||||
"""
|
||||
Removes ANSI escape codes from a string, effectively un-coloring it.
|
||||
|
||||
Args:
|
||||
input_string (str): The string to remove color and style from.
|
||||
|
||||
Returns:
|
||||
(str): A new string with all ANSI escape codes removed.
|
||||
|
||||
Examples:
|
||||
>>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
|
||||
>>> 'hello world'
|
||||
"""
|
||||
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
|
||||
return ansi_escape.sub('', input_string)
|
||||
|
||||
|
||||
class TryExcept(contextlib.ContextDecorator):
|
||||
"""YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
|
||||
|
||||
|
|
|
|||
|
|
@ -498,13 +498,23 @@ def plot_images(images,
|
|||
@plt_settings()
|
||||
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
|
||||
"""
|
||||
Plot training results from results CSV file.
|
||||
Plot training results from a results CSV file. The function supports various types of data including segmentation,
|
||||
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
|
||||
|
||||
Args:
|
||||
file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.
|
||||
dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.
|
||||
segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.
|
||||
pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.
|
||||
classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.
|
||||
on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
|
||||
Defaults to None.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.plotting import plot_results
|
||||
|
||||
plot_results('path/to/results.csv')
|
||||
plot_results('path/to/results.csv', segment=True)
|
||||
```
|
||||
"""
|
||||
import pandas as pd
|
||||
|
|
@ -548,6 +558,92 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
|
|||
on_plot(fname)
|
||||
|
||||
|
||||
def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none'):
|
||||
"""
|
||||
Plots a scatter plot with points colored based on a 2D histogram.
|
||||
|
||||
Args:
|
||||
v (array-like): Values for the x-axis.
|
||||
f (array-like): Values for the y-axis.
|
||||
bins (int, optional): Number of bins for the histogram. Defaults to 20.
|
||||
cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'.
|
||||
alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8.
|
||||
edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'.
|
||||
|
||||
Examples:
|
||||
>>> v = np.random.rand(100)
|
||||
>>> f = np.random.rand(100)
|
||||
>>> plt_color_scatter(v, f)
|
||||
"""
|
||||
|
||||
# Calculate 2D histogram and corresponding colors
|
||||
hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
|
||||
colors = [
|
||||
hist[min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
|
||||
min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1)] for i in range(len(v))]
|
||||
|
||||
# Scatter plot
|
||||
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
|
||||
|
||||
|
||||
def plot_tune_results(csv_file='tune_results.csv'):
|
||||
"""
|
||||
Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
|
||||
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
|
||||
|
||||
Args:
|
||||
csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.
|
||||
|
||||
Examples:
|
||||
>>> plot_tune_results('path/to/tune_results.csv')
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
# Scatter plots for each hyperparameter
|
||||
csv_file = Path(csv_file)
|
||||
data = pd.read_csv(csv_file)
|
||||
num_metrics_columns = 1
|
||||
keys = [x.strip() for x in data.columns][num_metrics_columns:]
|
||||
x = data.values
|
||||
fitness = x[:, 0] # fitness
|
||||
j = np.argmax(fitness) # max fitness index
|
||||
n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
|
||||
plt.figure(figsize=(10, 10), tight_layout=True)
|
||||
for i, k in enumerate(keys):
|
||||
v = x[:, i + num_metrics_columns]
|
||||
mu = v[j] # best single result
|
||||
plt.subplot(n, n, i + 1)
|
||||
plt_color_scatter(v, fitness, cmap='viridis', alpha=.8, edgecolors='none')
|
||||
plt.plot(mu, fitness.max(), 'k+', markersize=15)
|
||||
plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
|
||||
plt.tick_params(axis='both', labelsize=8) # Set axis label size to 8
|
||||
if i % n != 0:
|
||||
plt.yticks([])
|
||||
|
||||
file = csv_file.with_name('tune_scatter_plots.png') # filename
|
||||
plt.savefig(file, dpi=200)
|
||||
plt.close()
|
||||
LOGGER.info(f'Saved {file}')
|
||||
|
||||
# Fitness vs iteration
|
||||
x = range(1, len(fitness) + 1)
|
||||
plt.figure(figsize=(10, 6), tight_layout=True)
|
||||
plt.plot(x, fitness, marker='o', linestyle='none', label='fitness')
|
||||
plt.plot(x, gaussian_filter1d(fitness, sigma=3), ':', label='smoothed', linewidth=2) # smoothing line
|
||||
plt.title('Fitness vs Iteration')
|
||||
plt.xlabel('Iteration')
|
||||
plt.ylabel('Fitness')
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
file = csv_file.with_name('tune_fitness.png') # filename
|
||||
plt.savefig(file, dpi=200)
|
||||
plt.close()
|
||||
LOGGER.info(f'Saved {file}')
|
||||
|
||||
|
||||
def output_to_target(output, max_det=300):
|
||||
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
||||
targets = []
|
||||
|
|
|
|||
|
|
@ -395,12 +395,6 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
|
|||
strip_optimizer(f)
|
||||
```
|
||||
"""
|
||||
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
|
||||
try:
|
||||
import dill as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
|
||||
x = torch.load(f, map_location=torch.device('cpu'))
|
||||
if 'model' not in x:
|
||||
LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.')
|
||||
|
|
@ -419,8 +413,8 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
|
|||
p.requires_grad = False
|
||||
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
||||
# x['model'].args = x['train_args']
|
||||
torch.save(x, s or f, pickle_module=pickle)
|
||||
mb = os.path.getsize(s or f) / 1E6 # filesize
|
||||
torch.save(x, s or f)
|
||||
mb = os.path.getsize(s or f) / 1E6 # file size
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue