ultralytics 8.0.80 single-line docstring fixes (#2060)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-04-16 15:20:11 +02:00 committed by GitHub
parent 31db8ed163
commit 5bce1c3021
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
48 changed files with 418 additions and 420 deletions

View file

@ -66,7 +66,7 @@ class Compose:
class BaseMixTransform:
"""This implementation is from mmyolo"""
"""This implementation is from mmyolo."""
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
self.dataset = dataset
@ -77,12 +77,12 @@ class BaseMixTransform:
if random.uniform(0, 1) > self.p:
return labels
# get index of one or three other images
# Get index of one or three other images
indexes = self.get_indexes()
if isinstance(indexes, int):
indexes = [indexes]
# get images information will be used for Mosaic or MixUp
# Get images information will be used for Mosaic or MixUp
mix_labels = [self.dataset.get_label_info(i) for i in indexes]
if self.pre_transform is not None:
@ -132,7 +132,7 @@ class Mosaic(BaseMixTransform):
img = labels_patch['img']
h, w = labels_patch.pop('resized_shape')
# place img in img4
# Place img in img4
if i == 0: # top left
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
@ -158,7 +158,7 @@ class Mosaic(BaseMixTransform):
return final_labels
def _update_labels(self, labels, padw, padh):
"""Update labels"""
"""Update labels."""
nh, nw = labels['img'].shape[:2]
labels['instances'].convert_bbox(format='xyxy')
labels['instances'].denormalize(nw, nh)
@ -193,7 +193,7 @@ class MixUp(BaseMixTransform):
return random.randint(0, len(self.dataset) - 1)
def _mix_transform(self, labels):
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
"""Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf."""
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
labels2 = labels['mix_labels'][0]
labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
@ -217,12 +217,12 @@ class RandomPerspective:
self.scale = scale
self.shear = shear
self.perspective = perspective
# mosaic border
# Mosaic border
self.border = border
self.pre_transform = pre_transform
def affine_transform(self, img, border):
# Center
"""Center."""
C = np.eye(3, dtype=np.float32)
C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
@ -253,7 +253,7 @@ class RandomPerspective:
# Combined rotation matrix
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
# affine image
# Affine image
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
if self.perspective:
img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
@ -281,7 +281,7 @@ class RandomPerspective:
xy = xy @ M.T # transform
xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
# create new boxes
# Create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
@ -348,7 +348,7 @@ class RandomPerspective:
img = labels['img']
cls = labels['cls']
instances = labels.pop('instances')
# make sure the coord formats are right
# Make sure the coord formats are right
instances.convert_bbox(format='xyxy')
instances.denormalize(*img.shape[:2][::-1])
@ -362,19 +362,19 @@ class RandomPerspective:
segments = instances.segments
keypoints = instances.keypoints
# update bboxes if there are segments.
# Update bboxes if there are segments.
if len(segments):
bboxes, segments = self.apply_segments(segments, M)
if keypoints is not None:
keypoints = self.apply_keypoints(keypoints, M)
new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
# clip
# Clip
new_instances.clip(*self.size)
# filter instances
# Filter instances
instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
# make the bboxes have the same scale with new_bboxes
# Make the bboxes have the same scale with new_bboxes
i = self.box_candidates(box1=instances.bboxes.T,
box2=new_instances.bboxes.T,
area_thr=0.01 if len(segments) else 0.10)
@ -441,7 +441,7 @@ class RandomFlip:
if self.direction == 'horizontal' and random.random() < self.p:
img = np.fliplr(img)
instances.fliplr(w)
# for keypoints
# For keypoints
if self.flip_idx is not None and instances.keypoints is not None:
instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
labels['img'] = np.ascontiguousarray(img)
@ -450,7 +450,7 @@ class RandomFlip:
class LetterBox:
"""Resize image and padding for detection, instance segmentation, pose"""
"""Resize image and padding for detection, instance segmentation, pose."""
def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
self.new_shape = new_shape
@ -505,7 +505,7 @@ class LetterBox:
return img
def _update_labels(self, labels, ratio, padw, padh):
"""Update labels"""
"""Update labels."""
labels['instances'].convert_bbox(format='xyxy')
labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
labels['instances'].scale(*ratio)
@ -519,7 +519,7 @@ class CopyPaste:
self.p = p
def __call__(self, labels):
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
"""Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)."""
im = labels['img']
cls = labels['cls']
h, w = im.shape[:2]
@ -531,7 +531,7 @@ class CopyPaste:
_, w, _ = im.shape # height, width, channels
im_new = np.zeros(im.shape, np.uint8)
# calculate ioa first then select indexes randomly
# Calculate ioa first then select indexes randomly
ins_flip = deepcopy(instances)
ins_flip.fliplr(w)
@ -641,7 +641,7 @@ class Format:
labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
if self.return_keypoint:
labels['keypoints'] = torch.from_numpy(instances.keypoints)
# then we can use collate_fn
# Then we can use collate_fn
if self.batch_idx:
labels['batch_idx'] = torch.zeros(nl)
return labels
@ -654,7 +654,7 @@ class Format:
return img
def _format_segments(self, instances, cls, w, h):
"""convert polygon points to bitmap"""
"""convert polygon points to bitmap."""
segments = instances.segments
if self.mask_overlap:
masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)