ultralytics 8.0.81 single-line docstring updates (#2061)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-04-17 00:45:36 +02:00 committed by GitHub
parent 5bce1c3021
commit a38f227672
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
64 changed files with 620 additions and 58 deletions

View file

@ -15,20 +15,24 @@ from ultralytics.yolo.v8.detect import DetectionValidator
class PoseValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'pose'
self.metrics = PoseMetrics(save_dir=self.save_dir)
def preprocess(self, batch):
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
batch = super().preprocess(batch)
batch['keypoints'] = batch['keypoints'].to(self.device).float()
return batch
def get_desc(self):
"""Returns description of evaluation metrics in string format."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds):
"""Apply non-maximum suppression and return detections with high confidence scores."""
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
@ -40,6 +44,7 @@ class PoseValidator(DetectionValidator):
return preds
def init_metrics(self, model):
"""Initiate pose estimation metrics for YOLO model."""
super().init_metrics(model)
self.kpt_shape = self.data['kpt_shape']
is_pose = self.kpt_shape == [17, 3]
@ -137,6 +142,7 @@ class PoseValidator(DetectionValidator):
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def plot_val_samples(self, batch, ni):
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
@ -147,6 +153,7 @@ class PoseValidator(DetectionValidator):
names=self.names)
def plot_predictions(self, batch, preds, ni):
"""Plots predictions for YOLO model."""
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape)[:15] for p in preds], 0)
plot_images(batch['img'],
*output_to_target(preds, max_det=15),
@ -156,6 +163,7 @@ class PoseValidator(DetectionValidator):
names=self.names) # pred
def pred_to_json(self, predn, filename):
"""Converts YOLO predictions to COCO JSON format."""
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
@ -169,6 +177,7 @@ class PoseValidator(DetectionValidator):
'score': round(p[4], 5)})
def eval_json(self, stats):
"""Evaluates object detection model using COCO JSON format."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
@ -197,6 +206,7 @@ class PoseValidator(DetectionValidator):
def val(cfg=DEFAULT_CFG, use_python=False):
"""Performs validation on YOLO model using given data."""
model = cfg.model or 'yolov8n-pose.pt'
data = cfg.data or 'coco8-pose.yaml'