General cleanup (#69)
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
7ae45c6cc4
commit
d63ee112d4
13 changed files with 265 additions and 433 deletions
|
|
@ -11,7 +11,7 @@ from ultralytics.yolo.engine.validator import BaseValidator
|
|||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.yolo.utils.checks import check_file, check_requirements
|
||||
from ultralytics.yolo.utils.files import yaml_load
|
||||
from ultralytics.yolo.utils.metrics import ConfusionMatrix, Metric, ap_per_class, box_iou, fitness_detection
|
||||
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ class DetectionValidator(BaseValidator):
|
|||
self.niou = self.iouv.numel()
|
||||
self.seen = 0
|
||||
self.confusion_matrix = ConfusionMatrix(nc=self.nc)
|
||||
self.metrics = Metric()
|
||||
self.metrics = DetMetrics(save_dir=self.save_dir, plot=self.args.plots, names=self.names)
|
||||
self.loss = torch.zeros(3, device=self.device)
|
||||
self.jdict = []
|
||||
self.stats = []
|
||||
|
|
@ -128,10 +128,9 @@ class DetectionValidator(BaseValidator):
|
|||
def get_stats(self):
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
||||
if len(stats) and stats[0].any():
|
||||
results = ap_per_class(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names)
|
||||
self.metrics.update(results[2:])
|
||||
self.nt_per_class = np.bincount(stats[3].astype(int), minlength=self.nc) # number of targets per class
|
||||
metrics = {"fitness": fitness_detection(np.array(self.metrics.mean_results()).reshape(1, -1))}
|
||||
self.metrics.process(*stats)
|
||||
self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class
|
||||
metrics = {"fitness": self.metrics.fitness()}
|
||||
metrics |= zip(self.metric_keys, self.metrics.mean_results())
|
||||
return metrics
|
||||
|
||||
|
|
@ -203,8 +202,11 @@ class DetectionValidator(BaseValidator):
|
|||
def plot_predictions(self, batch, preds, ni):
|
||||
images = batch["img"]
|
||||
paths = batch["im_file"]
|
||||
plot_images(images, *output_to_target(preds, max_det=15), paths, self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||
self.names) # pred
|
||||
plot_images(images,
|
||||
*output_to_target(preds, max_det=15),
|
||||
paths=paths,
|
||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||
names=self.names) # pred
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue