Add docformatter to pre-commit (#5279)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
This commit is contained in:
parent
c7aa83da31
commit
7517667a33
90 changed files with 1396 additions and 497 deletions
|
|
@ -117,6 +117,7 @@ class TQDM(tqdm_original):
|
|||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize custom Ultralytics tqdm class with different default arguments."""
|
||||
# Set new default values (these can still be overridden when calling TQDM)
|
||||
kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed
|
||||
kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed
|
||||
|
|
@ -124,8 +125,7 @@ class TQDM(tqdm_original):
|
|||
|
||||
|
||||
class SimpleClass:
|
||||
"""
|
||||
Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
|
||||
"""Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
|
||||
access methods for easier debugging and usage.
|
||||
"""
|
||||
|
||||
|
|
@ -154,8 +154,7 @@ class SimpleClass:
|
|||
|
||||
|
||||
class IterableSimpleNamespace(SimpleNamespace):
|
||||
"""
|
||||
Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
|
||||
"""Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
|
||||
enables usage with dict() and for loops.
|
||||
"""
|
||||
|
||||
|
|
@ -256,8 +255,8 @@ class EmojiFilter(logging.Filter):
|
|||
"""
|
||||
A custom logging filter class for removing emojis in log messages.
|
||||
|
||||
This filter is particularly useful for ensuring compatibility with Windows terminals
|
||||
that may not support the display of emojis in log messages.
|
||||
This filter is particularly useful for ensuring compatibility with Windows terminals that may not support the
|
||||
display of emojis in log messages.
|
||||
"""
|
||||
|
||||
def filter(self, record):
|
||||
|
|
@ -275,9 +274,9 @@ if WINDOWS: # emoji-safe logging
|
|||
|
||||
class ThreadingLocked:
|
||||
"""
|
||||
A decorator class for ensuring thread-safe execution of a function or method.
|
||||
This class can be used as a decorator to make sure that if the decorated function
|
||||
is called from multiple threads, only one thread at a time will be able to execute the function.
|
||||
A decorator class for ensuring thread-safe execution of a function or method. This class can be used as a decorator
|
||||
to make sure that if the decorated function is called from multiple threads, only one thread at a time will be able
|
||||
to execute the function.
|
||||
|
||||
Attributes:
|
||||
lock (threading.Lock): A lock object used to manage access to the decorated function.
|
||||
|
|
@ -294,13 +293,16 @@ class ThreadingLocked:
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the decorator class for thread-safe execution of a function or method."""
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def __call__(self, f):
|
||||
"""Run thread-safe execution of function or method."""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
"""Applies thread-safety to the decorated function or method."""
|
||||
with self.lock:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
|
|
@ -424,8 +426,7 @@ def is_kaggle():
|
|||
|
||||
def is_jupyter():
|
||||
"""
|
||||
Check if the current script is running inside a Jupyter Notebook.
|
||||
Verified on Colab, Jupyterlab, Kaggle, Paperspace.
|
||||
Check if the current script is running inside a Jupyter Notebook. Verified on Colab, Jupyterlab, Kaggle, Paperspace.
|
||||
|
||||
Returns:
|
||||
(bool): True if running inside a Jupyter Notebook, False otherwise.
|
||||
|
|
@ -529,8 +530,8 @@ def is_github_actions_ci() -> bool:
|
|||
|
||||
def is_git_dir():
|
||||
"""
|
||||
Determines whether the current file is part of a git repository.
|
||||
If the current file is not part of a git repository, returns None.
|
||||
Determines whether the current file is part of a git repository. If the current file is not part of a git
|
||||
repository, returns None.
|
||||
|
||||
Returns:
|
||||
(bool): True if current file is part of a git repository.
|
||||
|
|
@ -540,8 +541,8 @@ def is_git_dir():
|
|||
|
||||
def get_git_dir():
|
||||
"""
|
||||
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
|
||||
If the current file is not part of a git repository, returns None.
|
||||
Determines whether the current file is part of a git repository and if so, returns the repository root directory. If
|
||||
the current file is not part of a git repository, returns None.
|
||||
|
||||
Returns:
|
||||
(Path | None): Git root directory if found or None if not found.
|
||||
|
|
@ -578,7 +579,8 @@ def get_git_branch():
|
|||
|
||||
|
||||
def get_default_args(func):
|
||||
"""Returns a dictionary of default arguments for a function.
|
||||
"""
|
||||
Returns a dictionary of default arguments for a function.
|
||||
|
||||
Args:
|
||||
func (callable): The function to inspect.
|
||||
|
|
@ -710,7 +712,11 @@ def remove_colorstr(input_string):
|
|||
|
||||
|
||||
class TryExcept(contextlib.ContextDecorator):
|
||||
"""YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
|
||||
"""
|
||||
YOLOv8 TryExcept class.
|
||||
|
||||
Use as @TryExcept() decorator or 'with TryExcept():' context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, msg='', verbose=True):
|
||||
"""Initialize TryExcept class with optional message and verbosity settings."""
|
||||
|
|
@ -729,7 +735,11 @@ class TryExcept(contextlib.ContextDecorator):
|
|||
|
||||
|
||||
def threaded(func):
|
||||
"""Multi-threads a target function and returns thread. Usage: @threaded decorator."""
|
||||
"""
|
||||
Multi-threads a target function and returns thread.
|
||||
|
||||
Use as @threaded decorator.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Multi-threads a given function and returns the thread."""
|
||||
|
|
@ -824,6 +834,9 @@ class SettingsManager(dict):
|
|||
"""
|
||||
|
||||
def __init__(self, file=SETTINGS_YAML, version='0.0.4'):
|
||||
"""Initialize the SettingsManager with default settings, load and validate current settings from the YAML
|
||||
file.
|
||||
"""
|
||||
import copy
|
||||
import hashlib
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.
|
||||
"""
|
||||
"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Benchmark a YOLO model formats for speed and accuracy
|
||||
Benchmark a YOLO model formats for speed and accuracy.
|
||||
|
||||
Usage:
|
||||
from ultralytics.utils.benchmarks import ProfileModels, benchmark
|
||||
|
|
@ -194,6 +194,7 @@ class ProfileModels:
|
|||
self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def profile(self):
|
||||
"""Logs the benchmarking results of a model, checks metrics against floor and returns the results."""
|
||||
files = self.get_files()
|
||||
|
||||
if not files:
|
||||
|
|
@ -235,6 +236,7 @@ class ProfileModels:
|
|||
return output
|
||||
|
||||
def get_files(self):
|
||||
"""Returns a list of paths for all relevant model files given by the user."""
|
||||
files = []
|
||||
for path in self.paths:
|
||||
path = Path(path)
|
||||
|
|
@ -250,10 +252,14 @@ class ProfileModels:
|
|||
return [Path(file) for file in sorted(files)]
|
||||
|
||||
def get_onnx_model_info(self, onnx_file: str):
|
||||
"""Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model
|
||||
file.
|
||||
"""
|
||||
# return (num_layers, num_params, num_gradients, num_flops)
|
||||
return 0.0, 0.0, 0.0, 0.0
|
||||
|
||||
def iterative_sigma_clipping(self, data, sigma=2, max_iters=3):
|
||||
"""Applies an iterative sigma clipping algorithm to the given data times number of iterations."""
|
||||
data = np.array(data)
|
||||
for _ in range(max_iters):
|
||||
mean, std = np.mean(data), np.std(data)
|
||||
|
|
@ -264,6 +270,7 @@ class ProfileModels:
|
|||
return data
|
||||
|
||||
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
|
||||
"""Profiles the TensorRT model, measuring average run time and standard deviation among runs."""
|
||||
if not self.trt or not Path(engine_file).is_file():
|
||||
return 0.0, 0.0
|
||||
|
||||
|
|
@ -292,6 +299,9 @@ class ProfileModels:
|
|||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
|
||||
"""Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
|
||||
times.
|
||||
"""
|
||||
check_requirements('onnxruntime')
|
||||
import onnxruntime as ort
|
||||
|
||||
|
|
@ -344,10 +354,12 @@ class ProfileModels:
|
|||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
|
||||
"""Generates a formatted string for a table row that includes model performance and metric details."""
|
||||
layers, params, gradients, flops = model_info
|
||||
return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |'
|
||||
|
||||
def generate_results_dict(self, model_name, t_onnx, t_engine, model_info):
|
||||
"""Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
|
||||
layers, params, gradients, flops = model_info
|
||||
return {
|
||||
'model/name': model_name,
|
||||
|
|
@ -357,6 +369,7 @@ class ProfileModels:
|
|||
'model/speed_TensorRT(ms)': round(t_engine[0], 3)}
|
||||
|
||||
def print_table(self, table_rows):
|
||||
"""Formats and prints a comparison table for different models with given statistics and performance data."""
|
||||
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
|
||||
header = f'| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |'
|
||||
separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Base callbacks
|
||||
"""
|
||||
"""Base callbacks."""
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
|
|
|
|||
|
|
@ -26,31 +26,38 @@ except (ImportError, AssertionError):
|
|||
|
||||
|
||||
def _get_comet_mode():
|
||||
"""Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
|
||||
return os.getenv('COMET_MODE', 'online')
|
||||
|
||||
|
||||
def _get_comet_model_name():
|
||||
"""Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
|
||||
return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
|
||||
|
||||
|
||||
def _get_eval_batch_logging_interval():
|
||||
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
|
||||
return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
|
||||
|
||||
|
||||
def _get_max_image_predictions_to_log():
|
||||
"""Get the maximum number of image predictions to log from the environment variables."""
|
||||
return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
|
||||
|
||||
|
||||
def _scale_confidence_score(score):
|
||||
"""Scales the given confidence score by a factor specified in an environment variable."""
|
||||
scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
|
||||
return score * scale
|
||||
|
||||
|
||||
def _should_log_confusion_matrix():
|
||||
"""Determines if the confusion matrix should be logged based on the environment variable settings."""
|
||||
return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true'
|
||||
|
||||
|
||||
def _should_log_image_predictions():
|
||||
"""Determines whether to log image predictions based on a specified environment variable."""
|
||||
return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
|
||||
|
||||
|
||||
|
|
@ -104,9 +111,10 @@ def _fetch_trainer_metadata(trainer):
|
|||
|
||||
|
||||
def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
|
||||
"""YOLOv8 resizes images during training and the label values
|
||||
are normalized based on this resized shape. This function rescales the
|
||||
bounding box labels to the original image shape.
|
||||
"""
|
||||
YOLOv8 resizes images during training and the label values are normalized based on this resized shape.
|
||||
|
||||
This function rescales the bounding box labels to the original image shape.
|
||||
"""
|
||||
|
||||
resized_image_height, resized_image_width = resized_image_shape
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ except (ImportError, AssertionError, TypeError):
|
|||
|
||||
|
||||
def _log_images(path, prefix=''):
|
||||
"""Logs images at specified path with an optional prefix using DVCLive."""
|
||||
if live:
|
||||
name = path.name
|
||||
|
||||
|
|
@ -38,6 +39,7 @@ def _log_images(path, prefix=''):
|
|||
|
||||
|
||||
def _log_plots(plots, prefix=''):
|
||||
"""Logs plot images for training progress if they have not been previously processed."""
|
||||
for name, params in plots.items():
|
||||
timestamp = params['timestamp']
|
||||
if _processed_plots.get(name) != timestamp:
|
||||
|
|
@ -46,6 +48,7 @@ def _log_plots(plots, prefix=''):
|
|||
|
||||
|
||||
def _log_confusion_matrix(validator):
|
||||
"""Logs the confusion matrix for the given validator using DVCLive."""
|
||||
targets = []
|
||||
preds = []
|
||||
matrix = validator.confusion_matrix.matrix
|
||||
|
|
@ -62,6 +65,7 @@ def _log_confusion_matrix(validator):
|
|||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initializes DVCLive logger for training metadata during pre-training routine."""
|
||||
try:
|
||||
global live
|
||||
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
|
||||
|
|
@ -71,20 +75,24 @@ def on_pretrain_routine_start(trainer):
|
|||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Logs plots related to the training process at the end of the pretraining routine."""
|
||||
_log_plots(trainer.plots, 'train')
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Logs the training parameters if DVCLive logging is active."""
|
||||
if live:
|
||||
live.log_params(trainer.args)
|
||||
|
||||
|
||||
def on_train_epoch_start(trainer):
|
||||
"""Sets the global variable _training_epoch value to True at the start of training each epoch."""
|
||||
global _training_epoch
|
||||
_training_epoch = True
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
|
||||
global _training_epoch
|
||||
if live and _training_epoch:
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
|
||||
|
|
@ -104,6 +112,7 @@ def on_fit_epoch_end(trainer):
|
|||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
|
||||
if live:
|
||||
# At the end log the best metrics. It runs validator on the best model internally.
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
|
||||
|
|
|
|||
|
|
@ -31,14 +31,13 @@ def _log_images(imgs_dict, group=''):
|
|||
|
||||
|
||||
def _log_plot(title, plot_path):
|
||||
"""Log plots to the NeptuneAI experiment logger."""
|
||||
"""
|
||||
Log image as plot in the plot section of NeptuneAI
|
||||
Log plots to the NeptuneAI experiment logger.
|
||||
|
||||
arguments:
|
||||
title (str) Title of the plot
|
||||
plot_path (PosixPath or str) Path to the saved image file
|
||||
"""
|
||||
Args:
|
||||
title (str): Title of the plot.
|
||||
plot_path (PosixPath | str): Path to the saved image file.
|
||||
"""
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ except (ImportError, AssertionError):
|
|||
|
||||
|
||||
def _log_plots(plots, step):
|
||||
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
|
||||
for name, params in plots.items():
|
||||
timestamp = params['timestamp']
|
||||
if _processed_plots.get(name) != timestamp:
|
||||
|
|
|
|||
|
|
@ -64,8 +64,8 @@ def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
|
|||
|
||||
def parse_version(version='0.0.0') -> tuple:
|
||||
"""
|
||||
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.
|
||||
This function replaces deprecated 'pkg_resources.parse_version(v)'
|
||||
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
|
||||
function replaces deprecated 'pkg_resources.parse_version(v)'.
|
||||
|
||||
Args:
|
||||
version (str): Version string, i.e. '2.0.1+cpu'
|
||||
|
|
@ -372,8 +372,10 @@ def check_torchvision():
|
|||
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
|
||||
|
||||
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
|
||||
to the provided compatibility table based on https://github.com/pytorch/vision#installation. The
|
||||
compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
|
||||
to the provided compatibility table based on:
|
||||
https://github.com/pytorch/vision#installation.
|
||||
|
||||
The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
|
||||
Torchvision versions.
|
||||
"""
|
||||
|
||||
|
|
@ -527,9 +529,9 @@ def collect_system_info():
|
|||
|
||||
def check_amp(model):
|
||||
"""
|
||||
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model.
|
||||
If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
|
||||
results, so AMP will be disabled during training.
|
||||
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
|
||||
fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
|
||||
be disabled during training.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A YOLOv8 model instance.
|
||||
|
|
@ -606,7 +608,8 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
|||
|
||||
|
||||
def cuda_device_count() -> int:
|
||||
"""Get the number of NVIDIA GPUs available in the environment.
|
||||
"""
|
||||
Get the number of NVIDIA GPUs available in the environment.
|
||||
|
||||
Returns:
|
||||
(int): The number of NVIDIA GPUs available.
|
||||
|
|
@ -626,7 +629,8 @@ def cuda_device_count() -> int:
|
|||
|
||||
|
||||
def cuda_is_available() -> bool:
|
||||
"""Check if CUDA is available in the environment.
|
||||
"""
|
||||
Check if CUDA is available in the environment.
|
||||
|
||||
Returns:
|
||||
(bool): True if one or more NVIDIA GPUs are available, False otherwise.
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ from .torch_utils import TORCH_1_9
|
|||
|
||||
|
||||
def find_free_network_port() -> int:
|
||||
"""Finds a free port on localhost.
|
||||
"""
|
||||
Finds a free port on localhost.
|
||||
|
||||
It is useful in single-node training when we don't want to connect to a real main node but have to set the
|
||||
`MASTER_PORT` environment variable.
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
|
|||
|
||||
def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), progress=True):
|
||||
"""
|
||||
Zips the contents of a directory, excluding files containing strings in the exclude list.
|
||||
The resulting zip file is named after the directory and placed alongside it.
|
||||
Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
|
||||
named after the directory and placed alongside it.
|
||||
|
||||
Args:
|
||||
directory (str | Path): The path to the directory to be zipped.
|
||||
|
|
@ -341,7 +341,11 @@ def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
|
|||
|
||||
|
||||
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
"""Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc."""
|
||||
"""
|
||||
Attempt file download from GitHub release assets if not found locally.
|
||||
|
||||
release = 'latest', 'v6.2', etc.
|
||||
"""
|
||||
from ultralytics.utils import SETTINGS # scoped for circular import
|
||||
|
||||
# YOLOv3/5u updates
|
||||
|
|
|
|||
|
|
@ -30,9 +30,9 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
|||
@contextmanager
|
||||
def spaces_in_path(path):
|
||||
"""
|
||||
Context manager to handle paths with spaces in their names.
|
||||
If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path,
|
||||
executes the context code block, then copies the file/directory back to its original location.
|
||||
Context manager to handle paths with spaces in their names. If a path contains spaces, it replaces them with
|
||||
underscores, copies the file/directory to the new path, executes the context code block, then copies the
|
||||
file/directory back to its original location.
|
||||
|
||||
Args:
|
||||
path (str | Path): The original path.
|
||||
|
|
|
|||
|
|
@ -32,9 +32,14 @@ __all__ = 'Bboxes', # tuple or list
|
|||
|
||||
|
||||
class Bboxes:
|
||||
"""Bounding Boxes class. Only numpy variables are supported."""
|
||||
"""
|
||||
Bounding Boxes class.
|
||||
|
||||
Only numpy variables are supported.
|
||||
"""
|
||||
|
||||
def __init__(self, bboxes, format='xyxy') -> None:
|
||||
"""Initializes the Bboxes class with bounding box data in a specified format."""
|
||||
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
|
||||
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
|
||||
assert bboxes.ndim == 2
|
||||
|
|
@ -194,7 +199,7 @@ class Instances:
|
|||
return self._bboxes.areas()
|
||||
|
||||
def scale(self, scale_w, scale_h, bbox_only=False):
|
||||
"""this might be similar with denormalize func but without normalized sign."""
|
||||
"""This might be similar with denormalize func but without normalized sign."""
|
||||
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
|
||||
if bbox_only:
|
||||
return
|
||||
|
|
@ -307,7 +312,11 @@ class Instances:
|
|||
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
|
||||
|
||||
def remove_zero_area_boxes(self):
|
||||
"""Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height. This removes them."""
|
||||
"""
|
||||
Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.
|
||||
|
||||
This removes them.
|
||||
"""
|
||||
good = self.bbox_areas > 0
|
||||
if not all(good):
|
||||
self._bboxes = self._bboxes[good]
|
||||
|
|
|
|||
|
|
@ -13,7 +13,11 @@ from .tal import bbox2dist
|
|||
|
||||
|
||||
class VarifocalLoss(nn.Module):
|
||||
"""Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
|
||||
"""
|
||||
Varifocal loss by Zhang et al.
|
||||
|
||||
https://arxiv.org/abs/2008.13367.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the VarifocalLoss class."""
|
||||
|
|
@ -33,6 +37,7 @@ class FocalLoss(nn.Module):
|
|||
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
|
||||
|
||||
def __init__(self, ):
|
||||
"""Initializer for FocalLoss class with no parameters."""
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -93,6 +98,7 @@ class KeypointLoss(nn.Module):
|
|||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, sigmas) -> None:
|
||||
"""Initialize the KeypointLoss class."""
|
||||
super().__init__()
|
||||
self.sigmas = sigmas
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Model validation metrics
|
||||
"""
|
||||
"""Model validation metrics."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
|
|
@ -195,7 +193,7 @@ class ConfusionMatrix:
|
|||
|
||||
def process_cls_preds(self, preds, targets):
|
||||
"""
|
||||
Update confusion matrix for classification task
|
||||
Update confusion matrix for classification task.
|
||||
|
||||
Args:
|
||||
preds (Array[N, min(nc,5)]): Predicted class labels.
|
||||
|
|
@ -308,9 +306,7 @@ class ConfusionMatrix:
|
|||
on_plot(plot_fname)
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
Print the confusion matrix to the console.
|
||||
"""
|
||||
"""Print the confusion matrix to the console."""
|
||||
for i in range(self.nc + 1):
|
||||
LOGGER.info(' '.join(map(str, self.matrix[i])))
|
||||
|
||||
|
|
@ -440,7 +436,6 @@ def ap_per_class(tp,
|
|||
f1 (np.ndarray): F1-score values at each confidence threshold.
|
||||
ap (np.ndarray): Average precision for each class at different IoU thresholds.
|
||||
unique_classes (np.ndarray): An array of unique classes that have data.
|
||||
|
||||
"""
|
||||
|
||||
# Sort by objectness
|
||||
|
|
@ -498,32 +493,33 @@ def ap_per_class(tp,
|
|||
|
||||
class Metric(SimpleClass):
|
||||
"""
|
||||
Class for computing evaluation metrics for YOLOv8 model.
|
||||
Class for computing evaluation metrics for YOLOv8 model.
|
||||
|
||||
Attributes:
|
||||
p (list): Precision for each class. Shape: (nc,).
|
||||
r (list): Recall for each class. Shape: (nc,).
|
||||
f1 (list): F1 score for each class. Shape: (nc,).
|
||||
all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
|
||||
ap_class_index (list): Index of class for each AP score. Shape: (nc,).
|
||||
nc (int): Number of classes.
|
||||
Attributes:
|
||||
p (list): Precision for each class. Shape: (nc,).
|
||||
r (list): Recall for each class. Shape: (nc,).
|
||||
f1 (list): F1 score for each class. Shape: (nc,).
|
||||
all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
|
||||
ap_class_index (list): Index of class for each AP score. Shape: (nc,).
|
||||
nc (int): Number of classes.
|
||||
|
||||
Methods:
|
||||
ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
|
||||
ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
|
||||
mp(): Mean precision of all classes. Returns: Float.
|
||||
mr(): Mean recall of all classes. Returns: Float.
|
||||
map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.
|
||||
map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.
|
||||
map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.
|
||||
mean_results(): Mean of results, returns mp, mr, map50, map.
|
||||
class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].
|
||||
maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).
|
||||
fitness(): Model fitness as a weighted combination of metrics. Returns: Float.
|
||||
update(results): Update metric attributes with new evaluation results.
|
||||
"""
|
||||
Methods:
|
||||
ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
|
||||
ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
|
||||
mp(): Mean precision of all classes. Returns: Float.
|
||||
mr(): Mean recall of all classes. Returns: Float.
|
||||
map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.
|
||||
map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.
|
||||
map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.
|
||||
mean_results(): Mean of results, returns mp, mr, map50, map.
|
||||
class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].
|
||||
maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).
|
||||
fitness(): Model fitness as a weighted combination of metrics. Returns: Float.
|
||||
update(results): Update metric attributes with new evaluation results.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initializes a Metric instance for computing evaluation metrics for the YOLOv8 model."""
|
||||
self.p = [] # (nc, )
|
||||
self.r = [] # (nc, )
|
||||
self.f1 = [] # (nc, )
|
||||
|
|
@ -606,12 +602,12 @@ class Metric(SimpleClass):
|
|||
return [self.mp, self.mr, self.map50, self.map]
|
||||
|
||||
def class_result(self, i):
|
||||
"""class-aware result, return p[i], r[i], ap50[i], ap[i]."""
|
||||
"""Class-aware result, return p[i], r[i], ap50[i], ap[i]."""
|
||||
return self.p[i], self.r[i], self.ap50[i], self.ap[i]
|
||||
|
||||
@property
|
||||
def maps(self):
|
||||
"""mAP of each class."""
|
||||
"""MAP of each class."""
|
||||
maps = np.zeros(self.nc) + self.map
|
||||
for i, c in enumerate(self.ap_class_index):
|
||||
maps[c] = self.ap[i]
|
||||
|
|
@ -672,6 +668,7 @@ class DetMetrics(SimpleClass):
|
|||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
"""Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
self.on_plot = on_plot
|
||||
|
|
@ -756,6 +753,7 @@ class SegmentMetrics(SimpleClass):
|
|||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
"""Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names."""
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
self.on_plot = on_plot
|
||||
|
|
@ -865,6 +863,7 @@ class PoseMetrics(SegmentMetrics):
|
|||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
"""Initialize the PoseMetrics class with directory path, class names, and plotting options."""
|
||||
super().__init__(save_dir, plot, names)
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
|
|
@ -954,6 +953,7 @@ class ClassifyMetrics(SimpleClass):
|
|||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize a ClassifyMetrics instance."""
|
||||
self.top1 = 0
|
||||
self.top5 = 0
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ class Profile(contextlib.ContextDecorator):
|
|||
self.t += self.dt # accumulate dt
|
||||
|
||||
def __str__(self):
|
||||
"""Returns a human-readable string representing the accumulated elapsed time in the profiler."""
|
||||
return f'Elapsed time is {self.t} s'
|
||||
|
||||
def time(self):
|
||||
|
|
@ -303,7 +304,7 @@ def clip_coords(coords, shape):
|
|||
|
||||
def scale_image(masks, im0_shape, ratio_pad=None):
|
||||
"""
|
||||
Takes a mask, and resizes it to the original image size
|
||||
Takes a mask, and resizes it to the original image size.
|
||||
|
||||
Args:
|
||||
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
|
||||
|
|
@ -403,8 +404,8 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
|||
|
||||
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
||||
"""
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format.
|
||||
x, y, width and height are normalized to image dimensions
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
|
||||
width and height are normalized to image dimensions.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
|
|
@ -445,7 +446,7 @@ def xywh2ltwh(x):
|
|||
|
||||
def xyxy2ltwh(x):
|
||||
"""
|
||||
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
|
||||
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
|
||||
|
|
@ -461,7 +462,7 @@ def xyxy2ltwh(x):
|
|||
|
||||
def ltwh2xywh(x):
|
||||
"""
|
||||
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
|
||||
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): the input tensor
|
||||
|
|
@ -544,7 +545,7 @@ def xywhr2xyxyxyxy(center):
|
|||
|
||||
def ltwh2xyxy(x):
|
||||
"""
|
||||
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): the input image
|
||||
|
|
@ -616,8 +617,8 @@ def crop_mask(masks, boxes):
|
|||
|
||||
def process_mask_upsample(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
|
||||
quality but is slower.
|
||||
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher quality
|
||||
but is slower.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
|
|
@ -713,7 +714,7 @@ def scale_masks(masks, shape, padding=True):
|
|||
|
||||
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
|
||||
"""
|
||||
Rescale segment coordinates (xy) from img1_shape to img0_shape
|
||||
Rescale segment coordinates (xy) from img1_shape to img0_shape.
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the coords are from.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Monkey patches to update/extend functionality of existing functions
|
||||
"""
|
||||
"""Monkey patches to update/extend functionality of existing functions."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -14,7 +12,8 @@ _imshow = cv2.imshow # copy to avoid recursion errors
|
|||
|
||||
|
||||
def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
|
||||
"""Read an image from a file.
|
||||
"""
|
||||
Read an image from a file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the file to read.
|
||||
|
|
@ -27,7 +26,8 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
|
|||
|
||||
|
||||
def imwrite(filename: str, img: np.ndarray, params=None):
|
||||
"""Write an image to a file.
|
||||
"""
|
||||
Write an image to a file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the file to write.
|
||||
|
|
@ -45,7 +45,8 @@ def imwrite(filename: str, img: np.ndarray, params=None):
|
|||
|
||||
|
||||
def imshow(winname: str, mat: np.ndarray):
|
||||
"""Displays an image in the specified window.
|
||||
"""
|
||||
Displays an image in the specified window.
|
||||
|
||||
Args:
|
||||
winname (str): Name of the window.
|
||||
|
|
@ -59,7 +60,8 @@ _torch_save = torch.save # copy to avoid recursion errors
|
|||
|
||||
|
||||
def torch_save(*args, **kwargs):
|
||||
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this.
|
||||
"""
|
||||
Use dill (if exists) to serialize the lambda functions where pickle does not do this.
|
||||
|
||||
Args:
|
||||
*args (tuple): Positional arguments to pass to torch.save.
|
||||
|
|
|
|||
|
|
@ -316,7 +316,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
|
|||
|
||||
|
||||
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
||||
"""Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
||||
"""
|
||||
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
||||
|
||||
This function takes a bounding box and an image, and then saves a cropped portion of the image according
|
||||
to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
|
||||
|
|
|
|||
|
|
@ -205,7 +205,11 @@ def fuse_deconv_and_bn(deconv, bn):
|
|||
|
||||
|
||||
def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||||
"""Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]."""
|
||||
"""
|
||||
Model information.
|
||||
|
||||
imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320].
|
||||
"""
|
||||
if not verbose:
|
||||
return
|
||||
n_p = get_num_params(model) # number of parameters
|
||||
|
|
@ -517,13 +521,11 @@ def profile(input, ops, n=10, device=None):
|
|||
|
||||
|
||||
class EarlyStopping:
|
||||
"""
|
||||
Early stopping class that stops training when a specified number of epochs have passed without improvement.
|
||||
"""
|
||||
"""Early stopping class that stops training when a specified number of epochs have passed without improvement."""
|
||||
|
||||
def __init__(self, patience=50):
|
||||
"""
|
||||
Initialize early stopping object
|
||||
Initialize early stopping object.
|
||||
|
||||
Args:
|
||||
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
|
||||
|
|
@ -535,7 +537,7 @@ class EarlyStopping:
|
|||
|
||||
def __call__(self, epoch, fitness):
|
||||
"""
|
||||
Check whether to stop training
|
||||
Check whether to stop training.
|
||||
|
||||
Args:
|
||||
epoch (int): Current epoch of training
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ import numpy as np
|
|||
|
||||
|
||||
class TritonRemoteModel:
|
||||
"""Client for interacting with a remote Triton Inference Server model.
|
||||
"""
|
||||
Client for interacting with a remote Triton Inference Server model.
|
||||
|
||||
Attributes:
|
||||
endpoint (str): The name of the model on the Triton server.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue