Fix model names docstring type to dict (#15726)
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
98a39f594c
commit
5f93df6fca
2 changed files with 9 additions and 9 deletions
|
|
@ -900,7 +900,7 @@ class Model(nn.Module):
|
||||||
initialized, it sets it up before retrieving the names.
|
initialized, it sets it up before retrieving the names.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(List[str]): A list of class names associated with the model.
|
(Dict[int, str]): A dict of class names associated with the model.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AttributeError: If the model or predictor does not have a 'names' attribute.
|
AttributeError: If the model or predictor does not have a 'names' attribute.
|
||||||
|
|
@ -908,7 +908,7 @@ class Model(nn.Module):
|
||||||
Examples:
|
Examples:
|
||||||
>>> model = YOLO('yolov8n.pt')
|
>>> model = YOLO('yolov8n.pt')
|
||||||
>>> print(model.names)
|
>>> print(model.names)
|
||||||
['person', 'bicycle', 'car', ...]
|
{0: 'person', 1: 'bicycle', 2: 'car', ...}
|
||||||
"""
|
"""
|
||||||
from ultralytics.nn.autobackend import check_class_names
|
from ultralytics.nn.autobackend import check_class_names
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -445,7 +445,7 @@ def smooth(y, f=0.05):
|
||||||
|
|
||||||
|
|
||||||
@plt_settings()
|
@plt_settings()
|
||||||
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=None):
|
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None):
|
||||||
"""Plots a precision-recall curve."""
|
"""Plots a precision-recall curve."""
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||||
py = np.stack(py, axis=1)
|
py = np.stack(py, axis=1)
|
||||||
|
|
@ -470,7 +470,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=N
|
||||||
|
|
||||||
|
|
||||||
@plt_settings()
|
@plt_settings()
|
||||||
def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=(), xlabel="Confidence", ylabel="Metric", on_plot=None):
|
def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confidence", ylabel="Metric", on_plot=None):
|
||||||
"""Plots a metric-confidence curve."""
|
"""Plots a metric-confidence curve."""
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||||
|
|
||||||
|
|
@ -528,7 +528,7 @@ def compute_ap(recall, precision):
|
||||||
|
|
||||||
|
|
||||||
def ap_per_class(
|
def ap_per_class(
|
||||||
tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names=(), eps=1e-16, prefix=""
|
tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names={}, eps=1e-16, prefix=""
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Computes the average precision per class for object detection evaluation.
|
Computes the average precision per class for object detection evaluation.
|
||||||
|
|
@ -541,7 +541,7 @@ def ap_per_class(
|
||||||
plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
|
plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
|
||||||
on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.
|
on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.
|
||||||
save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
|
save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
|
||||||
names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple.
|
names (dict, optional): Dict of class names to plot PR curves. Defaults to an empty tuple.
|
||||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
|
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
|
||||||
prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
|
prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
|
||||||
|
|
||||||
|
|
@ -799,13 +799,13 @@ class DetMetrics(SimpleClass):
|
||||||
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
|
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
|
||||||
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
|
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
|
||||||
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
|
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
|
||||||
names (tuple of str): A tuple of strings that represents the names of the classes. Defaults to an empty tuple.
|
names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
save_dir (Path): A path to the directory where the output plots will be saved.
|
save_dir (Path): A path to the directory where the output plots will be saved.
|
||||||
plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
|
plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
|
||||||
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
||||||
names (tuple of str): A tuple of strings that represents the names of the classes.
|
names (dict of str): A dict of strings that represents the names of the classes.
|
||||||
box (Metric): An instance of the Metric class for storing the results of the detection metrics.
|
box (Metric): An instance of the Metric class for storing the results of the detection metrics.
|
||||||
speed (dict): A dictionary for storing the execution time of different parts of the detection process.
|
speed (dict): A dictionary for storing the execution time of different parts of the detection process.
|
||||||
|
|
||||||
|
|
@ -822,7 +822,7 @@ class DetMetrics(SimpleClass):
|
||||||
curves_results: TODO
|
curves_results: TODO
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
|
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names={}) -> None:
|
||||||
"""Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
|
"""Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.plot = plot
|
self.plot = plot
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue