ultralytics 8.0.81 single-line docstring updates (#2061)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-04-17 00:45:36 +02:00 committed by GitHub
parent 5bce1c3021
commit a38f227672
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
64 changed files with 620 additions and 58 deletions

View file

@ -17,16 +17,19 @@ from ultralytics.yolo.v8.detect import DetectionValidator
class SegmentationValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir)
def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""
batch = super().preprocess(batch)
batch['masks'] = batch['masks'].to(self.device).float()
return batch
def init_metrics(self, model):
"""Initialize metrics and select mask processing function based on save_json flag."""
super().init_metrics(model)
self.plot_masks = []
if self.args.save_json:
@ -36,10 +39,12 @@ class SegmentationValidator(DetectionValidator):
self.process = ops.process_mask # faster
def get_desc(self):
"""Return a formatted description of evaluation metrics."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds):
"""Postprocesses YOLO predictions and returns output detections with proto."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
@ -119,6 +124,7 @@ class SegmentationValidator(DetectionValidator):
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def finalize_metrics(self, *args, **kwargs):
"""Sets speed and confusion matrix for evaluation metrics."""
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
@ -160,6 +166,7 @@ class SegmentationValidator(DetectionValidator):
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def plot_val_samples(self, batch, ni):
"""Plots validation samples with bounding box labels."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
@ -170,6 +177,7 @@ class SegmentationValidator(DetectionValidator):
names=self.names)
def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
plot_images(batch['img'],
*output_to_target(preds[0], max_det=15),
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
@ -184,6 +192,7 @@ class SegmentationValidator(DetectionValidator):
from pycocotools.mask import encode # noqa
def single_encode(x):
"""Encode predicted masks as RLE and append results to jdict."""
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
rle['counts'] = rle['counts'].decode('utf-8')
return rle
@ -204,6 +213,7 @@ class SegmentationValidator(DetectionValidator):
'segmentation': rles[i]})
def eval_json(self, stats):
"""Return COCO-style object detection evaluation metrics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
@ -232,6 +242,7 @@ class SegmentationValidator(DetectionValidator):
def val(cfg=DEFAULT_CFG, use_python=False):
"""Validate trained YOLO model on validation data."""
model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or 'coco128-seg.yaml'