General cleanup (#69)

Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2022-12-08 08:28:13 -06:00 committed by GitHub
parent 7ae45c6cc4
commit d63ee112d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 265 additions and 433 deletions

View file

@ -7,17 +7,17 @@ import torch.nn.functional as F
from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.checks import check_file, check_requirements
from ultralytics.yolo.utils.files import yaml_load
from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou,
fitness_segmentation, mask_iou)
from ultralytics.yolo.utils.plotting import output_to_target, plot_images_and_masks
from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
from ultralytics.yolo.utils.torch_utils import de_parallel
from ..detect import DetectionValidator
class SegmentationValidator(BaseValidator):
class SegmentationValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args)
@ -65,7 +65,7 @@ class SegmentationValidator(BaseValidator):
self.niou = self.iouv.numel()
self.seen = 0
self.confusion_matrix = ConfusionMatrix(nc=self.nc)
self.metrics = Metrics()
self.metrics = SegmentMetrics(save_dir=self.save_dir, plot=self.args.plots, names=self.names)
self.loss = torch.zeros(4, device=self.device)
self.jdict = []
self.stats = []
@ -150,16 +150,6 @@ class SegmentationValidator(BaseValidator):
# callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
'''
def get_stats(self):
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
if len(stats) and stats[0].any():
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names)
self.metrics.update(results)
self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc) # number of targets per class
metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))}
metrics |= zip(self.metric_keys, self.metrics.mean_results())
return metrics
def print_results(self):
pf = '%22s' + '%11i' * 2 + '%11.3g' * 8 # print format
self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
@ -218,6 +208,7 @@ class SegmentationValidator(BaseValidator):
gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0]
# TODO: probably add this to class Metrics
@property
def metric_keys(self):
return [
@ -237,23 +228,22 @@ class SegmentationValidator(BaseValidator):
bboxes = batch["bboxes"]
paths = batch["im_file"]
batch_idx = batch["batch_idx"]
plot_images_and_masks(images,
batch_idx,
cls,
bboxes,
masks,
paths=paths,
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names)
plot_images(images,
batch_idx,
cls,
bboxes,
masks,
paths=paths,
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names)
def plot_predictions(self, batch, preds, ni):
images = batch["img"]
paths = batch["im_file"]
if len(self.plot_masks):
plot_masks = torch.cat(self.plot_masks, dim=0)
batch_idx, cls, bboxes, conf = output_to_target(preds[0], max_det=15)
plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, conf, paths,
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
plot_images(images, *output_to_target(preds[0], max_det=15), plot_masks, paths,
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
self.plot_masks.clear()