Add docformatter to pre-commit (#5279)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-10-09 02:25:22 +02:00 committed by GitHub
parent c7aa83da31
commit 7517667a33
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
90 changed files with 1396 additions and 497 deletions

View file

@ -117,6 +117,7 @@ class TQDM(tqdm_original):
"""
def __init__(self, *args, **kwargs):
"""Initialize custom Ultralytics tqdm class with different default arguments."""
# Set new default values (these can still be overridden when calling TQDM)
kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed
kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed
@ -124,8 +125,7 @@ class TQDM(tqdm_original):
class SimpleClass:
"""
Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
"""Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
access methods for easier debugging and usage.
"""
@ -154,8 +154,7 @@ class SimpleClass:
class IterableSimpleNamespace(SimpleNamespace):
"""
Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
"""Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
enables usage with dict() and for loops.
"""
@ -256,8 +255,8 @@ class EmojiFilter(logging.Filter):
"""
A custom logging filter class for removing emojis in log messages.
This filter is particularly useful for ensuring compatibility with Windows terminals
that may not support the display of emojis in log messages.
This filter is particularly useful for ensuring compatibility with Windows terminals that may not support the
display of emojis in log messages.
"""
def filter(self, record):
@ -275,9 +274,9 @@ if WINDOWS: # emoji-safe logging
class ThreadingLocked:
"""
A decorator class for ensuring thread-safe execution of a function or method.
This class can be used as a decorator to make sure that if the decorated function
is called from multiple threads, only one thread at a time will be able to execute the function.
A decorator class for ensuring thread-safe execution of a function or method. This class can be used as a decorator
to make sure that if the decorated function is called from multiple threads, only one thread at a time will be able
to execute the function.
Attributes:
lock (threading.Lock): A lock object used to manage access to the decorated function.
@ -294,13 +293,16 @@ class ThreadingLocked:
"""
def __init__(self):
"""Initializes the decorator class for thread-safe execution of a function or method."""
self.lock = threading.Lock()
def __call__(self, f):
"""Run thread-safe execution of function or method."""
from functools import wraps
@wraps(f)
def decorated(*args, **kwargs):
"""Applies thread-safety to the decorated function or method."""
with self.lock:
return f(*args, **kwargs)
@ -424,8 +426,7 @@ def is_kaggle():
def is_jupyter():
"""
Check if the current script is running inside a Jupyter Notebook.
Verified on Colab, Jupyterlab, Kaggle, Paperspace.
Check if the current script is running inside a Jupyter Notebook. Verified on Colab, Jupyterlab, Kaggle, Paperspace.
Returns:
(bool): True if running inside a Jupyter Notebook, False otherwise.
@ -529,8 +530,8 @@ def is_github_actions_ci() -> bool:
def is_git_dir():
"""
Determines whether the current file is part of a git repository.
If the current file is not part of a git repository, returns None.
Determines whether the current file is part of a git repository. If the current file is not part of a git
repository, returns None.
Returns:
(bool): True if current file is part of a git repository.
@ -540,8 +541,8 @@ def is_git_dir():
def get_git_dir():
"""
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
If the current file is not part of a git repository, returns None.
Determines whether the current file is part of a git repository and if so, returns the repository root directory. If
the current file is not part of a git repository, returns None.
Returns:
(Path | None): Git root directory if found or None if not found.
@ -578,7 +579,8 @@ def get_git_branch():
def get_default_args(func):
"""Returns a dictionary of default arguments for a function.
"""
Returns a dictionary of default arguments for a function.
Args:
func (callable): The function to inspect.
@ -710,7 +712,11 @@ def remove_colorstr(input_string):
class TryExcept(contextlib.ContextDecorator):
"""YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
"""
YOLOv8 TryExcept class.
Use as @TryExcept() decorator or 'with TryExcept():' context manager.
"""
def __init__(self, msg='', verbose=True):
"""Initialize TryExcept class with optional message and verbosity settings."""
@ -729,7 +735,11 @@ class TryExcept(contextlib.ContextDecorator):
def threaded(func):
"""Multi-threads a target function and returns thread. Usage: @threaded decorator."""
"""
Multi-threads a target function and returns thread.
Use as @threaded decorator.
"""
def wrapper(*args, **kwargs):
"""Multi-threads a given function and returns the thread."""
@ -824,6 +834,9 @@ class SettingsManager(dict):
"""
def __init__(self, file=SETTINGS_YAML, version='0.0.4'):
"""Initialize the SettingsManager with default settings, load and validate current settings from the YAML
file.
"""
import copy
import hashlib

View file

@ -1,7 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.
"""
"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
from copy import deepcopy

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Benchmark a YOLO model formats for speed and accuracy
Benchmark a YOLO model formats for speed and accuracy.
Usage:
from ultralytics.utils.benchmarks import ProfileModels, benchmark
@ -194,6 +194,7 @@ class ProfileModels:
self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu')
def profile(self):
"""Logs the benchmarking results of a model, checks metrics against floor and returns the results."""
files = self.get_files()
if not files:
@ -235,6 +236,7 @@ class ProfileModels:
return output
def get_files(self):
"""Returns a list of paths for all relevant model files given by the user."""
files = []
for path in self.paths:
path = Path(path)
@ -250,10 +252,14 @@ class ProfileModels:
return [Path(file) for file in sorted(files)]
def get_onnx_model_info(self, onnx_file: str):
"""Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model
file.
"""
# return (num_layers, num_params, num_gradients, num_flops)
return 0.0, 0.0, 0.0, 0.0
def iterative_sigma_clipping(self, data, sigma=2, max_iters=3):
"""Applies an iterative sigma clipping algorithm to the given data times number of iterations."""
data = np.array(data)
for _ in range(max_iters):
mean, std = np.mean(data), np.std(data)
@ -264,6 +270,7 @@ class ProfileModels:
return data
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
"""Profiles the TensorRT model, measuring average run time and standard deviation among runs."""
if not self.trt or not Path(engine_file).is_file():
return 0.0, 0.0
@ -292,6 +299,9 @@ class ProfileModels:
return np.mean(run_times), np.std(run_times)
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
"""Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
times.
"""
check_requirements('onnxruntime')
import onnxruntime as ort
@ -344,10 +354,12 @@ class ProfileModels:
return np.mean(run_times), np.std(run_times)
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
"""Generates a formatted string for a table row that includes model performance and metric details."""
layers, params, gradients, flops = model_info
return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |'
def generate_results_dict(self, model_name, t_onnx, t_engine, model_info):
"""Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
layers, params, gradients, flops = model_info
return {
'model/name': model_name,
@ -357,6 +369,7 @@ class ProfileModels:
'model/speed_TensorRT(ms)': round(t_engine[0], 3)}
def print_table(self, table_rows):
"""Formats and prints a comparison table for different models with given statistics and performance data."""
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
header = f'| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |'
separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'

View file

@ -1,7 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Base callbacks
"""
"""Base callbacks."""
from collections import defaultdict
from copy import deepcopy

View file

@ -26,31 +26,38 @@ except (ImportError, AssertionError):
def _get_comet_mode():
"""Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
return os.getenv('COMET_MODE', 'online')
def _get_comet_model_name():
"""Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
def _get_eval_batch_logging_interval():
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
def _get_max_image_predictions_to_log():
"""Get the maximum number of image predictions to log from the environment variables."""
return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
def _scale_confidence_score(score):
"""Scales the given confidence score by a factor specified in an environment variable."""
scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
return score * scale
def _should_log_confusion_matrix():
"""Determines if the confusion matrix should be logged based on the environment variable settings."""
return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true'
def _should_log_image_predictions():
"""Determines whether to log image predictions based on a specified environment variable."""
return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
@ -104,9 +111,10 @@ def _fetch_trainer_metadata(trainer):
def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
"""YOLOv8 resizes images during training and the label values
are normalized based on this resized shape. This function rescales the
bounding box labels to the original image shape.
"""
YOLOv8 resizes images during training and the label values are normalized based on this resized shape.
This function rescales the bounding box labels to the original image shape.
"""
resized_image_height, resized_image_width = resized_image_shape

View file

@ -25,6 +25,7 @@ except (ImportError, AssertionError, TypeError):
def _log_images(path, prefix=''):
"""Logs images at specified path with an optional prefix using DVCLive."""
if live:
name = path.name
@ -38,6 +39,7 @@ def _log_images(path, prefix=''):
def _log_plots(plots, prefix=''):
"""Logs plot images for training progress if they have not been previously processed."""
for name, params in plots.items():
timestamp = params['timestamp']
if _processed_plots.get(name) != timestamp:
@ -46,6 +48,7 @@ def _log_plots(plots, prefix=''):
def _log_confusion_matrix(validator):
"""Logs the confusion matrix for the given validator using DVCLive."""
targets = []
preds = []
matrix = validator.confusion_matrix.matrix
@ -62,6 +65,7 @@ def _log_confusion_matrix(validator):
def on_pretrain_routine_start(trainer):
"""Initializes DVCLive logger for training metadata during pre-training routine."""
try:
global live
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
@ -71,20 +75,24 @@ def on_pretrain_routine_start(trainer):
def on_pretrain_routine_end(trainer):
"""Logs plots related to the training process at the end of the pretraining routine."""
_log_plots(trainer.plots, 'train')
def on_train_start(trainer):
"""Logs the training parameters if DVCLive logging is active."""
if live:
live.log_params(trainer.args)
def on_train_epoch_start(trainer):
"""Sets the global variable _training_epoch value to True at the start of training each epoch."""
global _training_epoch
_training_epoch = True
def on_fit_epoch_end(trainer):
"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
global _training_epoch
if live and _training_epoch:
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
@ -104,6 +112,7 @@ def on_fit_epoch_end(trainer):
def on_train_end(trainer):
"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
if live:
# At the end log the best metrics. It runs validator on the best model internally.
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}

View file

@ -31,14 +31,13 @@ def _log_images(imgs_dict, group=''):
def _log_plot(title, plot_path):
"""Log plots to the NeptuneAI experiment logger."""
"""
Log image as plot in the plot section of NeptuneAI
Log plots to the NeptuneAI experiment logger.
arguments:
title (str) Title of the plot
plot_path (PosixPath or str) Path to the saved image file
"""
Args:
title (str): Title of the plot.
plot_path (PosixPath | str): Path to the saved image file.
"""
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

View file

@ -17,6 +17,7 @@ except (ImportError, AssertionError):
def _log_plots(plots, step):
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
for name, params in plots.items():
timestamp = params['timestamp']
if _processed_plots.get(name) != timestamp:

View file

@ -64,8 +64,8 @@ def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
def parse_version(version='0.0.0') -> tuple:
"""
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.
This function replaces deprecated 'pkg_resources.parse_version(v)'
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
function replaces deprecated 'pkg_resources.parse_version(v)'.
Args:
version (str): Version string, i.e. '2.0.1+cpu'
@ -372,8 +372,10 @@ def check_torchvision():
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
to the provided compatibility table based on https://github.com/pytorch/vision#installation. The
compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
to the provided compatibility table based on:
https://github.com/pytorch/vision#installation.
The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
Torchvision versions.
"""
@ -527,9 +529,9 @@ def collect_system_info():
def check_amp(model):
"""
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model.
If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
results, so AMP will be disabled during training.
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
be disabled during training.
Args:
model (nn.Module): A YOLOv8 model instance.
@ -606,7 +608,8 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
def cuda_device_count() -> int:
"""Get the number of NVIDIA GPUs available in the environment.
"""
Get the number of NVIDIA GPUs available in the environment.
Returns:
(int): The number of NVIDIA GPUs available.
@ -626,7 +629,8 @@ def cuda_device_count() -> int:
def cuda_is_available() -> bool:
"""Check if CUDA is available in the environment.
"""
Check if CUDA is available in the environment.
Returns:
(bool): True if one or more NVIDIA GPUs are available, False otherwise.

View file

@ -13,7 +13,8 @@ from .torch_utils import TORCH_1_9
def find_free_network_port() -> int:
"""Finds a free port on localhost.
"""
Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable.

View file

@ -69,8 +69,8 @@ def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), progress=True):
"""
Zips the contents of a directory, excluding files containing strings in the exclude list.
The resulting zip file is named after the directory and placed alongside it.
Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
named after the directory and placed alongside it.
Args:
directory (str | Path): The path to the directory to be zipped.
@ -341,7 +341,11 @@ def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
"""Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc."""
"""
Attempt file download from GitHub release assets if not found locally.
release = 'latest', 'v6.2', etc.
"""
from ultralytics.utils import SETTINGS # scoped for circular import
# YOLOv3/5u updates

View file

@ -30,9 +30,9 @@ class WorkingDirectory(contextlib.ContextDecorator):
@contextmanager
def spaces_in_path(path):
"""
Context manager to handle paths with spaces in their names.
If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path,
executes the context code block, then copies the file/directory back to its original location.
Context manager to handle paths with spaces in their names. If a path contains spaces, it replaces them with
underscores, copies the file/directory to the new path, executes the context code block, then copies the
file/directory back to its original location.
Args:
path (str | Path): The original path.

View file

@ -32,9 +32,14 @@ __all__ = 'Bboxes', # tuple or list
class Bboxes:
"""Bounding Boxes class. Only numpy variables are supported."""
"""
Bounding Boxes class.
Only numpy variables are supported.
"""
def __init__(self, bboxes, format='xyxy') -> None:
"""Initializes the Bboxes class with bounding box data in a specified format."""
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
assert bboxes.ndim == 2
@ -194,7 +199,7 @@ class Instances:
return self._bboxes.areas()
def scale(self, scale_w, scale_h, bbox_only=False):
"""this might be similar with denormalize func but without normalized sign."""
"""This might be similar with denormalize func but without normalized sign."""
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
if bbox_only:
return
@ -307,7 +312,11 @@ class Instances:
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
def remove_zero_area_boxes(self):
"""Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height. This removes them."""
"""
Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.
This removes them.
"""
good = self.bbox_areas > 0
if not all(good):
self._bboxes = self._bboxes[good]

View file

@ -13,7 +13,11 @@ from .tal import bbox2dist
class VarifocalLoss(nn.Module):
"""Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
"""
Varifocal loss by Zhang et al.
https://arxiv.org/abs/2008.13367.
"""
def __init__(self):
"""Initialize the VarifocalLoss class."""
@ -33,6 +37,7 @@ class FocalLoss(nn.Module):
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
def __init__(self, ):
"""Initializer for FocalLoss class with no parameters."""
super().__init__()
@staticmethod
@ -93,6 +98,7 @@ class KeypointLoss(nn.Module):
"""Criterion class for computing training losses."""
def __init__(self, sigmas) -> None:
"""Initialize the KeypointLoss class."""
super().__init__()
self.sigmas = sigmas

View file

@ -1,7 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Model validation metrics
"""
"""Model validation metrics."""
import math
import warnings
@ -195,7 +193,7 @@ class ConfusionMatrix:
def process_cls_preds(self, preds, targets):
"""
Update confusion matrix for classification task
Update confusion matrix for classification task.
Args:
preds (Array[N, min(nc,5)]): Predicted class labels.
@ -308,9 +306,7 @@ class ConfusionMatrix:
on_plot(plot_fname)
def print(self):
"""
Print the confusion matrix to the console.
"""
"""Print the confusion matrix to the console."""
for i in range(self.nc + 1):
LOGGER.info(' '.join(map(str, self.matrix[i])))
@ -440,7 +436,6 @@ def ap_per_class(tp,
f1 (np.ndarray): F1-score values at each confidence threshold.
ap (np.ndarray): Average precision for each class at different IoU thresholds.
unique_classes (np.ndarray): An array of unique classes that have data.
"""
# Sort by objectness
@ -498,32 +493,33 @@ def ap_per_class(tp,
class Metric(SimpleClass):
"""
Class for computing evaluation metrics for YOLOv8 model.
Class for computing evaluation metrics for YOLOv8 model.
Attributes:
p (list): Precision for each class. Shape: (nc,).
r (list): Recall for each class. Shape: (nc,).
f1 (list): F1 score for each class. Shape: (nc,).
all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
ap_class_index (list): Index of class for each AP score. Shape: (nc,).
nc (int): Number of classes.
Attributes:
p (list): Precision for each class. Shape: (nc,).
r (list): Recall for each class. Shape: (nc,).
f1 (list): F1 score for each class. Shape: (nc,).
all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
ap_class_index (list): Index of class for each AP score. Shape: (nc,).
nc (int): Number of classes.
Methods:
ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
mp(): Mean precision of all classes. Returns: Float.
mr(): Mean recall of all classes. Returns: Float.
map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.
map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.
map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.
mean_results(): Mean of results, returns mp, mr, map50, map.
class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].
maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).
fitness(): Model fitness as a weighted combination of metrics. Returns: Float.
update(results): Update metric attributes with new evaluation results.
"""
Methods:
ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
mp(): Mean precision of all classes. Returns: Float.
mr(): Mean recall of all classes. Returns: Float.
map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.
map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.
map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.
mean_results(): Mean of results, returns mp, mr, map50, map.
class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].
maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).
fitness(): Model fitness as a weighted combination of metrics. Returns: Float.
update(results): Update metric attributes with new evaluation results.
"""
def __init__(self) -> None:
"""Initializes a Metric instance for computing evaluation metrics for the YOLOv8 model."""
self.p = [] # (nc, )
self.r = [] # (nc, )
self.f1 = [] # (nc, )
@ -606,12 +602,12 @@ class Metric(SimpleClass):
return [self.mp, self.mr, self.map50, self.map]
def class_result(self, i):
"""class-aware result, return p[i], r[i], ap50[i], ap[i]."""
"""Class-aware result, return p[i], r[i], ap50[i], ap[i]."""
return self.p[i], self.r[i], self.ap50[i], self.ap[i]
@property
def maps(self):
"""mAP of each class."""
"""MAP of each class."""
maps = np.zeros(self.nc) + self.map
for i, c in enumerate(self.ap_class_index):
maps[c] = self.ap[i]
@ -672,6 +668,7 @@ class DetMetrics(SimpleClass):
"""
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
"""Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
@ -756,6 +753,7 @@ class SegmentMetrics(SimpleClass):
"""
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
"""Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names."""
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
@ -865,6 +863,7 @@ class PoseMetrics(SegmentMetrics):
"""
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
"""Initialize the PoseMetrics class with directory path, class names, and plotting options."""
super().__init__(save_dir, plot, names)
self.save_dir = save_dir
self.plot = plot
@ -954,6 +953,7 @@ class ClassifyMetrics(SimpleClass):
"""
def __init__(self) -> None:
"""Initialize a ClassifyMetrics instance."""
self.top1 = 0
self.top5 = 0
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}

View file

@ -50,6 +50,7 @@ class Profile(contextlib.ContextDecorator):
self.t += self.dt # accumulate dt
def __str__(self):
"""Returns a human-readable string representing the accumulated elapsed time in the profiler."""
return f'Elapsed time is {self.t} s'
def time(self):
@ -303,7 +304,7 @@ def clip_coords(coords, shape):
def scale_image(masks, im0_shape, ratio_pad=None):
"""
Takes a mask, and resizes it to the original image size
Takes a mask, and resizes it to the original image size.
Args:
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
@ -403,8 +404,8 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
"""
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format.
x, y, width and height are normalized to image dimensions
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
width and height are normalized to image dimensions.
Args:
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
@ -445,7 +446,7 @@ def xywh2ltwh(x):
def xyxy2ltwh(x):
"""
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right.
Args:
x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
@ -461,7 +462,7 @@ def xyxy2ltwh(x):
def ltwh2xywh(x):
"""
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
Args:
x (torch.Tensor): the input tensor
@ -544,7 +545,7 @@ def xywhr2xyxyxyxy(center):
def ltwh2xyxy(x):
"""
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
Args:
x (np.ndarray | torch.Tensor): the input image
@ -616,8 +617,8 @@ def crop_mask(masks, boxes):
def process_mask_upsample(protos, masks_in, bboxes, shape):
"""
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
quality but is slower.
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher quality
but is slower.
Args:
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
@ -713,7 +714,7 @@ def scale_masks(masks, shape, padding=True):
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
"""
Rescale segment coordinates (xy) from img1_shape to img0_shape
Rescale segment coordinates (xy) from img1_shape to img0_shape.
Args:
img1_shape (tuple): The shape of the image that the coords are from.

View file

@ -1,7 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Monkey patches to update/extend functionality of existing functions
"""
"""Monkey patches to update/extend functionality of existing functions."""
from pathlib import Path
@ -14,7 +12,8 @@ _imshow = cv2.imshow # copy to avoid recursion errors
def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
"""Read an image from a file.
"""
Read an image from a file.
Args:
filename (str): Path to the file to read.
@ -27,7 +26,8 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
def imwrite(filename: str, img: np.ndarray, params=None):
"""Write an image to a file.
"""
Write an image to a file.
Args:
filename (str): Path to the file to write.
@ -45,7 +45,8 @@ def imwrite(filename: str, img: np.ndarray, params=None):
def imshow(winname: str, mat: np.ndarray):
"""Displays an image in the specified window.
"""
Displays an image in the specified window.
Args:
winname (str): Name of the window.
@ -59,7 +60,8 @@ _torch_save = torch.save # copy to avoid recursion errors
def torch_save(*args, **kwargs):
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this.
"""
Use dill (if exists) to serialize the lambda functions where pickle does not do this.
Args:
*args (tuple): Positional arguments to pass to torch.save.

View file

@ -316,7 +316,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
"""Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
"""
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
This function takes a bounding box and an image, and then saves a cropped portion of the image according
to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding

View file

@ -205,7 +205,11 @@ def fuse_deconv_and_bn(deconv, bn):
def model_info(model, detailed=False, verbose=True, imgsz=640):
"""Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]."""
"""
Model information.
imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320].
"""
if not verbose:
return
n_p = get_num_params(model) # number of parameters
@ -517,13 +521,11 @@ def profile(input, ops, n=10, device=None):
class EarlyStopping:
"""
Early stopping class that stops training when a specified number of epochs have passed without improvement.
"""
"""Early stopping class that stops training when a specified number of epochs have passed without improvement."""
def __init__(self, patience=50):
"""
Initialize early stopping object
Initialize early stopping object.
Args:
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
@ -535,7 +537,7 @@ class EarlyStopping:
def __call__(self, epoch, fitness):
"""
Check whether to stop training
Check whether to stop training.
Args:
epoch (int): Current epoch of training

View file

@ -7,7 +7,8 @@ import numpy as np
class TritonRemoteModel:
"""Client for interacting with a remote Triton Inference Server model.
"""
Client for interacting with a remote Triton Inference Server model.
Attributes:
endpoint (str): The name of the model on the Triton server.