diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index ad29559a..7e215319 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -900,7 +900,7 @@ class Model(nn.Module): initialized, it sets it up before retrieving the names. Returns: - (List[str]): A list of class names associated with the model. + (Dict[int, str]): A dict of class names associated with the model. Raises: AttributeError: If the model or predictor does not have a 'names' attribute. @@ -908,7 +908,7 @@ class Model(nn.Module): Examples: >>> model = YOLO('yolov8n.pt') >>> print(model.names) - ['person', 'bicycle', 'car', ...] + {0: 'person', 1: 'bicycle', 2: 'car', ...} """ from ultralytics.nn.autobackend import check_class_names diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py index 80c2f660..78c77c72 100644 --- a/ultralytics/utils/metrics.py +++ b/ultralytics/utils/metrics.py @@ -445,7 +445,7 @@ def smooth(y, f=0.05): @plt_settings() -def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=None): +def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None): """Plots a precision-recall curve.""" fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) py = np.stack(py, axis=1) @@ -470,7 +470,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=N @plt_settings() -def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=(), xlabel="Confidence", ylabel="Metric", on_plot=None): +def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confidence", ylabel="Metric", on_plot=None): """Plots a metric-confidence curve.""" fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) @@ -528,7 +528,7 @@ def compute_ap(recall, precision): def ap_per_class( - tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names=(), eps=1e-16, prefix="" + tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names={}, eps=1e-16, prefix="" ): """ Computes the average precision per class for object detection evaluation. @@ -541,7 +541,7 @@ def ap_per_class( plot (bool, optional): Whether to plot PR curves or not. Defaults to False. on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None. save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path. - names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple. + names (dict, optional): Dict of class names to plot PR curves. Defaults to an empty tuple. eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16. prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string. @@ -799,13 +799,13 @@ class DetMetrics(SimpleClass): save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory. plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False. on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None. - names (tuple of str): A tuple of strings that represents the names of the classes. Defaults to an empty tuple. + names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple. Attributes: save_dir (Path): A path to the directory where the output plots will be saved. plot (bool): A flag that indicates whether to plot the precision-recall curves for each class. on_plot (func): An optional callback to pass plots path and data when they are rendered. - names (tuple of str): A tuple of strings that represents the names of the classes. + names (dict of str): A dict of strings that represents the names of the classes. box (Metric): An instance of the Metric class for storing the results of the detection metrics. speed (dict): A dictionary for storing the execution time of different parts of the detection process. @@ -822,7 +822,7 @@ class DetMetrics(SimpleClass): curves_results: TODO """ - def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names={}) -> None: """Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names.""" self.save_dir = save_dir self.plot = plot