General cleanup (#69)

Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2022-12-08 08:28:13 -06:00 committed by GitHub
parent 7ae45c6cc4
commit d63ee112d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 265 additions and 433 deletions

View file

@ -11,7 +11,7 @@ from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.checks import check_file, check_requirements
from ultralytics.yolo.utils.files import yaml_load
from ultralytics.yolo.utils.metrics import ConfusionMatrix, Metric, ap_per_class, box_iou, fitness_detection
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
from ultralytics.yolo.utils.torch_utils import de_parallel
@ -62,7 +62,7 @@ class DetectionValidator(BaseValidator):
self.niou = self.iouv.numel()
self.seen = 0
self.confusion_matrix = ConfusionMatrix(nc=self.nc)
self.metrics = Metric()
self.metrics = DetMetrics(save_dir=self.save_dir, plot=self.args.plots, names=self.names)
self.loss = torch.zeros(3, device=self.device)
self.jdict = []
self.stats = []
@ -128,10 +128,9 @@ class DetectionValidator(BaseValidator):
def get_stats(self):
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
if len(stats) and stats[0].any():
results = ap_per_class(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names)
self.metrics.update(results[2:])
self.nt_per_class = np.bincount(stats[3].astype(int), minlength=self.nc) # number of targets per class
metrics = {"fitness": fitness_detection(np.array(self.metrics.mean_results()).reshape(1, -1))}
self.metrics.process(*stats)
self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class
metrics = {"fitness": self.metrics.fitness()}
metrics |= zip(self.metric_keys, self.metrics.mean_results())
return metrics
@ -203,8 +202,11 @@ class DetectionValidator(BaseValidator):
def plot_predictions(self, batch, preds, ni):
images = batch["img"]
paths = batch["im_file"]
plot_images(images, *output_to_target(preds, max_det=15), paths, self.save_dir / f'val_batch{ni}_pred.jpg',
self.names) # pred
plot_images(images,
*output_to_target(preds, max_det=15),
paths=paths,
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)