ultralytics 8.0.58 new SimpleClass, fixes and updates (#1636)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-03-26 22:16:38 +02:00 committed by GitHub
parent ef03e6732a
commit ec10002a4a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 351 additions and 314 deletions

View file

@ -11,7 +11,7 @@ import numpy as np
import torch
import torch.nn as nn
from ultralytics.yolo.utils import LOGGER, TryExcept
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept
# boxes
@ -425,7 +425,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
return tp, fp, p, r, f1, ap, unique_classes.astype(int)
class Metric:
class Metric(SimpleClass):
"""
Class for computing evaluation metrics for YOLOv8 model.
@ -461,10 +461,6 @@ class Metric:
self.ap_class_index = [] # (nc, )
self.nc = 0
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@property
def ap50(self):
"""AP@0.5 of all classes.
@ -550,7 +546,7 @@ class Metric:
self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results
class DetMetrics:
class DetMetrics(SimpleClass):
"""
This class is a utility class for computing detection metrics such as precision, recall, and mean average precision
(mAP) of an object detection model.
@ -585,10 +581,6 @@ class DetMetrics:
self.box = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, tp, conf, pred_cls, target_cls):
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
names=self.names)[2:]
@ -622,7 +614,7 @@ class DetMetrics:
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
class SegmentMetrics:
class SegmentMetrics(SimpleClass):
"""
Calculates and aggregates detection and segmentation metrics over a given set of classes.
@ -657,10 +649,6 @@ class SegmentMetrics:
self.seg = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, tp_m, tp_b, conf, pred_cls, target_cls):
"""
Processes the detection and segmentation metrics over the given set of predictions.
@ -724,7 +712,7 @@ class SegmentMetrics:
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
class ClassifyMetrics:
class ClassifyMetrics(SimpleClass):
"""
Class for computing classification metrics including top-1 and top-5 accuracy.
@ -747,10 +735,6 @@ class ClassifyMetrics:
self.top5 = 0
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, targets, pred):
# target classes and predicted classes
pred, targets = torch.cat(pred), torch.cat(targets)