General cleanup (#69)
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
7ae45c6cc4
commit
d63ee112d4
13 changed files with 265 additions and 433 deletions
|
|
@ -7,17 +7,17 @@ import torch.nn.functional as F
|
|||
|
||||
from ultralytics.yolo.data import build_dataloader
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.yolo.utils.checks import check_file, check_requirements
|
||||
from ultralytics.yolo.utils.files import yaml_load
|
||||
from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou,
|
||||
fitness_segmentation, mask_iou)
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images_and_masks
|
||||
from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
from ..detect import DetectionValidator
|
||||
|
||||
class SegmentationValidator(BaseValidator):
|
||||
|
||||
class SegmentationValidator(DetectionValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||
|
|
@ -65,7 +65,7 @@ class SegmentationValidator(BaseValidator):
|
|||
self.niou = self.iouv.numel()
|
||||
self.seen = 0
|
||||
self.confusion_matrix = ConfusionMatrix(nc=self.nc)
|
||||
self.metrics = Metrics()
|
||||
self.metrics = SegmentMetrics(save_dir=self.save_dir, plot=self.args.plots, names=self.names)
|
||||
self.loss = torch.zeros(4, device=self.device)
|
||||
self.jdict = []
|
||||
self.stats = []
|
||||
|
|
@ -150,16 +150,6 @@ class SegmentationValidator(BaseValidator):
|
|||
# callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
|
||||
'''
|
||||
|
||||
def get_stats(self):
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
||||
if len(stats) and stats[0].any():
|
||||
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names)
|
||||
self.metrics.update(results)
|
||||
self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc) # number of targets per class
|
||||
metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))}
|
||||
metrics |= zip(self.metric_keys, self.metrics.mean_results())
|
||||
return metrics
|
||||
|
||||
def print_results(self):
|
||||
pf = '%22s' + '%11i' * 2 + '%11.3g' * 8 # print format
|
||||
self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||
|
|
@ -218,6 +208,7 @@ class SegmentationValidator(BaseValidator):
|
|||
gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
|
||||
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0]
|
||||
|
||||
# TODO: probably add this to class Metrics
|
||||
@property
|
||||
def metric_keys(self):
|
||||
return [
|
||||
|
|
@ -237,23 +228,22 @@ class SegmentationValidator(BaseValidator):
|
|||
bboxes = batch["bboxes"]
|
||||
paths = batch["im_file"]
|
||||
batch_idx = batch["batch_idx"]
|
||||
plot_images_and_masks(images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes,
|
||||
masks,
|
||||
paths=paths,
|
||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||
names=self.names)
|
||||
plot_images(images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes,
|
||||
masks,
|
||||
paths=paths,
|
||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||
names=self.names)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
images = batch["img"]
|
||||
paths = batch["im_file"]
|
||||
if len(self.plot_masks):
|
||||
plot_masks = torch.cat(self.plot_masks, dim=0)
|
||||
batch_idx, cls, bboxes, conf = output_to_target(preds[0], max_det=15)
|
||||
plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, conf, paths,
|
||||
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
|
||||
plot_images(images, *output_to_target(preds[0], max_det=15), plot_masks, paths,
|
||||
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
|
||||
self.plot_masks.clear()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue