update segment training (#57)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Laughing 2022-11-29 05:30:08 -06:00 committed by GitHub
parent d0b0fe2592
commit 3a241e4cea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 460 additions and 144 deletions

View file

@ -6,23 +6,24 @@ import torch.nn.functional as F
from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.checks import check_file, check_requirements
from ultralytics.yolo.utils.files import yaml_load
from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou,
fitness_segmentation, mask_iou)
from ultralytics.yolo.utils.plotting import output_to_target, plot_images_and_masks
from ultralytics.yolo.utils.torch_utils import de_parallel
class SegmentationValidator(BaseValidator):
def __init__(self, dataloader, pbar=None, logger=None, args=None):
super().__init__(dataloader, pbar, logger, args)
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args)
if self.args.save_json:
check_requirements(['pycocotools'])
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
self.data_dict = yaml_load(self.args.data) if self.args.data else None
self.data_dict = yaml_load(check_file(self.args.data)) if self.args.data else None
self.is_coco = False
self.class_map = None
self.targets = None
@ -62,6 +63,7 @@ class SegmentationValidator(BaseValidator):
self.loss = torch.zeros(4, device=self.device)
self.jdict = []
self.stats = []
self.plot_masks = []
def get_desc(self):
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
@ -80,11 +82,10 @@ class SegmentationValidator(BaseValidator):
def update_metrics(self, preds, batch):
# Metrics
plot_masks = [] # masks for plotting
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
labels = self.targets[self.targets[:, 0] == si, 1:]
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
shape = batch["shape"][si]
shape = batch["ori_shape"][si]
# path = batch["shape"][si][0]
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
@ -130,7 +131,7 @@ class SegmentationValidator(BaseValidator):
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if self.args.plots and self.batch_i < 3:
plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
# TODO: Save/log
'''
@ -143,26 +144,14 @@ class SegmentationValidator(BaseValidator):
# callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
'''
# TODO Plot images
'''
if self.args.plots and self.batch_i < 3:
if len(plot_masks):
plot_masks = torch.cat(plot_masks, dim=0)
plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths,
save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
'''
def get_stats(self):
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
if len(stats) and stats[0].any():
# TODO: save_dir
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir='', names=self.names)
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names)
self.metrics.update(results)
self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc) # number of targets per class
keys = ["mp_bbox", "mr_bbox", "map50_bbox", "map_bbox", "mp_mask", "mr_mask", "map50_mask", "map_mask"]
metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))}
metrics |= zip(keys, self.metrics.mean_results())
metrics |= zip(self.metric_keys, self.metrics.mean_results())
return metrics
def print_results(self):
@ -177,9 +166,8 @@ class SegmentationValidator(BaseValidator):
for i, c in enumerate(self.metrics.ap_class_index):
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
# plot TODO: save_dir
if self.args.plots:
self.confusion_matrix.plot(save_dir='', names=list(self.names.values()))
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
@ -217,3 +205,41 @@ class SegmentationValidator(BaseValidator):
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
@property
def metric_keys(self):
return [
"metrics/precision(B)",
"metrics/recall(B)",
"metrics/mAP_0.5(B)",
"metrics/mAP_0.5:0.95(B)", # metrics
"metrics/precision(M)",
"metrics/recall(M)",
"metrics/mAP_0.5(M)",
"metrics/mAP_0.5:0.95(M)",]
def plot_val_samples(self, batch, ni):
images = batch["img"]
masks = batch["masks"]
cls = batch["cls"].squeeze(-1)
bboxes = batch["bboxes"]
paths = batch["im_file"]
batch_idx = batch["batch_idx"]
plot_images_and_masks(images,
batch_idx,
cls,
bboxes,
masks,
paths,
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names)
def plot_predictions(self, batch, preds, ni):
images = batch["img"]
paths = batch["im_file"]
if len(self.plot_masks):
plot_masks = torch.cat(self.plot_masks, dim=0)
batch_idx, cls, bboxes, conf = output_to_target(preds[0], max_det=15)
plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf,
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
self.plot_masks.clear()