ultralytics 8.0.58 new SimpleClass, fixes and updates (#1636)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
ef03e6732a
commit
ec10002a4a
30 changed files with 351 additions and 314 deletions
|
|
@ -11,7 +11,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, TryExcept
|
||||
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept
|
||||
|
||||
|
||||
# boxes
|
||||
|
|
@ -425,7 +425,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
|
|||
return tp, fp, p, r, f1, ap, unique_classes.astype(int)
|
||||
|
||||
|
||||
class Metric:
|
||||
class Metric(SimpleClass):
|
||||
"""
|
||||
Class for computing evaluation metrics for YOLOv8 model.
|
||||
|
||||
|
|
@ -461,10 +461,6 @@ class Metric:
|
|||
self.ap_class_index = [] # (nc, )
|
||||
self.nc = 0
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
@property
|
||||
def ap50(self):
|
||||
"""AP@0.5 of all classes.
|
||||
|
|
@ -550,7 +546,7 @@ class Metric:
|
|||
self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results
|
||||
|
||||
|
||||
class DetMetrics:
|
||||
class DetMetrics(SimpleClass):
|
||||
"""
|
||||
This class is a utility class for computing detection metrics such as precision, recall, and mean average precision
|
||||
(mAP) of an object detection model.
|
||||
|
|
@ -585,10 +581,6 @@ class DetMetrics:
|
|||
self.box = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def process(self, tp, conf, pred_cls, target_cls):
|
||||
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
|
||||
names=self.names)[2:]
|
||||
|
|
@ -622,7 +614,7 @@ class DetMetrics:
|
|||
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||
|
||||
|
||||
class SegmentMetrics:
|
||||
class SegmentMetrics(SimpleClass):
|
||||
"""
|
||||
Calculates and aggregates detection and segmentation metrics over a given set of classes.
|
||||
|
||||
|
|
@ -657,10 +649,6 @@ class SegmentMetrics:
|
|||
self.seg = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def process(self, tp_m, tp_b, conf, pred_cls, target_cls):
|
||||
"""
|
||||
Processes the detection and segmentation metrics over the given set of predictions.
|
||||
|
|
@ -724,7 +712,7 @@ class SegmentMetrics:
|
|||
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||
|
||||
|
||||
class ClassifyMetrics:
|
||||
class ClassifyMetrics(SimpleClass):
|
||||
"""
|
||||
Class for computing classification metrics including top-1 and top-5 accuracy.
|
||||
|
||||
|
|
@ -747,10 +735,6 @@ class ClassifyMetrics:
|
|||
self.top5 = 0
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def process(self, targets, pred):
|
||||
# target classes and predicted classes
|
||||
pred, targets = torch.cat(pred), torch.cat(targets)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue