ultralytics 8.0.239 Ultralytics Actions and hub-sdk adoption (#7431)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
parent
e795277391
commit
fe27db2f6e
139 changed files with 6870 additions and 5125 deletions
|
|
@ -4,5 +4,12 @@ from .base import BaseDataset
|
|||
from .build import build_dataloader, build_yolo_dataset, load_inference_source
|
||||
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
||||
|
||||
__all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset',
|
||||
'build_dataloader', 'load_inference_source')
|
||||
__all__ = (
|
||||
"BaseDataset",
|
||||
"ClassificationDataset",
|
||||
"SemanticDataset",
|
||||
"YOLODataset",
|
||||
"build_yolo_dataset",
|
||||
"build_dataloader",
|
||||
"load_inference_source",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from pathlib import Path
|
|||
from ultralytics import SAM, YOLO
|
||||
|
||||
|
||||
def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
|
||||
def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="", output_dir=None):
|
||||
"""
|
||||
Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='',
|
|||
|
||||
data = Path(data)
|
||||
if not output_dir:
|
||||
output_dir = data.parent / f'{data.stem}_auto_annotate_labels'
|
||||
output_dir = data.parent / f"{data.stem}_auto_annotate_labels"
|
||||
Path(output_dir).mkdir(exist_ok=True, parents=True)
|
||||
|
||||
det_results = det_model(data, stream=True, device=device)
|
||||
|
|
@ -41,10 +41,10 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='',
|
|||
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
|
||||
segments = sam_results[0].masks.xyn # noqa
|
||||
|
||||
with open(f'{str(Path(output_dir) / Path(result.path).stem)}.txt', 'w') as f:
|
||||
with open(f"{str(Path(output_dir) / Path(result.path).stem)}.txt", "w") as f:
|
||||
for i in range(len(segments)):
|
||||
s = segments[i]
|
||||
if len(s) == 0:
|
||||
continue
|
||||
segment = map(str, segments[i].reshape(-1).tolist())
|
||||
f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')
|
||||
f.write(f"{class_ids[i]} " + " ".join(segment) + "\n")
|
||||
|
|
|
|||
|
|
@ -117,11 +117,11 @@ class BaseMixTransform:
|
|||
if self.pre_transform is not None:
|
||||
for i, data in enumerate(mix_labels):
|
||||
mix_labels[i] = self.pre_transform(data)
|
||||
labels['mix_labels'] = mix_labels
|
||||
labels["mix_labels"] = mix_labels
|
||||
|
||||
# Mosaic or MixUp
|
||||
labels = self._mix_transform(labels)
|
||||
labels.pop('mix_labels', None)
|
||||
labels.pop("mix_labels", None)
|
||||
return labels
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
|
|
@ -149,8 +149,8 @@ class Mosaic(BaseMixTransform):
|
|||
|
||||
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
|
||||
"""Initializes the object with a dataset, image size, probability, and border."""
|
||||
assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.'
|
||||
assert n in (4, 9), 'grid must be equal to 4 or 9.'
|
||||
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
|
||||
assert n in (4, 9), "grid must be equal to 4 or 9."
|
||||
super().__init__(dataset=dataset, p=p)
|
||||
self.dataset = dataset
|
||||
self.imgsz = imgsz
|
||||
|
|
@ -166,20 +166,21 @@ class Mosaic(BaseMixTransform):
|
|||
|
||||
def _mix_transform(self, labels):
|
||||
"""Apply mixup transformation to the input image and labels."""
|
||||
assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.'
|
||||
assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.'
|
||||
return self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(
|
||||
labels) # This code is modified for mosaic3 method.
|
||||
assert labels.get("rect_shape", None) is None, "rect and mosaic are mutually exclusive."
|
||||
assert len(labels.get("mix_labels", [])), "There are no other images for mosaic augment."
|
||||
return (
|
||||
self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
|
||||
) # This code is modified for mosaic3 method.
|
||||
|
||||
def _mosaic3(self, labels):
|
||||
"""Create a 1x3 image mosaic."""
|
||||
mosaic_labels = []
|
||||
s = self.imgsz
|
||||
for i in range(3):
|
||||
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
|
||||
labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
|
||||
# Load image
|
||||
img = labels_patch['img']
|
||||
h, w = labels_patch.pop('resized_shape')
|
||||
img = labels_patch["img"]
|
||||
h, w = labels_patch.pop("resized_shape")
|
||||
|
||||
# Place img in img3
|
||||
if i == 0: # center
|
||||
|
|
@ -194,7 +195,7 @@ class Mosaic(BaseMixTransform):
|
|||
padw, padh = c[:2]
|
||||
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
|
||||
|
||||
img3[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img3[ymin:ymax, xmin:xmax]
|
||||
img3[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img3[ymin:ymax, xmin:xmax]
|
||||
# hp, wp = h, w # height, width previous for next iteration
|
||||
|
||||
# Labels assuming imgsz*2 mosaic size
|
||||
|
|
@ -202,7 +203,7 @@ class Mosaic(BaseMixTransform):
|
|||
mosaic_labels.append(labels_patch)
|
||||
final_labels = self._cat_labels(mosaic_labels)
|
||||
|
||||
final_labels['img'] = img3[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
|
||||
final_labels["img"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]
|
||||
return final_labels
|
||||
|
||||
def _mosaic4(self, labels):
|
||||
|
|
@ -211,10 +212,10 @@ class Mosaic(BaseMixTransform):
|
|||
s = self.imgsz
|
||||
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
|
||||
for i in range(4):
|
||||
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
|
||||
labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
|
||||
# Load image
|
||||
img = labels_patch['img']
|
||||
h, w = labels_patch.pop('resized_shape')
|
||||
img = labels_patch["img"]
|
||||
h, w = labels_patch.pop("resized_shape")
|
||||
|
||||
# Place img in img4
|
||||
if i == 0: # top left
|
||||
|
|
@ -238,7 +239,7 @@ class Mosaic(BaseMixTransform):
|
|||
labels_patch = self._update_labels(labels_patch, padw, padh)
|
||||
mosaic_labels.append(labels_patch)
|
||||
final_labels = self._cat_labels(mosaic_labels)
|
||||
final_labels['img'] = img4
|
||||
final_labels["img"] = img4
|
||||
return final_labels
|
||||
|
||||
def _mosaic9(self, labels):
|
||||
|
|
@ -247,10 +248,10 @@ class Mosaic(BaseMixTransform):
|
|||
s = self.imgsz
|
||||
hp, wp = -1, -1 # height, width previous
|
||||
for i in range(9):
|
||||
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
|
||||
labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
|
||||
# Load image
|
||||
img = labels_patch['img']
|
||||
h, w = labels_patch.pop('resized_shape')
|
||||
img = labels_patch["img"]
|
||||
h, w = labels_patch.pop("resized_shape")
|
||||
|
||||
# Place img in img9
|
||||
if i == 0: # center
|
||||
|
|
@ -278,7 +279,7 @@ class Mosaic(BaseMixTransform):
|
|||
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
|
||||
|
||||
# Image
|
||||
img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
|
||||
img9[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img9[ymin:ymax, xmin:xmax]
|
||||
hp, wp = h, w # height, width previous for next iteration
|
||||
|
||||
# Labels assuming imgsz*2 mosaic size
|
||||
|
|
@ -286,16 +287,16 @@ class Mosaic(BaseMixTransform):
|
|||
mosaic_labels.append(labels_patch)
|
||||
final_labels = self._cat_labels(mosaic_labels)
|
||||
|
||||
final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
|
||||
final_labels["img"] = img9[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]
|
||||
return final_labels
|
||||
|
||||
@staticmethod
|
||||
def _update_labels(labels, padw, padh):
|
||||
"""Update labels."""
|
||||
nh, nw = labels['img'].shape[:2]
|
||||
labels['instances'].convert_bbox(format='xyxy')
|
||||
labels['instances'].denormalize(nw, nh)
|
||||
labels['instances'].add_padding(padw, padh)
|
||||
nh, nw = labels["img"].shape[:2]
|
||||
labels["instances"].convert_bbox(format="xyxy")
|
||||
labels["instances"].denormalize(nw, nh)
|
||||
labels["instances"].add_padding(padw, padh)
|
||||
return labels
|
||||
|
||||
def _cat_labels(self, mosaic_labels):
|
||||
|
|
@ -306,18 +307,20 @@ class Mosaic(BaseMixTransform):
|
|||
instances = []
|
||||
imgsz = self.imgsz * 2 # mosaic imgsz
|
||||
for labels in mosaic_labels:
|
||||
cls.append(labels['cls'])
|
||||
instances.append(labels['instances'])
|
||||
cls.append(labels["cls"])
|
||||
instances.append(labels["instances"])
|
||||
# Final labels
|
||||
final_labels = {
|
||||
'im_file': mosaic_labels[0]['im_file'],
|
||||
'ori_shape': mosaic_labels[0]['ori_shape'],
|
||||
'resized_shape': (imgsz, imgsz),
|
||||
'cls': np.concatenate(cls, 0),
|
||||
'instances': Instances.concatenate(instances, axis=0),
|
||||
'mosaic_border': self.border} # final_labels
|
||||
final_labels['instances'].clip(imgsz, imgsz)
|
||||
good = final_labels['instances'].remove_zero_area_boxes()
|
||||
final_labels['cls'] = final_labels['cls'][good]
|
||||
"im_file": mosaic_labels[0]["im_file"],
|
||||
"ori_shape": mosaic_labels[0]["ori_shape"],
|
||||
"resized_shape": (imgsz, imgsz),
|
||||
"cls": np.concatenate(cls, 0),
|
||||
"instances": Instances.concatenate(instances, axis=0),
|
||||
"mosaic_border": self.border,
|
||||
}
|
||||
final_labels["instances"].clip(imgsz, imgsz)
|
||||
good = final_labels["instances"].remove_zero_area_boxes()
|
||||
final_labels["cls"] = final_labels["cls"][good]
|
||||
return final_labels
|
||||
|
||||
|
||||
|
|
@ -335,10 +338,10 @@ class MixUp(BaseMixTransform):
|
|||
def _mix_transform(self, labels):
|
||||
"""Applies MixUp augmentation as per https://arxiv.org/pdf/1710.09412.pdf."""
|
||||
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
||||
labels2 = labels['mix_labels'][0]
|
||||
labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
|
||||
labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
|
||||
labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
|
||||
labels2 = labels["mix_labels"][0]
|
||||
labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
|
||||
labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
|
||||
labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
|
||||
return labels
|
||||
|
||||
|
||||
|
|
@ -366,14 +369,9 @@ class RandomPerspective:
|
|||
box_candidates(box1, box2): Filters out bounding boxes that don't meet certain criteria post-transformation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
degrees=0.0,
|
||||
translate=0.1,
|
||||
scale=0.5,
|
||||
shear=0.0,
|
||||
perspective=0.0,
|
||||
border=(0, 0),
|
||||
pre_transform=None):
|
||||
def __init__(
|
||||
self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0), pre_transform=None
|
||||
):
|
||||
"""Initializes RandomPerspective object with transformation parameters."""
|
||||
|
||||
self.degrees = degrees
|
||||
|
|
@ -519,18 +517,18 @@ class RandomPerspective:
|
|||
Args:
|
||||
labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
|
||||
"""
|
||||
if self.pre_transform and 'mosaic_border' not in labels:
|
||||
if self.pre_transform and "mosaic_border" not in labels:
|
||||
labels = self.pre_transform(labels)
|
||||
labels.pop('ratio_pad', None) # do not need ratio pad
|
||||
labels.pop("ratio_pad", None) # do not need ratio pad
|
||||
|
||||
img = labels['img']
|
||||
cls = labels['cls']
|
||||
instances = labels.pop('instances')
|
||||
img = labels["img"]
|
||||
cls = labels["cls"]
|
||||
instances = labels.pop("instances")
|
||||
# Make sure the coord formats are right
|
||||
instances.convert_bbox(format='xyxy')
|
||||
instances.convert_bbox(format="xyxy")
|
||||
instances.denormalize(*img.shape[:2][::-1])
|
||||
|
||||
border = labels.pop('mosaic_border', self.border)
|
||||
border = labels.pop("mosaic_border", self.border)
|
||||
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
|
||||
# M is affine matrix
|
||||
# Scale for func:`box_candidates`
|
||||
|
|
@ -546,20 +544,20 @@ class RandomPerspective:
|
|||
|
||||
if keypoints is not None:
|
||||
keypoints = self.apply_keypoints(keypoints, M)
|
||||
new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
|
||||
new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
|
||||
# Clip
|
||||
new_instances.clip(*self.size)
|
||||
|
||||
# Filter instances
|
||||
instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
|
||||
# Make the bboxes have the same scale with new_bboxes
|
||||
i = self.box_candidates(box1=instances.bboxes.T,
|
||||
box2=new_instances.bboxes.T,
|
||||
area_thr=0.01 if len(segments) else 0.10)
|
||||
labels['instances'] = new_instances[i]
|
||||
labels['cls'] = cls[i]
|
||||
labels['img'] = img
|
||||
labels['resized_shape'] = img.shape[:2]
|
||||
i = self.box_candidates(
|
||||
box1=instances.bboxes.T, box2=new_instances.bboxes.T, area_thr=0.01 if len(segments) else 0.10
|
||||
)
|
||||
labels["instances"] = new_instances[i]
|
||||
labels["cls"] = cls[i]
|
||||
labels["img"] = img
|
||||
labels["resized_shape"] = img.shape[:2]
|
||||
return labels
|
||||
|
||||
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
|
||||
|
|
@ -611,7 +609,7 @@ class RandomHSV:
|
|||
|
||||
The modified image replaces the original image in the input 'labels' dict.
|
||||
"""
|
||||
img = labels['img']
|
||||
img = labels["img"]
|
||||
if self.hgain or self.sgain or self.vgain:
|
||||
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
|
||||
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
||||
|
|
@ -634,7 +632,7 @@ class RandomFlip:
|
|||
Also updates any instances (bounding boxes, keypoints, etc.) accordingly.
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None:
|
||||
def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None:
|
||||
"""
|
||||
Initializes the RandomFlip class with probability and direction.
|
||||
|
||||
|
|
@ -644,7 +642,7 @@ class RandomFlip:
|
|||
Default is 'horizontal'.
|
||||
flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
|
||||
"""
|
||||
assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
|
||||
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
|
||||
assert 0 <= p <= 1.0
|
||||
|
||||
self.p = p
|
||||
|
|
@ -662,25 +660,25 @@ class RandomFlip:
|
|||
Returns:
|
||||
(dict): The same dict with the flipped image and updated instances under the 'img' and 'instances' keys.
|
||||
"""
|
||||
img = labels['img']
|
||||
instances = labels.pop('instances')
|
||||
instances.convert_bbox(format='xywh')
|
||||
img = labels["img"]
|
||||
instances = labels.pop("instances")
|
||||
instances.convert_bbox(format="xywh")
|
||||
h, w = img.shape[:2]
|
||||
h = 1 if instances.normalized else h
|
||||
w = 1 if instances.normalized else w
|
||||
|
||||
# Flip up-down
|
||||
if self.direction == 'vertical' and random.random() < self.p:
|
||||
if self.direction == "vertical" and random.random() < self.p:
|
||||
img = np.flipud(img)
|
||||
instances.flipud(h)
|
||||
if self.direction == 'horizontal' and random.random() < self.p:
|
||||
if self.direction == "horizontal" and random.random() < self.p:
|
||||
img = np.fliplr(img)
|
||||
instances.fliplr(w)
|
||||
# For keypoints
|
||||
if self.flip_idx is not None and instances.keypoints is not None:
|
||||
instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
|
||||
labels['img'] = np.ascontiguousarray(img)
|
||||
labels['instances'] = instances
|
||||
labels["img"] = np.ascontiguousarray(img)
|
||||
labels["instances"] = instances
|
||||
return labels
|
||||
|
||||
|
||||
|
|
@ -700,9 +698,9 @@ class LetterBox:
|
|||
"""Return updated labels and image with added border."""
|
||||
if labels is None:
|
||||
labels = {}
|
||||
img = labels.get('img') if image is None else image
|
||||
img = labels.get("img") if image is None else image
|
||||
shape = img.shape[:2] # current shape [height, width]
|
||||
new_shape = labels.pop('rect_shape', self.new_shape)
|
||||
new_shape = labels.pop("rect_shape", self.new_shape)
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
||||
|
|
@ -730,25 +728,26 @@ class LetterBox:
|
|||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
|
||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
|
||||
value=(114, 114, 114)) # add border
|
||||
if labels.get('ratio_pad'):
|
||||
labels['ratio_pad'] = (labels['ratio_pad'], (left, top)) # for evaluation
|
||||
img = cv2.copyMakeBorder(
|
||||
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
||||
) # add border
|
||||
if labels.get("ratio_pad"):
|
||||
labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation
|
||||
|
||||
if len(labels):
|
||||
labels = self._update_labels(labels, ratio, dw, dh)
|
||||
labels['img'] = img
|
||||
labels['resized_shape'] = new_shape
|
||||
labels["img"] = img
|
||||
labels["resized_shape"] = new_shape
|
||||
return labels
|
||||
else:
|
||||
return img
|
||||
|
||||
def _update_labels(self, labels, ratio, padw, padh):
|
||||
"""Update labels."""
|
||||
labels['instances'].convert_bbox(format='xyxy')
|
||||
labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
|
||||
labels['instances'].scale(*ratio)
|
||||
labels['instances'].add_padding(padw, padh)
|
||||
labels["instances"].convert_bbox(format="xyxy")
|
||||
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
|
||||
labels["instances"].scale(*ratio)
|
||||
labels["instances"].add_padding(padw, padh)
|
||||
return labels
|
||||
|
||||
|
||||
|
|
@ -785,11 +784,11 @@ class CopyPaste:
|
|||
1. Instances are expected to have 'segments' as one of their attributes for this augmentation to work.
|
||||
2. This method modifies the input dictionary 'labels' in place.
|
||||
"""
|
||||
im = labels['img']
|
||||
cls = labels['cls']
|
||||
im = labels["img"]
|
||||
cls = labels["cls"]
|
||||
h, w = im.shape[:2]
|
||||
instances = labels.pop('instances')
|
||||
instances.convert_bbox(format='xyxy')
|
||||
instances = labels.pop("instances")
|
||||
instances.convert_bbox(format="xyxy")
|
||||
instances.denormalize(w, h)
|
||||
if self.p and len(instances.segments):
|
||||
n = len(instances)
|
||||
|
|
@ -812,9 +811,9 @@ class CopyPaste:
|
|||
i = cv2.flip(im_new, 1).astype(bool)
|
||||
im[i] = result[i]
|
||||
|
||||
labels['img'] = im
|
||||
labels['cls'] = cls
|
||||
labels['instances'] = instances
|
||||
labels["img"] = im
|
||||
labels["cls"] = cls
|
||||
labels["instances"] = instances
|
||||
return labels
|
||||
|
||||
|
||||
|
|
@ -831,12 +830,13 @@ class Albumentations:
|
|||
"""Initialize the transform object for YOLO bbox formatted params."""
|
||||
self.p = p
|
||||
self.transform = None
|
||||
prefix = colorstr('albumentations: ')
|
||||
prefix = colorstr("albumentations: ")
|
||||
try:
|
||||
import albumentations as A
|
||||
|
||||
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
||||
check_version(A.__version__, "1.0.3", hard=True) # version requirement
|
||||
|
||||
# Transforms
|
||||
T = [
|
||||
A.Blur(p=0.01),
|
||||
A.MedianBlur(p=0.01),
|
||||
|
|
@ -844,31 +844,32 @@ class Albumentations:
|
|||
A.CLAHE(p=0.01),
|
||||
A.RandomBrightnessContrast(p=0.0),
|
||||
A.RandomGamma(p=0.0),
|
||||
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
|
||||
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
||||
A.ImageCompression(quality_lower=75, p=0.0),
|
||||
]
|
||||
self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
|
||||
|
||||
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
||||
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
|
||||
except ImportError: # package not installed, skip
|
||||
pass
|
||||
except Exception as e:
|
||||
LOGGER.info(f'{prefix}{e}')
|
||||
LOGGER.info(f"{prefix}{e}")
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Generates object detections and returns a dictionary with detection results."""
|
||||
im = labels['img']
|
||||
cls = labels['cls']
|
||||
im = labels["img"]
|
||||
cls = labels["cls"]
|
||||
if len(cls):
|
||||
labels['instances'].convert_bbox('xywh')
|
||||
labels['instances'].normalize(*im.shape[:2][::-1])
|
||||
bboxes = labels['instances'].bboxes
|
||||
labels["instances"].convert_bbox("xywh")
|
||||
labels["instances"].normalize(*im.shape[:2][::-1])
|
||||
bboxes = labels["instances"].bboxes
|
||||
# TODO: add supports of segments and keypoints
|
||||
if self.transform and random.random() < self.p:
|
||||
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
||||
if len(new['class_labels']) > 0: # skip update if no bbox in new im
|
||||
labels['img'] = new['image']
|
||||
labels['cls'] = np.array(new['class_labels'])
|
||||
bboxes = np.array(new['bboxes'], dtype=np.float32)
|
||||
labels['instances'].update(bboxes=bboxes)
|
||||
if len(new["class_labels"]) > 0: # skip update if no bbox in new im
|
||||
labels["img"] = new["image"]
|
||||
labels["cls"] = np.array(new["class_labels"])
|
||||
bboxes = np.array(new["bboxes"], dtype=np.float32)
|
||||
labels["instances"].update(bboxes=bboxes)
|
||||
return labels
|
||||
|
||||
|
||||
|
|
@ -888,15 +889,17 @@ class Format:
|
|||
batch_idx (bool): Keep batch indexes. Default is True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
bbox_format='xywh',
|
||||
normalize=True,
|
||||
return_mask=False,
|
||||
return_keypoint=False,
|
||||
return_obb=False,
|
||||
mask_ratio=4,
|
||||
mask_overlap=True,
|
||||
batch_idx=True):
|
||||
def __init__(
|
||||
self,
|
||||
bbox_format="xywh",
|
||||
normalize=True,
|
||||
return_mask=False,
|
||||
return_keypoint=False,
|
||||
return_obb=False,
|
||||
mask_ratio=4,
|
||||
mask_overlap=True,
|
||||
batch_idx=True,
|
||||
):
|
||||
"""Initializes the Format class with given parameters."""
|
||||
self.bbox_format = bbox_format
|
||||
self.normalize = normalize
|
||||
|
|
@ -909,10 +912,10 @@ class Format:
|
|||
|
||||
def __call__(self, labels):
|
||||
"""Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'."""
|
||||
img = labels.pop('img')
|
||||
img = labels.pop("img")
|
||||
h, w = img.shape[:2]
|
||||
cls = labels.pop('cls')
|
||||
instances = labels.pop('instances')
|
||||
cls = labels.pop("cls")
|
||||
instances = labels.pop("instances")
|
||||
instances.convert_bbox(format=self.bbox_format)
|
||||
instances.denormalize(w, h)
|
||||
nl = len(instances)
|
||||
|
|
@ -922,22 +925,24 @@ class Format:
|
|||
masks, instances, cls = self._format_segments(instances, cls, w, h)
|
||||
masks = torch.from_numpy(masks)
|
||||
else:
|
||||
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
|
||||
img.shape[1] // self.mask_ratio)
|
||||
labels['masks'] = masks
|
||||
masks = torch.zeros(
|
||||
1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio
|
||||
)
|
||||
labels["masks"] = masks
|
||||
if self.normalize:
|
||||
instances.normalize(w, h)
|
||||
labels['img'] = self._format_img(img)
|
||||
labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
||||
labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
||||
labels["img"] = self._format_img(img)
|
||||
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
||||
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
||||
if self.return_keypoint:
|
||||
labels['keypoints'] = torch.from_numpy(instances.keypoints)
|
||||
labels["keypoints"] = torch.from_numpy(instances.keypoints)
|
||||
if self.return_obb:
|
||||
labels['bboxes'] = xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(
|
||||
instances.segments) else torch.zeros((0, 5))
|
||||
labels["bboxes"] = (
|
||||
xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5))
|
||||
)
|
||||
# Then we can use collate_fn
|
||||
if self.batch_idx:
|
||||
labels['batch_idx'] = torch.zeros(nl)
|
||||
labels["batch_idx"] = torch.zeros(nl)
|
||||
return labels
|
||||
|
||||
def _format_img(self, img):
|
||||
|
|
@ -964,33 +969,39 @@ class Format:
|
|||
|
||||
def v8_transforms(dataset, imgsz, hyp, stretch=False):
|
||||
"""Convert images to a size suitable for YOLOv8 training."""
|
||||
pre_transform = Compose([
|
||||
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
|
||||
CopyPaste(p=hyp.copy_paste),
|
||||
RandomPerspective(
|
||||
degrees=hyp.degrees,
|
||||
translate=hyp.translate,
|
||||
scale=hyp.scale,
|
||||
shear=hyp.shear,
|
||||
perspective=hyp.perspective,
|
||||
pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
|
||||
)])
|
||||
flip_idx = dataset.data.get('flip_idx', []) # for keypoints augmentation
|
||||
pre_transform = Compose(
|
||||
[
|
||||
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
|
||||
CopyPaste(p=hyp.copy_paste),
|
||||
RandomPerspective(
|
||||
degrees=hyp.degrees,
|
||||
translate=hyp.translate,
|
||||
scale=hyp.scale,
|
||||
shear=hyp.shear,
|
||||
perspective=hyp.perspective,
|
||||
pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
|
||||
),
|
||||
]
|
||||
)
|
||||
flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation
|
||||
if dataset.use_keypoints:
|
||||
kpt_shape = dataset.data.get('kpt_shape', None)
|
||||
kpt_shape = dataset.data.get("kpt_shape", None)
|
||||
if len(flip_idx) == 0 and hyp.fliplr > 0.0:
|
||||
hyp.fliplr = 0.0
|
||||
LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
|
||||
elif flip_idx and (len(flip_idx) != kpt_shape[0]):
|
||||
raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
|
||||
raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}")
|
||||
|
||||
return Compose([
|
||||
pre_transform,
|
||||
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
||||
Albumentations(p=1.0),
|
||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||
RandomFlip(direction='vertical', p=hyp.flipud),
|
||||
RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms
|
||||
return Compose(
|
||||
[
|
||||
pre_transform,
|
||||
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
||||
Albumentations(p=1.0),
|
||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||
RandomFlip(direction="vertical", p=hyp.flipud),
|
||||
RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx),
|
||||
]
|
||||
) # transforms
|
||||
|
||||
|
||||
# Classification augmentations -----------------------------------------------------------------------------------------
|
||||
|
|
@ -1031,10 +1042,13 @@ def classify_transforms(
|
|||
tfl = [T.Resize(scale_size)]
|
||||
tfl += [T.CenterCrop(size)]
|
||||
|
||||
tfl += [T.ToTensor(), T.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std),
|
||||
)]
|
||||
tfl += [
|
||||
T.ToTensor(),
|
||||
T.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std),
|
||||
),
|
||||
]
|
||||
|
||||
return T.Compose(tfl)
|
||||
|
||||
|
|
@ -1053,7 +1067,7 @@ def classify_augmentations(
|
|||
hsv_s=0.4, # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v=0.4, # image HSV-Value augmentation (fraction)
|
||||
force_color_jitter=False,
|
||||
erasing=0.,
|
||||
erasing=0.0,
|
||||
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
|
||||
):
|
||||
"""
|
||||
|
|
@ -1080,13 +1094,13 @@ def classify_augmentations(
|
|||
"""
|
||||
# Transforms to apply if albumentations not installed
|
||||
if not isinstance(size, int):
|
||||
raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
|
||||
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
|
||||
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
|
||||
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
|
||||
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
|
||||
if hflip > 0.:
|
||||
if hflip > 0.0:
|
||||
primary_tfl += [T.RandomHorizontalFlip(p=hflip)]
|
||||
if vflip > 0.:
|
||||
if vflip > 0.0:
|
||||
primary_tfl += [T.RandomVerticalFlip(p=vflip)]
|
||||
|
||||
secondary_tfl = []
|
||||
|
|
@ -1097,27 +1111,29 @@ def classify_augmentations(
|
|||
# this allows override without breaking old hparm cfgs
|
||||
disable_color_jitter = not force_color_jitter
|
||||
|
||||
if auto_augment == 'randaugment':
|
||||
if auto_augment == "randaugment":
|
||||
if TORCHVISION_0_11:
|
||||
secondary_tfl += [T.RandAugment(interpolation=interpolation)]
|
||||
else:
|
||||
LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.')
|
||||
|
||||
elif auto_augment == 'augmix':
|
||||
elif auto_augment == "augmix":
|
||||
if TORCHVISION_0_13:
|
||||
secondary_tfl += [T.AugMix(interpolation=interpolation)]
|
||||
else:
|
||||
LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.')
|
||||
|
||||
elif auto_augment == 'autoaugment':
|
||||
elif auto_augment == "autoaugment":
|
||||
if TORCHVISION_0_10:
|
||||
secondary_tfl += [T.AutoAugment(interpolation=interpolation)]
|
||||
else:
|
||||
LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.')
|
||||
|
||||
else:
|
||||
raise ValueError(f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
|
||||
f'"augmix", "autoaugment" or None')
|
||||
raise ValueError(
|
||||
f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
|
||||
f'"augmix", "autoaugment" or None'
|
||||
)
|
||||
|
||||
if not disable_color_jitter:
|
||||
secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)]
|
||||
|
|
@ -1125,7 +1141,8 @@ def classify_augmentations(
|
|||
final_tfl = [
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
|
||||
T.RandomErasing(p=erasing, inplace=True)]
|
||||
T.RandomErasing(p=erasing, inplace=True),
|
||||
]
|
||||
|
||||
return T.Compose(primary_tfl + secondary_tfl + final_tfl)
|
||||
|
||||
|
|
@ -1177,7 +1194,7 @@ class ClassifyLetterBox:
|
|||
|
||||
# Create padded image
|
||||
im_out = np.full((hs, ws, 3), 114, dtype=im.dtype)
|
||||
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
return im_out
|
||||
|
||||
|
||||
|
|
@ -1205,7 +1222,7 @@ class CenterCrop:
|
|||
imh, imw = im.shape[:2]
|
||||
m = min(imh, imw) # min dimension
|
||||
top, left = (imh - m) // 2, (imw - m) // 2
|
||||
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
||||
return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
|
||||
# NOTE: keep this class for backward compatibility
|
||||
|
|
|
|||
|
|
@ -47,20 +47,22 @@ class BaseDataset(Dataset):
|
|||
transforms (callable): Image transformation function.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_path,
|
||||
imgsz=640,
|
||||
cache=False,
|
||||
augment=True,
|
||||
hyp=DEFAULT_CFG,
|
||||
prefix='',
|
||||
rect=False,
|
||||
batch_size=16,
|
||||
stride=32,
|
||||
pad=0.5,
|
||||
single_cls=False,
|
||||
classes=None,
|
||||
fraction=1.0):
|
||||
def __init__(
|
||||
self,
|
||||
img_path,
|
||||
imgsz=640,
|
||||
cache=False,
|
||||
augment=True,
|
||||
hyp=DEFAULT_CFG,
|
||||
prefix="",
|
||||
rect=False,
|
||||
batch_size=16,
|
||||
stride=32,
|
||||
pad=0.5,
|
||||
single_cls=False,
|
||||
classes=None,
|
||||
fraction=1.0,
|
||||
):
|
||||
"""Initialize BaseDataset with given configuration and options."""
|
||||
super().__init__()
|
||||
self.img_path = img_path
|
||||
|
|
@ -86,10 +88,10 @@ class BaseDataset(Dataset):
|
|||
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
||||
|
||||
# Cache images
|
||||
if cache == 'ram' and not self.check_cache_ram():
|
||||
if cache == "ram" and not self.check_cache_ram():
|
||||
cache = False
|
||||
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
||||
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
|
||||
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
|
||||
if cache:
|
||||
self.cache_images(cache)
|
||||
|
||||
|
|
@ -103,23 +105,23 @@ class BaseDataset(Dataset):
|
|||
for p in img_path if isinstance(img_path, list) else [img_path]:
|
||||
p = Path(p) # os-agnostic
|
||||
if p.is_dir(): # dir
|
||||
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
|
||||
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
|
||||
# F = list(p.rglob('*.*')) # pathlib
|
||||
elif p.is_file(): # file
|
||||
with open(p) as t:
|
||||
t = t.read().strip().splitlines()
|
||||
parent = str(p.parent) + os.sep
|
||||
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
|
||||
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
|
||||
# F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
||||
else:
|
||||
raise FileNotFoundError(f'{self.prefix}{p} does not exist')
|
||||
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
|
||||
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
||||
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
|
||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||
assert im_files, f'{self.prefix}No images found in {img_path}'
|
||||
assert im_files, f"{self.prefix}No images found in {img_path}"
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
|
||||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
||||
if self.fraction < 1:
|
||||
im_files = im_files[:round(len(im_files) * self.fraction)]
|
||||
im_files = im_files[: round(len(im_files) * self.fraction)]
|
||||
return im_files
|
||||
|
||||
def update_labels(self, include_class: Optional[list]):
|
||||
|
|
@ -127,19 +129,19 @@ class BaseDataset(Dataset):
|
|||
include_class_array = np.array(include_class).reshape(1, -1)
|
||||
for i in range(len(self.labels)):
|
||||
if include_class is not None:
|
||||
cls = self.labels[i]['cls']
|
||||
bboxes = self.labels[i]['bboxes']
|
||||
segments = self.labels[i]['segments']
|
||||
keypoints = self.labels[i]['keypoints']
|
||||
cls = self.labels[i]["cls"]
|
||||
bboxes = self.labels[i]["bboxes"]
|
||||
segments = self.labels[i]["segments"]
|
||||
keypoints = self.labels[i]["keypoints"]
|
||||
j = (cls == include_class_array).any(1)
|
||||
self.labels[i]['cls'] = cls[j]
|
||||
self.labels[i]['bboxes'] = bboxes[j]
|
||||
self.labels[i]["cls"] = cls[j]
|
||||
self.labels[i]["bboxes"] = bboxes[j]
|
||||
if segments:
|
||||
self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
|
||||
self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx]
|
||||
if keypoints is not None:
|
||||
self.labels[i]['keypoints'] = keypoints[j]
|
||||
self.labels[i]["keypoints"] = keypoints[j]
|
||||
if self.single_cls:
|
||||
self.labels[i]['cls'][:, 0] = 0
|
||||
self.labels[i]["cls"][:, 0] = 0
|
||||
|
||||
def load_image(self, i, rect_mode=True):
|
||||
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
||||
|
|
@ -149,13 +151,13 @@ class BaseDataset(Dataset):
|
|||
try:
|
||||
im = np.load(fn)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}')
|
||||
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}")
|
||||
Path(fn).unlink(missing_ok=True)
|
||||
im = cv2.imread(f) # BGR
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if im is None:
|
||||
raise FileNotFoundError(f'Image Not Found {f}')
|
||||
raise FileNotFoundError(f"Image Not Found {f}")
|
||||
|
||||
h0, w0 = im.shape[:2] # orig hw
|
||||
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
|
||||
|
|
@ -181,17 +183,17 @@ class BaseDataset(Dataset):
|
|||
def cache_images(self, cache):
|
||||
"""Cache images to memory or disk."""
|
||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
|
||||
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(fcn, range(self.ni))
|
||||
pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if cache == 'disk':
|
||||
if cache == "disk":
|
||||
b += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
b += self.ims[i].nbytes
|
||||
pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})'
|
||||
pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {cache})"
|
||||
pbar.close()
|
||||
|
||||
def cache_images_to_disk(self, i):
|
||||
|
|
@ -207,15 +209,17 @@ class BaseDataset(Dataset):
|
|||
for _ in range(n):
|
||||
im = cv2.imread(random.choice(self.im_files)) # sample image
|
||||
ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
|
||||
b += im.nbytes * ratio ** 2
|
||||
b += im.nbytes * ratio**2
|
||||
mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
|
||||
mem = psutil.virtual_memory()
|
||||
cache = mem_required < mem.available # to cache or not to cache, that is the question
|
||||
if not cache:
|
||||
LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
|
||||
f'with {int(safety_margin * 100)}% safety margin but only '
|
||||
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
|
||||
f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
|
||||
LOGGER.info(
|
||||
f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
|
||||
f'with {int(safety_margin * 100)}% safety margin but only '
|
||||
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
|
||||
f"{'caching images ✅' if cache else 'not caching images ⚠️'}"
|
||||
)
|
||||
return cache
|
||||
|
||||
def set_rectangle(self):
|
||||
|
|
@ -223,7 +227,7 @@ class BaseDataset(Dataset):
|
|||
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
||||
nb = bi[-1] + 1 # number of batches
|
||||
|
||||
s = np.array([x.pop('shape') for x in self.labels]) # hw
|
||||
s = np.array([x.pop("shape") for x in self.labels]) # hw
|
||||
ar = s[:, 0] / s[:, 1] # aspect ratio
|
||||
irect = ar.argsort()
|
||||
self.im_files = [self.im_files[i] for i in irect]
|
||||
|
|
@ -250,12 +254,14 @@ class BaseDataset(Dataset):
|
|||
def get_image_and_label(self, index):
|
||||
"""Get and return label information from the dataset."""
|
||||
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
|
||||
label.pop('shape', None) # shape is for rect, remove it
|
||||
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
|
||||
label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0],
|
||||
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
|
||||
label.pop("shape", None) # shape is for rect, remove it
|
||||
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
|
||||
label["ratio_pad"] = (
|
||||
label["resized_shape"][0] / label["ori_shape"][0],
|
||||
label["resized_shape"][1] / label["ori_shape"][1],
|
||||
) # for evaluation
|
||||
if self.rect:
|
||||
label['rect_shape'] = self.batch_shapes[self.batch[index]]
|
||||
label["rect_shape"] = self.batch_shapes[self.batch[index]]
|
||||
return self.update_labels_info(label)
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
|||
|
|
@ -9,8 +9,16 @@ import torch
|
|||
from PIL import Image
|
||||
from torch.utils.data import dataloader, distributed
|
||||
|
||||
from ultralytics.data.loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor,
|
||||
SourceTypes, autocast_list)
|
||||
from ultralytics.data.loaders import (
|
||||
LOADERS,
|
||||
LoadImages,
|
||||
LoadPilAndNumpy,
|
||||
LoadScreenshots,
|
||||
LoadStreams,
|
||||
LoadTensor,
|
||||
SourceTypes,
|
||||
autocast_list,
|
||||
)
|
||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.utils import RANK, colorstr
|
||||
from ultralytics.utils.checks import check_file
|
||||
|
|
@ -29,7 +37,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
|
|||
def __init__(self, *args, **kwargs):
|
||||
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
|
||||
super().__init__(*args, **kwargs)
|
||||
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
||||
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
|
||||
self.iterator = super().__iter__()
|
||||
|
||||
def __len__(self):
|
||||
|
|
@ -70,29 +78,30 @@ class _RepeatSampler:
|
|||
|
||||
def seed_worker(worker_id): # noqa
|
||||
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
|
||||
worker_seed = torch.initial_seed() % 2 ** 32
|
||||
worker_seed = torch.initial_seed() % 2**32
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
|
||||
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
|
||||
"""Build YOLO Dataset."""
|
||||
return YOLODataset(
|
||||
img_path=img_path,
|
||||
imgsz=cfg.imgsz,
|
||||
batch_size=batch,
|
||||
augment=mode == 'train', # augmentation
|
||||
augment=mode == "train", # augmentation
|
||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||
rect=cfg.rect or rect, # rectangular batches
|
||||
cache=cfg.cache or None,
|
||||
single_cls=cfg.single_cls or False,
|
||||
stride=int(stride),
|
||||
pad=0.0 if mode == 'train' else 0.5,
|
||||
prefix=colorstr(f'{mode}: '),
|
||||
pad=0.0 if mode == "train" else 0.5,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
task=cfg.task,
|
||||
classes=cfg.classes,
|
||||
data=data,
|
||||
fraction=cfg.fraction if mode == 'train' else 1.0)
|
||||
fraction=cfg.fraction if mode == "train" else 1.0,
|
||||
)
|
||||
|
||||
|
||||
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
||||
|
|
@ -103,15 +112,17 @@ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
|||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(6148914691236517205 + RANK)
|
||||
return InfiniteDataLoader(dataset=dataset,
|
||||
batch_size=batch,
|
||||
shuffle=shuffle and sampler is None,
|
||||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
pin_memory=PIN_MEMORY,
|
||||
collate_fn=getattr(dataset, 'collate_fn', None),
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator)
|
||||
return InfiniteDataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch,
|
||||
shuffle=shuffle and sampler is None,
|
||||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
pin_memory=PIN_MEMORY,
|
||||
collate_fn=getattr(dataset, "collate_fn", None),
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
|
||||
def check_source(source):
|
||||
|
|
@ -120,9 +131,9 @@ def check_source(source):
|
|||
if isinstance(source, (str, int, Path)): # int for local usb camera
|
||||
source = str(source)
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://', 'tcp://'))
|
||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||
screenshot = source.lower() == 'screen'
|
||||
is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
|
||||
webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
|
||||
screenshot = source.lower() == "screen"
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
elif isinstance(source, LOADERS):
|
||||
|
|
@ -135,7 +146,7 @@ def check_source(source):
|
|||
elif isinstance(source, torch.Tensor):
|
||||
tensor = True
|
||||
else:
|
||||
raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict')
|
||||
raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict")
|
||||
|
||||
return source, webcam, screenshot, from_img, in_memory, tensor
|
||||
|
||||
|
|
@ -171,6 +182,6 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False):
|
|||
dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
|
||||
|
||||
# Attach source types to the dataset
|
||||
setattr(dataset, 'source_type', source_type)
|
||||
setattr(dataset, "source_type", source_type)
|
||||
|
||||
return dataset
|
||||
|
|
|
|||
|
|
@ -20,10 +20,98 @@ def coco91_to_coco80_class():
|
|||
corresponding 91-index class ID.
|
||||
"""
|
||||
return [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None,
|
||||
None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
|
||||
51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
|
||||
None, 73, 74, 75, 76, 77, 78, 79, None]
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
None,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
None,
|
||||
24,
|
||||
25,
|
||||
None,
|
||||
None,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
None,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
None,
|
||||
60,
|
||||
None,
|
||||
None,
|
||||
61,
|
||||
None,
|
||||
62,
|
||||
63,
|
||||
64,
|
||||
65,
|
||||
66,
|
||||
67,
|
||||
68,
|
||||
69,
|
||||
70,
|
||||
71,
|
||||
72,
|
||||
None,
|
||||
73,
|
||||
74,
|
||||
75,
|
||||
76,
|
||||
77,
|
||||
78,
|
||||
79,
|
||||
None,
|
||||
]
|
||||
|
||||
|
||||
def coco80_to_coco91_class():
|
||||
|
|
@ -42,16 +130,96 @@ def coco80_to_coco91_class():
|
|||
```
|
||||
"""
|
||||
return [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
||||
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
||||
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
27,
|
||||
28,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
60,
|
||||
61,
|
||||
62,
|
||||
63,
|
||||
64,
|
||||
65,
|
||||
67,
|
||||
70,
|
||||
72,
|
||||
73,
|
||||
74,
|
||||
75,
|
||||
76,
|
||||
77,
|
||||
78,
|
||||
79,
|
||||
80,
|
||||
81,
|
||||
82,
|
||||
84,
|
||||
85,
|
||||
86,
|
||||
87,
|
||||
88,
|
||||
89,
|
||||
90,
|
||||
]
|
||||
|
||||
|
||||
def convert_coco(labels_dir='../coco/annotations/',
|
||||
save_dir='coco_converted/',
|
||||
use_segments=False,
|
||||
use_keypoints=False,
|
||||
cls91to80=True):
|
||||
def convert_coco(
|
||||
labels_dir="../coco/annotations/",
|
||||
save_dir="coco_converted/",
|
||||
use_segments=False,
|
||||
use_keypoints=False,
|
||||
cls91to80=True,
|
||||
):
|
||||
"""
|
||||
Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
|
||||
|
||||
|
|
@ -75,76 +243,78 @@ def convert_coco(labels_dir='../coco/annotations/',
|
|||
|
||||
# Create dataset directory
|
||||
save_dir = increment_path(save_dir) # increment if save directory already exists
|
||||
for p in save_dir / 'labels', save_dir / 'images':
|
||||
for p in save_dir / "labels", save_dir / "images":
|
||||
p.mkdir(parents=True, exist_ok=True) # make dir
|
||||
|
||||
# Convert classes
|
||||
coco80 = coco91_to_coco80_class()
|
||||
|
||||
# Import json
|
||||
for json_file in sorted(Path(labels_dir).resolve().glob('*.json')):
|
||||
fn = Path(save_dir) / 'labels' / json_file.stem.replace('instances_', '') # folder name
|
||||
for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
|
||||
fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name
|
||||
fn.mkdir(parents=True, exist_ok=True)
|
||||
with open(json_file) as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Create image dict
|
||||
images = {f'{x["id"]:d}': x for x in data['images']}
|
||||
images = {f'{x["id"]:d}': x for x in data["images"]}
|
||||
# Create image-annotations dict
|
||||
imgToAnns = defaultdict(list)
|
||||
for ann in data['annotations']:
|
||||
imgToAnns[ann['image_id']].append(ann)
|
||||
for ann in data["annotations"]:
|
||||
imgToAnns[ann["image_id"]].append(ann)
|
||||
|
||||
# Write labels file
|
||||
for img_id, anns in TQDM(imgToAnns.items(), desc=f'Annotations {json_file}'):
|
||||
img = images[f'{img_id:d}']
|
||||
h, w, f = img['height'], img['width'], img['file_name']
|
||||
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"):
|
||||
img = images[f"{img_id:d}"]
|
||||
h, w, f = img["height"], img["width"], img["file_name"]
|
||||
|
||||
bboxes = []
|
||||
segments = []
|
||||
keypoints = []
|
||||
for ann in anns:
|
||||
if ann['iscrowd']:
|
||||
if ann["iscrowd"]:
|
||||
continue
|
||||
# The COCO box format is [top left x, top left y, width, height]
|
||||
box = np.array(ann['bbox'], dtype=np.float64)
|
||||
box = np.array(ann["bbox"], dtype=np.float64)
|
||||
box[:2] += box[2:] / 2 # xy top-left corner to center
|
||||
box[[0, 2]] /= w # normalize x
|
||||
box[[1, 3]] /= h # normalize y
|
||||
if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0
|
||||
continue
|
||||
|
||||
cls = coco80[ann['category_id'] - 1] if cls91to80 else ann['category_id'] - 1 # class
|
||||
cls = coco80[ann["category_id"] - 1] if cls91to80 else ann["category_id"] - 1 # class
|
||||
box = [cls] + box.tolist()
|
||||
if box not in bboxes:
|
||||
bboxes.append(box)
|
||||
if use_segments and ann.get('segmentation') is not None:
|
||||
if len(ann['segmentation']) == 0:
|
||||
if use_segments and ann.get("segmentation") is not None:
|
||||
if len(ann["segmentation"]) == 0:
|
||||
segments.append([])
|
||||
continue
|
||||
elif len(ann['segmentation']) > 1:
|
||||
s = merge_multi_segment(ann['segmentation'])
|
||||
elif len(ann["segmentation"]) > 1:
|
||||
s = merge_multi_segment(ann["segmentation"])
|
||||
s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist()
|
||||
else:
|
||||
s = [j for i in ann['segmentation'] for j in i] # all segments concatenated
|
||||
s = [j for i in ann["segmentation"] for j in i] # all segments concatenated
|
||||
s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist()
|
||||
s = [cls] + s
|
||||
segments.append(s)
|
||||
if use_keypoints and ann.get('keypoints') is not None:
|
||||
keypoints.append(box + (np.array(ann['keypoints']).reshape(-1, 3) /
|
||||
np.array([w, h, 1])).reshape(-1).tolist())
|
||||
if use_keypoints and ann.get("keypoints") is not None:
|
||||
keypoints.append(
|
||||
box + (np.array(ann["keypoints"]).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist()
|
||||
)
|
||||
|
||||
# Write
|
||||
with open((fn / f).with_suffix('.txt'), 'a') as file:
|
||||
with open((fn / f).with_suffix(".txt"), "a") as file:
|
||||
for i in range(len(bboxes)):
|
||||
if use_keypoints:
|
||||
line = *(keypoints[i]), # cls, box, keypoints
|
||||
line = (*(keypoints[i]),) # cls, box, keypoints
|
||||
else:
|
||||
line = *(segments[i]
|
||||
if use_segments and len(segments[i]) > 0 else bboxes[i]), # cls, box or segments
|
||||
file.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
line = (
|
||||
*(segments[i] if use_segments and len(segments[i]) > 0 else bboxes[i]),
|
||||
) # cls, box or segments
|
||||
file.write(("%g " * len(line)).rstrip() % line + "\n")
|
||||
|
||||
LOGGER.info(f'COCO data converted successfully.\nResults saved to {save_dir.resolve()}')
|
||||
LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}")
|
||||
|
||||
|
||||
def convert_dota_to_yolo_obb(dota_root_path: str):
|
||||
|
|
@ -184,31 +354,32 @@ def convert_dota_to_yolo_obb(dota_root_path: str):
|
|||
|
||||
# Class names to indices mapping
|
||||
class_mapping = {
|
||||
'plane': 0,
|
||||
'ship': 1,
|
||||
'storage-tank': 2,
|
||||
'baseball-diamond': 3,
|
||||
'tennis-court': 4,
|
||||
'basketball-court': 5,
|
||||
'ground-track-field': 6,
|
||||
'harbor': 7,
|
||||
'bridge': 8,
|
||||
'large-vehicle': 9,
|
||||
'small-vehicle': 10,
|
||||
'helicopter': 11,
|
||||
'roundabout': 12,
|
||||
'soccer-ball-field': 13,
|
||||
'swimming-pool': 14,
|
||||
'container-crane': 15,
|
||||
'airport': 16,
|
||||
'helipad': 17}
|
||||
"plane": 0,
|
||||
"ship": 1,
|
||||
"storage-tank": 2,
|
||||
"baseball-diamond": 3,
|
||||
"tennis-court": 4,
|
||||
"basketball-court": 5,
|
||||
"ground-track-field": 6,
|
||||
"harbor": 7,
|
||||
"bridge": 8,
|
||||
"large-vehicle": 9,
|
||||
"small-vehicle": 10,
|
||||
"helicopter": 11,
|
||||
"roundabout": 12,
|
||||
"soccer-ball-field": 13,
|
||||
"swimming-pool": 14,
|
||||
"container-crane": 15,
|
||||
"airport": 16,
|
||||
"helipad": 17,
|
||||
}
|
||||
|
||||
def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir):
|
||||
"""Converts a single image's DOTA annotation to YOLO OBB format and saves it to a specified directory."""
|
||||
orig_label_path = orig_label_dir / f'{image_name}.txt'
|
||||
save_path = save_dir / f'{image_name}.txt'
|
||||
orig_label_path = orig_label_dir / f"{image_name}.txt"
|
||||
save_path = save_dir / f"{image_name}.txt"
|
||||
|
||||
with orig_label_path.open('r') as f, save_path.open('w') as g:
|
||||
with orig_label_path.open("r") as f, save_path.open("w") as g:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
|
|
@ -218,20 +389,21 @@ def convert_dota_to_yolo_obb(dota_root_path: str):
|
|||
class_idx = class_mapping[class_name]
|
||||
coords = [float(p) for p in parts[:8]]
|
||||
normalized_coords = [
|
||||
coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)]
|
||||
formatted_coords = ['{:.6g}'.format(coord) for coord in normalized_coords]
|
||||
coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)
|
||||
]
|
||||
formatted_coords = ["{:.6g}".format(coord) for coord in normalized_coords]
|
||||
g.write(f"{class_idx} {' '.join(formatted_coords)}\n")
|
||||
|
||||
for phase in ['train', 'val']:
|
||||
image_dir = dota_root_path / 'images' / phase
|
||||
orig_label_dir = dota_root_path / 'labels' / f'{phase}_original'
|
||||
save_dir = dota_root_path / 'labels' / phase
|
||||
for phase in ["train", "val"]:
|
||||
image_dir = dota_root_path / "images" / phase
|
||||
orig_label_dir = dota_root_path / "labels" / f"{phase}_original"
|
||||
save_dir = dota_root_path / "labels" / phase
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
image_paths = list(image_dir.iterdir())
|
||||
for image_path in TQDM(image_paths, desc=f'Processing {phase} images'):
|
||||
if image_path.suffix != '.png':
|
||||
for image_path in TQDM(image_paths, desc=f"Processing {phase} images"):
|
||||
if image_path.suffix != ".png":
|
||||
continue
|
||||
image_name_without_ext = image_path.stem
|
||||
img = cv2.imread(str(image_path))
|
||||
|
|
@ -293,7 +465,7 @@ def merge_multi_segment(segments):
|
|||
s.append(segments[i])
|
||||
else:
|
||||
idx = [0, idx[1] - idx[0]]
|
||||
s.append(segments[i][idx[0]:idx[1] + 1])
|
||||
s.append(segments[i][idx[0] : idx[1] + 1])
|
||||
|
||||
else:
|
||||
for i in range(len(idx_list) - 1, -1, -1):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from .base import BaseDataset
|
|||
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
|
||||
|
||||
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
||||
DATASET_CACHE_VERSION = '1.0.3'
|
||||
DATASET_CACHE_VERSION = "1.0.3"
|
||||
|
||||
|
||||
class YOLODataset(BaseDataset):
|
||||
|
|
@ -33,16 +33,16 @@ class YOLODataset(BaseDataset):
|
|||
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, data=None, task='detect', **kwargs):
|
||||
def __init__(self, *args, data=None, task="detect", **kwargs):
|
||||
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
|
||||
self.use_segments = task == 'segment'
|
||||
self.use_keypoints = task == 'pose'
|
||||
self.use_obb = task == 'obb'
|
||||
self.use_segments = task == "segment"
|
||||
self.use_keypoints = task == "pose"
|
||||
self.use_obb = task == "obb"
|
||||
self.data = data
|
||||
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
|
||||
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def cache_labels(self, path=Path('./labels.cache')):
|
||||
def cache_labels(self, path=Path("./labels.cache")):
|
||||
"""
|
||||
Cache dataset labels, check images and read shapes.
|
||||
|
||||
|
|
@ -51,19 +51,29 @@ class YOLODataset(BaseDataset):
|
|||
Returns:
|
||||
(dict): labels.
|
||||
"""
|
||||
x = {'labels': []}
|
||||
x = {"labels": []}
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
|
||||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||
total = len(self.im_files)
|
||||
nkpt, ndim = self.data.get('kpt_shape', (0, 0))
|
||||
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
|
||||
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
|
||||
raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
||||
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'")
|
||||
raise ValueError(
|
||||
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
||||
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
|
||||
)
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(func=verify_image_label,
|
||||
iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
|
||||
repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt),
|
||||
repeat(ndim)))
|
||||
results = pool.imap(
|
||||
func=verify_image_label,
|
||||
iterable=zip(
|
||||
self.im_files,
|
||||
self.label_files,
|
||||
repeat(self.prefix),
|
||||
repeat(self.use_keypoints),
|
||||
repeat(len(self.data["names"])),
|
||||
repeat(nkpt),
|
||||
repeat(ndim),
|
||||
),
|
||||
)
|
||||
pbar = TQDM(results, desc=desc, total=total)
|
||||
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||
nm += nm_f
|
||||
|
|
@ -71,7 +81,7 @@ class YOLODataset(BaseDataset):
|
|||
ne += ne_f
|
||||
nc += nc_f
|
||||
if im_file:
|
||||
x['labels'].append(
|
||||
x["labels"].append(
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=shape,
|
||||
|
|
@ -80,60 +90,63 @@ class YOLODataset(BaseDataset):
|
|||
segments=segments,
|
||||
keypoints=keypoint,
|
||||
normalized=True,
|
||||
bbox_format='xywh'))
|
||||
bbox_format="xywh",
|
||||
)
|
||||
)
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
pbar.close()
|
||||
|
||||
if msgs:
|
||||
LOGGER.info('\n'.join(msgs))
|
||||
LOGGER.info("\n".join(msgs))
|
||||
if nf == 0:
|
||||
LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
|
||||
x['hash'] = get_hash(self.label_files + self.im_files)
|
||||
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
||||
x['msgs'] = msgs # warnings
|
||||
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
|
||||
x["hash"] = get_hash(self.label_files + self.im_files)
|
||||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||||
x["msgs"] = msgs # warnings
|
||||
save_dataset_cache_file(self.prefix, path, x)
|
||||
return x
|
||||
|
||||
def get_labels(self):
|
||||
"""Returns dictionary of labels for YOLO training."""
|
||||
self.label_files = img2label_paths(self.im_files)
|
||||
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
|
||||
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
||||
try:
|
||||
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
||||
assert cache['version'] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
||||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
||||
except (FileNotFoundError, AssertionError, AttributeError):
|
||||
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||||
|
||||
# Display cache
|
||||
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
||||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if exists and LOCAL_RANK in (-1, 0):
|
||||
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
||||
if cache['msgs']:
|
||||
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
|
||||
# Read cache
|
||||
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
|
||||
labels = cache['labels']
|
||||
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
||||
labels = cache["labels"]
|
||||
if not labels:
|
||||
LOGGER.warning(f'WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}')
|
||||
self.im_files = [lb['im_file'] for lb in labels] # update im_files
|
||||
LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
|
||||
self.im_files = [lb["im_file"] for lb in labels] # update im_files
|
||||
|
||||
# Check if the dataset is all boxes or all segments
|
||||
lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels)
|
||||
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
|
||||
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
|
||||
if len_segments and len_boxes != len_segments:
|
||||
LOGGER.warning(
|
||||
f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
|
||||
f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
|
||||
'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
|
||||
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
||||
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
||||
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
|
||||
)
|
||||
for lb in labels:
|
||||
lb['segments'] = []
|
||||
lb["segments"] = []
|
||||
if len_cls == 0:
|
||||
LOGGER.warning(f'WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}')
|
||||
LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
|
||||
return labels
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
|
|
@ -145,14 +158,17 @@ class YOLODataset(BaseDataset):
|
|||
else:
|
||||
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
||||
transforms.append(
|
||||
Format(bbox_format='xywh',
|
||||
normalize=True,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
return_obb=self.use_obb,
|
||||
batch_idx=True,
|
||||
mask_ratio=hyp.mask_ratio,
|
||||
mask_overlap=hyp.overlap_mask))
|
||||
Format(
|
||||
bbox_format="xywh",
|
||||
normalize=True,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
return_obb=self.use_obb,
|
||||
batch_idx=True,
|
||||
mask_ratio=hyp.mask_ratio,
|
||||
mask_overlap=hyp.overlap_mask,
|
||||
)
|
||||
)
|
||||
return transforms
|
||||
|
||||
def close_mosaic(self, hyp):
|
||||
|
|
@ -166,11 +182,11 @@ class YOLODataset(BaseDataset):
|
|||
"""Custom your label format here."""
|
||||
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
||||
# We can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
||||
bboxes = label.pop('bboxes')
|
||||
segments = label.pop('segments', [])
|
||||
keypoints = label.pop('keypoints', None)
|
||||
bbox_format = label.pop('bbox_format')
|
||||
normalized = label.pop('normalized')
|
||||
bboxes = label.pop("bboxes")
|
||||
segments = label.pop("segments", [])
|
||||
keypoints = label.pop("keypoints", None)
|
||||
bbox_format = label.pop("bbox_format")
|
||||
normalized = label.pop("normalized")
|
||||
|
||||
# NOTE: do NOT resample oriented boxes
|
||||
segment_resamples = 100 if self.use_obb else 1000
|
||||
|
|
@ -180,7 +196,7 @@ class YOLODataset(BaseDataset):
|
|||
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
|
||||
else:
|
||||
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
|
||||
label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||||
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||||
return label
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -191,15 +207,15 @@ class YOLODataset(BaseDataset):
|
|||
values = list(zip(*[list(b.values()) for b in batch]))
|
||||
for i, k in enumerate(keys):
|
||||
value = values[i]
|
||||
if k == 'img':
|
||||
if k == "img":
|
||||
value = torch.stack(value, 0)
|
||||
if k in ['masks', 'keypoints', 'bboxes', 'cls', 'segments', 'obb']:
|
||||
if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
|
||||
value = torch.cat(value, 0)
|
||||
new_batch[k] = value
|
||||
new_batch['batch_idx'] = list(new_batch['batch_idx'])
|
||||
for i in range(len(new_batch['batch_idx'])):
|
||||
new_batch['batch_idx'][i] += i # add target image index for build_targets()
|
||||
new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
|
||||
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||||
for i in range(len(new_batch["batch_idx"])):
|
||||
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
||||
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
||||
return new_batch
|
||||
|
||||
|
||||
|
|
@ -219,7 +235,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
|
||||
"""
|
||||
|
||||
def __init__(self, root, args, augment=False, cache=False, prefix=''):
|
||||
def __init__(self, root, args, augment=False, cache=False, prefix=""):
|
||||
"""
|
||||
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
||||
|
||||
|
|
@ -231,23 +247,28 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
"""
|
||||
super().__init__(root=root)
|
||||
if augment and args.fraction < 1.0: # reduce training fraction
|
||||
self.samples = self.samples[:round(len(self.samples) * args.fraction)]
|
||||
self.prefix = colorstr(f'{prefix}: ') if prefix else ''
|
||||
self.cache_ram = cache is True or cache == 'ram'
|
||||
self.cache_disk = cache == 'disk'
|
||||
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
||||
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
||||
self.cache_ram = cache is True or cache == "ram"
|
||||
self.cache_disk = cache == "disk"
|
||||
self.samples = self.verify_images() # filter out bad images
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||||
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
||||
self.torch_transforms = classify_augmentations(size=args.imgsz,
|
||||
scale=scale,
|
||||
hflip=args.fliplr,
|
||||
vflip=args.flipud,
|
||||
erasing=args.erasing,
|
||||
auto_augment=args.auto_augment,
|
||||
hsv_h=args.hsv_h,
|
||||
hsv_s=args.hsv_s,
|
||||
hsv_v=args.hsv_v) if augment else classify_transforms(
|
||||
size=args.imgsz, crop_fraction=args.crop_fraction)
|
||||
self.torch_transforms = (
|
||||
classify_augmentations(
|
||||
size=args.imgsz,
|
||||
scale=scale,
|
||||
hflip=args.fliplr,
|
||||
vflip=args.flipud,
|
||||
erasing=args.erasing,
|
||||
auto_augment=args.auto_augment,
|
||||
hsv_h=args.hsv_h,
|
||||
hsv_s=args.hsv_s,
|
||||
hsv_v=args.hsv_v,
|
||||
)
|
||||
if augment
|
||||
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
||||
)
|
||||
|
||||
def __getitem__(self, i):
|
||||
"""Returns subset of data and targets corresponding to given indices."""
|
||||
|
|
@ -263,7 +284,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
# Convert NumPy array to PIL image
|
||||
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
||||
sample = self.torch_transforms(im)
|
||||
return {'img': sample, 'cls': j}
|
||||
return {"img": sample, "cls": j}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the total number of samples in the dataset."""
|
||||
|
|
@ -271,19 +292,19 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
|
||||
def verify_images(self):
|
||||
"""Verify all images in dataset."""
|
||||
desc = f'{self.prefix}Scanning {self.root}...'
|
||||
path = Path(self.root).with_suffix('.cache') # *.cache file path
|
||||
desc = f"{self.prefix}Scanning {self.root}..."
|
||||
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
||||
|
||||
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
||||
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
||||
assert cache['version'] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||
nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total
|
||||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if LOCAL_RANK in (-1, 0):
|
||||
d = f'{desc} {nf} images, {nc} corrupt'
|
||||
d = f"{desc} {nf} images, {nc} corrupt"
|
||||
TQDM(None, desc=d, total=n, initial=n)
|
||||
if cache['msgs']:
|
||||
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
return samples
|
||||
|
||||
# Run scan if *.cache retrieval failed
|
||||
|
|
@ -298,13 +319,13 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
msgs.append(msg)
|
||||
nf += nf_f
|
||||
nc += nc_f
|
||||
pbar.desc = f'{desc} {nf} images, {nc} corrupt'
|
||||
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
||||
pbar.close()
|
||||
if msgs:
|
||||
LOGGER.info('\n'.join(msgs))
|
||||
x['hash'] = get_hash([x[0] for x in self.samples])
|
||||
x['results'] = nf, nc, len(samples), samples
|
||||
x['msgs'] = msgs # warnings
|
||||
LOGGER.info("\n".join(msgs))
|
||||
x["hash"] = get_hash([x[0] for x in self.samples])
|
||||
x["results"] = nf, nc, len(samples), samples
|
||||
x["msgs"] = msgs # warnings
|
||||
save_dataset_cache_file(self.prefix, path, x)
|
||||
return samples
|
||||
|
||||
|
|
@ -312,6 +333,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
def load_dataset_cache_file(path):
|
||||
"""Load an Ultralytics *.cache dictionary from path."""
|
||||
import gc
|
||||
|
||||
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
||||
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
||||
gc.enable()
|
||||
|
|
@ -320,15 +342,15 @@ def load_dataset_cache_file(path):
|
|||
|
||||
def save_dataset_cache_file(prefix, path, x):
|
||||
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
||||
x['version'] = DATASET_CACHE_VERSION # add cache version
|
||||
x["version"] = DATASET_CACHE_VERSION # add cache version
|
||||
if is_dir_writeable(path.parent):
|
||||
if path.exists():
|
||||
path.unlink() # remove *.cache file if exists
|
||||
np.save(str(path), x) # save cache for next time
|
||||
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
||||
LOGGER.info(f'{prefix}New cache created: {path}')
|
||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||
LOGGER.info(f"{prefix}New cache created: {path}")
|
||||
else:
|
||||
LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
|
||||
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
||||
|
||||
|
||||
# TODO: support semantic segmentation
|
||||
|
|
|
|||
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
from .utils import plot_query_result
|
||||
|
||||
__all__ = ['plot_query_result']
|
||||
__all__ = ["plot_query_result"]
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from .utils import get_sim_index_schema, get_table_schema, plot_query_result, pr
|
|||
|
||||
|
||||
class ExplorerDataset(YOLODataset):
|
||||
|
||||
def __init__(self, *args, data: dict = None, **kwargs) -> None:
|
||||
super().__init__(*args, data=data, **kwargs)
|
||||
|
||||
|
|
@ -35,7 +34,7 @@ class ExplorerDataset(YOLODataset):
|
|||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if im is None:
|
||||
raise FileNotFoundError(f'Image Not Found {f}')
|
||||
raise FileNotFoundError(f"Image Not Found {f}")
|
||||
h0, w0 = im.shape[:2] # orig hw
|
||||
return im, (h0, w0), im.shape[:2]
|
||||
|
||||
|
|
@ -44,7 +43,7 @@ class ExplorerDataset(YOLODataset):
|
|||
def build_transforms(self, hyp: IterableSimpleNamespace = None):
|
||||
"""Creates transforms for dataset images without resizing."""
|
||||
return Format(
|
||||
bbox_format='xyxy',
|
||||
bbox_format="xyxy",
|
||||
normalize=False,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
|
|
@ -55,17 +54,16 @@ class ExplorerDataset(YOLODataset):
|
|||
|
||||
|
||||
class Explorer:
|
||||
|
||||
def __init__(self,
|
||||
data: Union[str, Path] = 'coco128.yaml',
|
||||
model: str = 'yolov8n.pt',
|
||||
uri: str = '~/ultralytics/explorer') -> None:
|
||||
checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
|
||||
def __init__(
|
||||
self, data: Union[str, Path] = "coco128.yaml", model: str = "yolov8n.pt", uri: str = "~/ultralytics/explorer"
|
||||
) -> None:
|
||||
checks.check_requirements(["lancedb>=0.4.3", "duckdb"])
|
||||
import lancedb
|
||||
|
||||
self.connection = lancedb.connect(uri)
|
||||
self.table_name = Path(data).name.lower() + '_' + model.lower()
|
||||
self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower(
|
||||
self.table_name = Path(data).name.lower() + "_" + model.lower()
|
||||
self.sim_idx_base_name = (
|
||||
f"{self.table_name}_sim_idx".lower()
|
||||
) # Use this name and append thres and top_k to reuse the table
|
||||
self.model = YOLO(model)
|
||||
self.data = data # None
|
||||
|
|
@ -74,7 +72,7 @@ class Explorer:
|
|||
self.table = None
|
||||
self.progress = 0
|
||||
|
||||
def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None:
|
||||
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
|
||||
"""
|
||||
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
|
||||
already exists. Pass force=True to overwrite the existing table.
|
||||
|
|
@ -90,20 +88,20 @@ class Explorer:
|
|||
```
|
||||
"""
|
||||
if self.table is not None and not force:
|
||||
LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.')
|
||||
LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
|
||||
return
|
||||
if self.table_name in self.connection.table_names() and not force:
|
||||
LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.')
|
||||
LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
|
||||
self.table = self.connection.open_table(self.table_name)
|
||||
self.progress = 1
|
||||
return
|
||||
if self.data is None:
|
||||
raise ValueError('Data must be provided to create embeddings table')
|
||||
raise ValueError("Data must be provided to create embeddings table")
|
||||
|
||||
data_info = check_det_dataset(self.data)
|
||||
if split not in data_info:
|
||||
raise ValueError(
|
||||
f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}'
|
||||
f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
|
||||
)
|
||||
|
||||
choice_set = data_info[split]
|
||||
|
|
@ -113,13 +111,16 @@ class Explorer:
|
|||
|
||||
# Create the table schema
|
||||
batch = dataset[0]
|
||||
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
|
||||
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite')
|
||||
vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
|
||||
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
|
||||
table.add(
|
||||
self._yield_batches(dataset,
|
||||
data_info,
|
||||
self.model,
|
||||
exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx']))
|
||||
self._yield_batches(
|
||||
dataset,
|
||||
data_info,
|
||||
self.model,
|
||||
exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
|
||||
)
|
||||
)
|
||||
|
||||
self.table = table
|
||||
|
||||
|
|
@ -131,12 +132,12 @@ class Explorer:
|
|||
for k in exclude_keys:
|
||||
batch.pop(k, None)
|
||||
batch = sanitize_batch(batch, data_info)
|
||||
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
|
||||
batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
|
||||
yield [batch]
|
||||
|
||||
def query(self,
|
||||
imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
||||
limit: int = 25) -> Any: # pyarrow.Table
|
||||
def query(
|
||||
self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
|
||||
) -> Any: # pyarrow.Table
|
||||
"""
|
||||
Query the table for similar images. Accepts a single image or a list of images.
|
||||
|
||||
|
|
@ -157,18 +158,18 @@ class Explorer:
|
|||
```
|
||||
"""
|
||||
if self.table is None:
|
||||
raise ValueError('Table is not created. Please create the table first.')
|
||||
raise ValueError("Table is not created. Please create the table first.")
|
||||
if isinstance(imgs, str):
|
||||
imgs = [imgs]
|
||||
assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}'
|
||||
assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
|
||||
embeds = self.model.embed(imgs)
|
||||
# Get avg if multiple images are passed (len > 1)
|
||||
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
||||
return self.table.search(embeds).limit(limit).to_arrow()
|
||||
|
||||
def sql_query(self,
|
||||
query: str,
|
||||
return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
|
||||
def sql_query(
|
||||
self, query: str, return_type: str = "pandas"
|
||||
) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
|
||||
"""
|
||||
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
||||
|
||||
|
|
@ -187,27 +188,29 @@ class Explorer:
|
|||
result = exp.sql_query(query)
|
||||
```
|
||||
"""
|
||||
assert return_type in ['pandas',
|
||||
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
|
||||
assert return_type in [
|
||||
"pandas",
|
||||
"arrow",
|
||||
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
||||
import duckdb
|
||||
|
||||
if self.table is None:
|
||||
raise ValueError('Table is not created. Please create the table first.')
|
||||
raise ValueError("Table is not created. Please create the table first.")
|
||||
|
||||
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
|
||||
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
|
||||
if not query.startswith('SELECT') and not query.startswith('WHERE'):
|
||||
if not query.startswith("SELECT") and not query.startswith("WHERE"):
|
||||
raise ValueError(
|
||||
f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}'
|
||||
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
|
||||
)
|
||||
if query.startswith('WHERE'):
|
||||
if query.startswith("WHERE"):
|
||||
query = f"SELECT * FROM 'table' {query}"
|
||||
LOGGER.info(f'Running query: {query}')
|
||||
LOGGER.info(f"Running query: {query}")
|
||||
|
||||
rs = duckdb.sql(query)
|
||||
if return_type == 'pandas':
|
||||
if return_type == "pandas":
|
||||
return rs.df()
|
||||
elif return_type == 'arrow':
|
||||
elif return_type == "arrow":
|
||||
return rs.arrow()
|
||||
|
||||
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
|
||||
|
|
@ -228,18 +231,20 @@ class Explorer:
|
|||
result = exp.plot_sql_query(query)
|
||||
```
|
||||
"""
|
||||
result = self.sql_query(query, return_type='arrow')
|
||||
result = self.sql_query(query, return_type="arrow")
|
||||
if len(result) == 0:
|
||||
LOGGER.info('No results found.')
|
||||
LOGGER.info("No results found.")
|
||||
return None
|
||||
img = plot_query_result(result, plot_labels=labels)
|
||||
return Image.fromarray(img)
|
||||
|
||||
def get_similar(self,
|
||||
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
||||
idx: Union[int, List[int]] = None,
|
||||
limit: int = 25,
|
||||
return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
|
||||
def get_similar(
|
||||
self,
|
||||
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
||||
idx: Union[int, List[int]] = None,
|
||||
limit: int = 25,
|
||||
return_type: str = "pandas",
|
||||
) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
|
||||
"""
|
||||
Query the table for similar images. Accepts a single image or a list of images.
|
||||
|
||||
|
|
@ -259,21 +264,25 @@ class Explorer:
|
|||
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
|
||||
```
|
||||
"""
|
||||
assert return_type in ['pandas',
|
||||
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
|
||||
assert return_type in [
|
||||
"pandas",
|
||||
"arrow",
|
||||
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
||||
img = self._check_imgs_or_idxs(img, idx)
|
||||
similar = self.query(img, limit=limit)
|
||||
|
||||
if return_type == 'pandas':
|
||||
if return_type == "pandas":
|
||||
return similar.to_pandas()
|
||||
elif return_type == 'arrow':
|
||||
elif return_type == "arrow":
|
||||
return similar
|
||||
|
||||
def plot_similar(self,
|
||||
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
||||
idx: Union[int, List[int]] = None,
|
||||
limit: int = 25,
|
||||
labels: bool = True) -> Image.Image:
|
||||
def plot_similar(
|
||||
self,
|
||||
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
||||
idx: Union[int, List[int]] = None,
|
||||
limit: int = 25,
|
||||
labels: bool = True,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Plot the similar images. Accepts images or indexes.
|
||||
|
||||
|
|
@ -293,9 +302,9 @@ class Explorer:
|
|||
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
|
||||
```
|
||||
"""
|
||||
similar = self.get_similar(img, idx, limit, return_type='arrow')
|
||||
similar = self.get_similar(img, idx, limit, return_type="arrow")
|
||||
if len(similar) == 0:
|
||||
LOGGER.info('No results found.')
|
||||
LOGGER.info("No results found.")
|
||||
return None
|
||||
img = plot_query_result(similar, plot_labels=labels)
|
||||
return Image.fromarray(img)
|
||||
|
|
@ -323,34 +332,37 @@ class Explorer:
|
|||
```
|
||||
"""
|
||||
if self.table is None:
|
||||
raise ValueError('Table is not created. Please create the table first.')
|
||||
sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower()
|
||||
raise ValueError("Table is not created. Please create the table first.")
|
||||
sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
|
||||
if sim_idx_table_name in self.connection.table_names() and not force:
|
||||
LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.')
|
||||
LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
|
||||
return self.connection.open_table(sim_idx_table_name).to_pandas()
|
||||
|
||||
if top_k and not (1.0 >= top_k >= 0.0):
|
||||
raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}')
|
||||
raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
|
||||
if max_dist < 0.0:
|
||||
raise ValueError(f'max_dist must be greater than 0. Got {max_dist}')
|
||||
raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
|
||||
|
||||
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
|
||||
top_k = max(top_k, 1)
|
||||
features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict()
|
||||
im_files = features['im_file']
|
||||
embeddings = features['vector']
|
||||
features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
|
||||
im_files = features["im_file"]
|
||||
embeddings = features["vector"]
|
||||
|
||||
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite')
|
||||
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
|
||||
|
||||
def _yield_sim_idx():
|
||||
"""Generates a dataframe with similarity indices and distances for images."""
|
||||
for i in tqdm(range(len(embeddings))):
|
||||
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}')
|
||||
yield [{
|
||||
'idx': i,
|
||||
'im_file': im_files[i],
|
||||
'count': len(sim_idx),
|
||||
'sim_im_files': sim_idx['im_file'].tolist()}]
|
||||
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
|
||||
yield [
|
||||
{
|
||||
"idx": i,
|
||||
"im_file": im_files[i],
|
||||
"count": len(sim_idx),
|
||||
"sim_im_files": sim_idx["im_file"].tolist(),
|
||||
}
|
||||
]
|
||||
|
||||
sim_table.add(_yield_sim_idx())
|
||||
self.sim_index = sim_table
|
||||
|
|
@ -381,7 +393,7 @@ class Explorer:
|
|||
```
|
||||
"""
|
||||
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
|
||||
sim_count = sim_idx['count'].tolist()
|
||||
sim_count = sim_idx["count"].tolist()
|
||||
sim_count = np.array(sim_count)
|
||||
|
||||
indices = np.arange(len(sim_count))
|
||||
|
|
@ -390,25 +402,26 @@ class Explorer:
|
|||
plt.bar(indices, sim_count)
|
||||
|
||||
# Customize the plot (optional)
|
||||
plt.xlabel('data idx')
|
||||
plt.ylabel('Count')
|
||||
plt.title('Similarity Count')
|
||||
plt.xlabel("data idx")
|
||||
plt.ylabel("Count")
|
||||
plt.title("Similarity Count")
|
||||
buffer = BytesIO()
|
||||
plt.savefig(buffer, format='png')
|
||||
plt.savefig(buffer, format="png")
|
||||
buffer.seek(0)
|
||||
|
||||
# Use Pillow to open the image from the buffer
|
||||
return Image.fromarray(np.array(Image.open(buffer)))
|
||||
|
||||
def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None],
|
||||
idx: Union[None, int, List[int]]) -> List[np.ndarray]:
|
||||
def _check_imgs_or_idxs(
|
||||
self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
|
||||
) -> List[np.ndarray]:
|
||||
if img is None and idx is None:
|
||||
raise ValueError('Either img or idx must be provided.')
|
||||
raise ValueError("Either img or idx must be provided.")
|
||||
if img is not None and idx is not None:
|
||||
raise ValueError('Only one of img or idx must be provided.')
|
||||
raise ValueError("Only one of img or idx must be provided.")
|
||||
if idx is not None:
|
||||
idx = idx if isinstance(idx, list) else [idx]
|
||||
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
|
||||
img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
|
||||
|
||||
return img if isinstance(img, list) else [img]
|
||||
|
||||
|
|
@ -433,7 +446,7 @@ class Explorer:
|
|||
try:
|
||||
df = self.sql_query(result)
|
||||
except Exception as e:
|
||||
LOGGER.error('AI generated query is not valid. Please try again with a different prompt')
|
||||
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
||||
LOGGER.error(e)
|
||||
return None
|
||||
return df
|
||||
|
|
|
|||
|
|
@ -9,100 +9,114 @@ from ultralytics import Explorer
|
|||
from ultralytics.utils import ROOT, SETTINGS
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements(('streamlit>=1.29.0', 'streamlit-select>=0.2'))
|
||||
check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.2"))
|
||||
|
||||
import streamlit as st
|
||||
from streamlit_select import image_select
|
||||
|
||||
|
||||
def _get_explorer():
|
||||
"""Initializes and returns an instance of the Explorer class."""
|
||||
exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model'))
|
||||
thread = Thread(target=exp.create_embeddings_table,
|
||||
kwargs={'force': st.session_state.get('force_recreate_embeddings')})
|
||||
exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model"))
|
||||
thread = Thread(
|
||||
target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")}
|
||||
)
|
||||
thread.start()
|
||||
progress_bar = st.progress(0, text='Creating embeddings table...')
|
||||
progress_bar = st.progress(0, text="Creating embeddings table...")
|
||||
while exp.progress < 1:
|
||||
time.sleep(0.1)
|
||||
progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%')
|
||||
progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%")
|
||||
thread.join()
|
||||
st.session_state['explorer'] = exp
|
||||
st.session_state["explorer"] = exp
|
||||
progress_bar.empty()
|
||||
|
||||
|
||||
def init_explorer_form():
|
||||
"""Initializes an Explorer instance and creates embeddings table with progress tracking."""
|
||||
datasets = ROOT / 'cfg' / 'datasets'
|
||||
ds = [d.name for d in datasets.glob('*.yaml')]
|
||||
datasets = ROOT / "cfg" / "datasets"
|
||||
ds = [d.name for d in datasets.glob("*.yaml")]
|
||||
models = [
|
||||
'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt',
|
||||
'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt',
|
||||
'yolov8l-pose.pt', 'yolov8x-pose.pt']
|
||||
with st.form(key='explorer_init_form'):
|
||||
"yolov8n.pt",
|
||||
"yolov8s.pt",
|
||||
"yolov8m.pt",
|
||||
"yolov8l.pt",
|
||||
"yolov8x.pt",
|
||||
"yolov8n-seg.pt",
|
||||
"yolov8s-seg.pt",
|
||||
"yolov8m-seg.pt",
|
||||
"yolov8l-seg.pt",
|
||||
"yolov8x-seg.pt",
|
||||
"yolov8n-pose.pt",
|
||||
"yolov8s-pose.pt",
|
||||
"yolov8m-pose.pt",
|
||||
"yolov8l-pose.pt",
|
||||
"yolov8x-pose.pt",
|
||||
]
|
||||
with st.form(key="explorer_init_form"):
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
|
||||
st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml"))
|
||||
with col2:
|
||||
st.selectbox('Select model', models, key='model')
|
||||
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
|
||||
st.selectbox("Select model", models, key="model")
|
||||
st.checkbox("Force recreate embeddings", key="force_recreate_embeddings")
|
||||
|
||||
st.form_submit_button('Explore', on_click=_get_explorer)
|
||||
st.form_submit_button("Explore", on_click=_get_explorer)
|
||||
|
||||
|
||||
def query_form():
|
||||
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
|
||||
with st.form('query_form'):
|
||||
with st.form("query_form"):
|
||||
col1, col2 = st.columns([0.8, 0.2])
|
||||
with col1:
|
||||
st.text_input('Query',
|
||||
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
|
||||
label_visibility='collapsed',
|
||||
key='query')
|
||||
st.text_input(
|
||||
"Query",
|
||||
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
|
||||
label_visibility="collapsed",
|
||||
key="query",
|
||||
)
|
||||
with col2:
|
||||
st.form_submit_button('Query', on_click=run_sql_query)
|
||||
st.form_submit_button("Query", on_click=run_sql_query)
|
||||
|
||||
|
||||
def ai_query_form():
|
||||
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
|
||||
with st.form('ai_query_form'):
|
||||
with st.form("ai_query_form"):
|
||||
col1, col2 = st.columns([0.8, 0.2])
|
||||
with col1:
|
||||
st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query')
|
||||
st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query")
|
||||
with col2:
|
||||
st.form_submit_button('Ask AI', on_click=run_ai_query)
|
||||
st.form_submit_button("Ask AI", on_click=run_ai_query)
|
||||
|
||||
|
||||
def find_similar_imgs(imgs):
|
||||
"""Initializes a Streamlit form for AI-based image querying with custom input."""
|
||||
exp = st.session_state['explorer']
|
||||
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
|
||||
paths = similar.to_pydict()['im_file']
|
||||
st.session_state['imgs'] = paths
|
||||
exp = st.session_state["explorer"]
|
||||
similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
|
||||
paths = similar.to_pydict()["im_file"]
|
||||
st.session_state["imgs"] = paths
|
||||
|
||||
|
||||
def similarity_form(selected_imgs):
|
||||
"""Initializes a form for AI-based image querying with custom input in Streamlit."""
|
||||
st.write('Similarity Search')
|
||||
with st.form('similarity_form'):
|
||||
st.write("Similarity Search")
|
||||
with st.form("similarity_form"):
|
||||
subcol1, subcol2 = st.columns([1, 1])
|
||||
with subcol1:
|
||||
st.number_input('limit',
|
||||
min_value=None,
|
||||
max_value=None,
|
||||
value=25,
|
||||
label_visibility='collapsed',
|
||||
key='limit')
|
||||
st.number_input(
|
||||
"limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit"
|
||||
)
|
||||
|
||||
with subcol2:
|
||||
disabled = not len(selected_imgs)
|
||||
st.write('Selected: ', len(selected_imgs))
|
||||
st.write("Selected: ", len(selected_imgs))
|
||||
st.form_submit_button(
|
||||
'Search',
|
||||
"Search",
|
||||
disabled=disabled,
|
||||
on_click=find_similar_imgs,
|
||||
args=(selected_imgs, ),
|
||||
args=(selected_imgs,),
|
||||
)
|
||||
if disabled:
|
||||
st.error('Select at least one image to search.')
|
||||
st.error("Select at least one image to search.")
|
||||
|
||||
|
||||
# def persist_reset_form():
|
||||
|
|
@ -117,100 +131,108 @@ def similarity_form(selected_imgs):
|
|||
|
||||
def run_sql_query():
|
||||
"""Executes an SQL query and returns the results."""
|
||||
st.session_state['error'] = None
|
||||
query = st.session_state.get('query')
|
||||
st.session_state["error"] = None
|
||||
query = st.session_state.get("query")
|
||||
if query.rstrip().lstrip():
|
||||
exp = st.session_state['explorer']
|
||||
res = exp.sql_query(query, return_type='arrow')
|
||||
st.session_state['imgs'] = res.to_pydict()['im_file']
|
||||
exp = st.session_state["explorer"]
|
||||
res = exp.sql_query(query, return_type="arrow")
|
||||
st.session_state["imgs"] = res.to_pydict()["im_file"]
|
||||
|
||||
|
||||
def run_ai_query():
|
||||
"""Execute SQL query and update session state with query results."""
|
||||
if not SETTINGS['openai_api_key']:
|
||||
if not SETTINGS["openai_api_key"]:
|
||||
st.session_state[
|
||||
'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
||||
"error"
|
||||
] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
||||
return
|
||||
st.session_state['error'] = None
|
||||
query = st.session_state.get('ai_query')
|
||||
st.session_state["error"] = None
|
||||
query = st.session_state.get("ai_query")
|
||||
if query.rstrip().lstrip():
|
||||
exp = st.session_state['explorer']
|
||||
exp = st.session_state["explorer"]
|
||||
res = exp.ask_ai(query)
|
||||
if not isinstance(res, pd.DataFrame) or res.empty:
|
||||
st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.'
|
||||
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
|
||||
return
|
||||
st.session_state['imgs'] = res['im_file'].to_list()
|
||||
st.session_state["imgs"] = res["im_file"].to_list()
|
||||
|
||||
|
||||
def reset_explorer():
|
||||
"""Resets the explorer to its initial state by clearing session variables."""
|
||||
st.session_state['explorer'] = None
|
||||
st.session_state['imgs'] = None
|
||||
st.session_state['error'] = None
|
||||
st.session_state["explorer"] = None
|
||||
st.session_state["imgs"] = None
|
||||
st.session_state["error"] = None
|
||||
|
||||
|
||||
def utralytics_explorer_docs_callback():
|
||||
"""Resets the explorer to its initial state by clearing session variables."""
|
||||
with st.container(border=True):
|
||||
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
|
||||
width=100)
|
||||
st.image(
|
||||
"https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg",
|
||||
width=100,
|
||||
)
|
||||
st.markdown(
|
||||
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
|
||||
unsafe_allow_html=True,
|
||||
help=None)
|
||||
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/')
|
||||
help=None,
|
||||
)
|
||||
st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/")
|
||||
|
||||
|
||||
def layout():
|
||||
"""Resets explorer session variables and provides documentation with a link to API docs."""
|
||||
st.set_page_config(layout='wide', initial_sidebar_state='collapsed')
|
||||
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
|
||||
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
|
||||
|
||||
if st.session_state.get('explorer') is None:
|
||||
if st.session_state.get("explorer") is None:
|
||||
init_explorer_form()
|
||||
return
|
||||
|
||||
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
|
||||
exp = st.session_state.get('explorer')
|
||||
col1, col2 = st.columns([0.75, 0.25], gap='small')
|
||||
st.button(":arrow_backward: Select Dataset", on_click=reset_explorer)
|
||||
exp = st.session_state.get("explorer")
|
||||
col1, col2 = st.columns([0.75, 0.25], gap="small")
|
||||
imgs = []
|
||||
if st.session_state.get('error'):
|
||||
st.error(st.session_state['error'])
|
||||
if st.session_state.get("error"):
|
||||
st.error(st.session_state["error"])
|
||||
else:
|
||||
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
|
||||
imgs = st.session_state.get("imgs") or exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
|
||||
total_imgs, selected_imgs = len(imgs), []
|
||||
with col1:
|
||||
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
|
||||
with subcol1:
|
||||
st.write('Max Images Displayed:')
|
||||
st.write("Max Images Displayed:")
|
||||
with subcol2:
|
||||
num = st.number_input('Max Images Displayed',
|
||||
min_value=0,
|
||||
max_value=total_imgs,
|
||||
value=min(500, total_imgs),
|
||||
key='num_imgs_displayed',
|
||||
label_visibility='collapsed')
|
||||
num = st.number_input(
|
||||
"Max Images Displayed",
|
||||
min_value=0,
|
||||
max_value=total_imgs,
|
||||
value=min(500, total_imgs),
|
||||
key="num_imgs_displayed",
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
with subcol3:
|
||||
st.write('Start Index:')
|
||||
st.write("Start Index:")
|
||||
with subcol4:
|
||||
start_idx = st.number_input('Start Index',
|
||||
min_value=0,
|
||||
max_value=total_imgs,
|
||||
value=0,
|
||||
key='start_index',
|
||||
label_visibility='collapsed')
|
||||
start_idx = st.number_input(
|
||||
"Start Index",
|
||||
min_value=0,
|
||||
max_value=total_imgs,
|
||||
value=0,
|
||||
key="start_index",
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
with subcol5:
|
||||
reset = st.button('Reset', use_container_width=False, key='reset')
|
||||
reset = st.button("Reset", use_container_width=False, key="reset")
|
||||
if reset:
|
||||
st.session_state['imgs'] = None
|
||||
st.session_state["imgs"] = None
|
||||
st.experimental_rerun()
|
||||
|
||||
query_form()
|
||||
ai_query_form()
|
||||
if total_imgs:
|
||||
imgs_displayed = imgs[start_idx:start_idx + num]
|
||||
imgs_displayed = imgs[start_idx : start_idx + num]
|
||||
selected_imgs = image_select(
|
||||
f'Total samples: {total_imgs}',
|
||||
f"Total samples: {total_imgs}",
|
||||
images=imgs_displayed,
|
||||
use_container_width=False,
|
||||
# indices=[i for i in range(num)] if select_all else None,
|
||||
|
|
@ -222,5 +244,5 @@ def layout():
|
|||
utralytics_explorer_docs_callback()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
layout()
|
||||
|
|
|
|||
|
|
@ -46,14 +46,13 @@ def get_sim_index_schema():
|
|||
|
||||
def sanitize_batch(batch, dataset_info):
|
||||
"""Sanitizes input batch for inference, ensuring correct format and dimensions."""
|
||||
batch['cls'] = batch['cls'].flatten().int().tolist()
|
||||
box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1])
|
||||
batch['bboxes'] = [box for box, _ in box_cls_pair]
|
||||
batch['cls'] = [cls for _, cls in box_cls_pair]
|
||||
batch['labels'] = [dataset_info['names'][i] for i in batch['cls']]
|
||||
batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]]
|
||||
batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]]
|
||||
|
||||
batch["cls"] = batch["cls"].flatten().int().tolist()
|
||||
box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
|
||||
batch["bboxes"] = [box for box, _ in box_cls_pair]
|
||||
batch["cls"] = [cls for _, cls in box_cls_pair]
|
||||
batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
|
||||
batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
|
||||
batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
|
||||
return batch
|
||||
|
||||
|
||||
|
|
@ -65,15 +64,16 @@ def plot_query_result(similar_set, plot_labels=True):
|
|||
similar_set (list): Pyarrow or pandas object containing the similar data points
|
||||
plot_labels (bool): Whether to plot labels or not
|
||||
"""
|
||||
similar_set = similar_set.to_dict(
|
||||
orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
|
||||
similar_set = (
|
||||
similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
|
||||
)
|
||||
empty_masks = [[[]]]
|
||||
empty_boxes = [[]]
|
||||
images = similar_set.get('im_file', [])
|
||||
bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else []
|
||||
masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else []
|
||||
kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else []
|
||||
cls = similar_set.get('cls', [])
|
||||
images = similar_set.get("im_file", [])
|
||||
bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
|
||||
masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
|
||||
kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
|
||||
cls = similar_set.get("cls", [])
|
||||
|
||||
plot_size = 640
|
||||
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
|
||||
|
|
@ -104,34 +104,26 @@ def plot_query_result(similar_set, plot_labels=True):
|
|||
batch_idx = np.concatenate(batch_idx, axis=0)
|
||||
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
|
||||
|
||||
return plot_images(imgs,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes=boxes,
|
||||
masks=masks,
|
||||
kpts=kpts,
|
||||
max_subplots=len(images),
|
||||
save=False,
|
||||
threaded=False)
|
||||
return plot_images(
|
||||
imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
|
||||
)
|
||||
|
||||
|
||||
def prompt_sql_query(query):
|
||||
"""Plots images with optional labels from a similar data set."""
|
||||
check_requirements('openai>=1.6.1')
|
||||
check_requirements("openai>=1.6.1")
|
||||
from openai import OpenAI
|
||||
|
||||
if not SETTINGS['openai_api_key']:
|
||||
logger.warning('OpenAI API key not found in settings. Please enter your API key below.')
|
||||
openai_api_key = getpass.getpass('OpenAI API key: ')
|
||||
SETTINGS.update({'openai_api_key': openai_api_key})
|
||||
openai = OpenAI(api_key=SETTINGS['openai_api_key'])
|
||||
if not SETTINGS["openai_api_key"]:
|
||||
logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
|
||||
openai_api_key = getpass.getpass("OpenAI API key: ")
|
||||
SETTINGS.update({"openai_api_key": openai_api_key})
|
||||
openai = OpenAI(api_key=SETTINGS["openai_api_key"])
|
||||
|
||||
messages = [
|
||||
{
|
||||
'role':
|
||||
'system',
|
||||
'content':
|
||||
'''
|
||||
"role": "system",
|
||||
"content": """
|
||||
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
|
||||
the following schema and a user request. You only need to output the format with fixed selection
|
||||
statement that selects everything from "'table'", like `SELECT * from 'table'`
|
||||
|
|
@ -165,10 +157,10 @@ def prompt_sql_query(query):
|
|||
request - Get all data points that contain 2 or more people and at least one dog
|
||||
correct query-
|
||||
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
|
||||
'''},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': f'{query}'}, ]
|
||||
""",
|
||||
},
|
||||
{"role": "user", "content": f"{query}"},
|
||||
]
|
||||
|
||||
response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages)
|
||||
response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
|
||||
return response.choices[0].message.content
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from ultralytics.utils.checks import check_requirements
|
|||
@dataclass
|
||||
class SourceTypes:
|
||||
"""Class to represent various types of input sources for predictions."""
|
||||
|
||||
webcam: bool = False
|
||||
screenshot: bool = False
|
||||
from_img: bool = False
|
||||
|
|
@ -59,12 +60,12 @@ class LoadStreams:
|
|||
__len__: Return the length of the sources object.
|
||||
"""
|
||||
|
||||
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, buffer=False):
|
||||
def __init__(self, sources="file.streams", imgsz=640, vid_stride=1, buffer=False):
|
||||
"""Initialize instance variables and check for consistent input stream shapes."""
|
||||
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||
self.buffer = buffer # buffer input streams
|
||||
self.running = True # running flag for Thread
|
||||
self.mode = 'stream'
|
||||
self.mode = "stream"
|
||||
self.imgsz = imgsz
|
||||
self.vid_stride = vid_stride # video frame-rate stride
|
||||
|
||||
|
|
@ -79,33 +80,36 @@ class LoadStreams:
|
|||
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
|
||||
for i, s in enumerate(sources): # index, source
|
||||
# Start thread to read frames from video stream
|
||||
st = f'{i + 1}/{n}: {s}... '
|
||||
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
|
||||
st = f"{i + 1}/{n}: {s}... "
|
||||
if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
|
||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
|
||||
s = get_best_youtube_url(s)
|
||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||
if s == 0 and (is_colab() or is_kaggle()):
|
||||
raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
|
||||
"Try running 'source=0' in a local environment.")
|
||||
raise NotImplementedError(
|
||||
"'source=0' webcam not supported in Colab and Kaggle notebooks. "
|
||||
"Try running 'source=0' in a local environment."
|
||||
)
|
||||
self.caps[i] = cv2.VideoCapture(s) # store video capture object
|
||||
if not self.caps[i].isOpened():
|
||||
raise ConnectionError(f'{st}Failed to open {s}')
|
||||
raise ConnectionError(f"{st}Failed to open {s}")
|
||||
w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
|
||||
self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float(
|
||||
'inf') # infinite stream fallback
|
||||
"inf"
|
||||
) # infinite stream fallback
|
||||
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
|
||||
|
||||
success, im = self.caps[i].read() # guarantee first frame
|
||||
if not success or im is None:
|
||||
raise ConnectionError(f'{st}Failed to read images from {s}')
|
||||
raise ConnectionError(f"{st}Failed to read images from {s}")
|
||||
self.imgs[i].append(im)
|
||||
self.shape[i] = im.shape
|
||||
self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True)
|
||||
LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)')
|
||||
LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
|
||||
self.threads[i].start()
|
||||
LOGGER.info('') # newline
|
||||
LOGGER.info("") # newline
|
||||
|
||||
# Check for common shapes
|
||||
self.bs = self.__len__()
|
||||
|
|
@ -121,7 +125,7 @@ class LoadStreams:
|
|||
success, im = cap.retrieve()
|
||||
if not success:
|
||||
im = np.zeros(self.shape[i], dtype=np.uint8)
|
||||
LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
|
||||
LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.")
|
||||
cap.open(stream) # re-open stream if signal was lost
|
||||
if self.buffer:
|
||||
self.imgs[i].append(im)
|
||||
|
|
@ -140,7 +144,7 @@ class LoadStreams:
|
|||
try:
|
||||
cap.release() # release video capture
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ Could not release VideoCapture object: {e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}")
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
def __iter__(self):
|
||||
|
|
@ -154,16 +158,15 @@ class LoadStreams:
|
|||
|
||||
images = []
|
||||
for i, x in enumerate(self.imgs):
|
||||
|
||||
# Wait until a frame is available in each buffer
|
||||
while not x:
|
||||
if not self.threads[i].is_alive() or cv2.waitKey(1) == ord('q'): # q to quit
|
||||
if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"): # q to quit
|
||||
self.close()
|
||||
raise StopIteration
|
||||
time.sleep(1 / min(self.fps))
|
||||
x = self.imgs[i]
|
||||
if not x:
|
||||
LOGGER.warning(f'WARNING ⚠️ Waiting for stream {i}')
|
||||
LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}")
|
||||
|
||||
# Get and remove the first frame from imgs buffer
|
||||
if self.buffer:
|
||||
|
|
@ -174,7 +177,7 @@ class LoadStreams:
|
|||
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
|
||||
x.clear()
|
||||
|
||||
return self.sources, images, None, ''
|
||||
return self.sources, images, None, ""
|
||||
|
||||
def __len__(self):
|
||||
"""Return the length of the sources object."""
|
||||
|
|
@ -209,7 +212,7 @@ class LoadScreenshots:
|
|||
|
||||
def __init__(self, source, imgsz=640):
|
||||
"""Source = [screen_number left top width height] (pixels)."""
|
||||
check_requirements('mss')
|
||||
check_requirements("mss")
|
||||
import mss # noqa
|
||||
|
||||
source, *params = source.split()
|
||||
|
|
@ -221,18 +224,18 @@ class LoadScreenshots:
|
|||
elif len(params) == 5:
|
||||
self.screen, left, top, width, height = (int(x) for x in params)
|
||||
self.imgsz = imgsz
|
||||
self.mode = 'stream'
|
||||
self.mode = "stream"
|
||||
self.frame = 0
|
||||
self.sct = mss.mss()
|
||||
self.bs = 1
|
||||
|
||||
# Parse monitor shape
|
||||
monitor = self.sct.monitors[self.screen]
|
||||
self.top = monitor['top'] if top is None else (monitor['top'] + top)
|
||||
self.left = monitor['left'] if left is None else (monitor['left'] + left)
|
||||
self.width = width or monitor['width']
|
||||
self.height = height or monitor['height']
|
||||
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
|
||||
self.top = monitor["top"] if top is None else (monitor["top"] + top)
|
||||
self.left = monitor["left"] if left is None else (monitor["left"] + left)
|
||||
self.width = width or monitor["width"]
|
||||
self.height = height or monitor["height"]
|
||||
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator of the object."""
|
||||
|
|
@ -241,7 +244,7 @@ class LoadScreenshots:
|
|||
def __next__(self):
|
||||
"""mss screen capture: get raw pixels from the screen as np array."""
|
||||
im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
|
||||
s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
|
||||
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
||||
|
||||
self.frame += 1
|
||||
return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string
|
||||
|
|
@ -274,32 +277,32 @@ class LoadImages:
|
|||
def __init__(self, path, imgsz=640, vid_stride=1):
|
||||
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
|
||||
parent = None
|
||||
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
||||
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
|
||||
parent = Path(path).parent
|
||||
path = Path(path).read_text().splitlines() # list of sources
|
||||
files = []
|
||||
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||||
a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
|
||||
if '*' in a:
|
||||
if "*" in a:
|
||||
files.extend(sorted(glob.glob(a, recursive=True))) # glob
|
||||
elif os.path.isdir(a):
|
||||
files.extend(sorted(glob.glob(os.path.join(a, '*.*')))) # dir
|
||||
files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir
|
||||
elif os.path.isfile(a):
|
||||
files.append(a) # files (absolute or relative to CWD)
|
||||
elif parent and (parent / p).is_file():
|
||||
files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
|
||||
else:
|
||||
raise FileNotFoundError(f'{p} does not exist')
|
||||
raise FileNotFoundError(f"{p} does not exist")
|
||||
|
||||
images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
|
||||
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
|
||||
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
|
||||
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
|
||||
ni, nv = len(images), len(videos)
|
||||
|
||||
self.imgsz = imgsz
|
||||
self.files = images + videos
|
||||
self.nf = ni + nv # number of files
|
||||
self.video_flag = [False] * ni + [True] * nv
|
||||
self.mode = 'image'
|
||||
self.mode = "image"
|
||||
self.vid_stride = vid_stride # video frame-rate stride
|
||||
self.bs = 1
|
||||
if any(videos):
|
||||
|
|
@ -307,8 +310,10 @@ class LoadImages:
|
|||
else:
|
||||
self.cap = None
|
||||
if self.nf == 0:
|
||||
raise FileNotFoundError(f'No images or videos found in {p}. '
|
||||
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
|
||||
raise FileNotFoundError(
|
||||
f"No images or videos found in {p}. "
|
||||
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator object for VideoStream or ImageFolder."""
|
||||
|
|
@ -323,7 +328,7 @@ class LoadImages:
|
|||
|
||||
if self.video_flag[self.count]:
|
||||
# Read video
|
||||
self.mode = 'video'
|
||||
self.mode = "video"
|
||||
for _ in range(self.vid_stride):
|
||||
self.cap.grab()
|
||||
success, im0 = self.cap.retrieve()
|
||||
|
|
@ -338,15 +343,15 @@ class LoadImages:
|
|||
|
||||
self.frame += 1
|
||||
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
|
||||
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
|
||||
s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: "
|
||||
|
||||
else:
|
||||
# Read image
|
||||
self.count += 1
|
||||
im0 = cv2.imread(path) # BGR
|
||||
if im0 is None:
|
||||
raise FileNotFoundError(f'Image Not Found {path}')
|
||||
s = f'image {self.count}/{self.nf} {path}: '
|
||||
raise FileNotFoundError(f"Image Not Found {path}")
|
||||
s = f"image {self.count}/{self.nf} {path}: "
|
||||
|
||||
return [path], [im0], self.cap, s
|
||||
|
||||
|
|
@ -385,20 +390,20 @@ class LoadPilAndNumpy:
|
|||
"""Initialize PIL and Numpy Dataloader."""
|
||||
if not isinstance(im0, list):
|
||||
im0 = [im0]
|
||||
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
|
||||
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
|
||||
self.im0 = [self._single_check(im) for im in im0]
|
||||
self.imgsz = imgsz
|
||||
self.mode = 'image'
|
||||
self.mode = "image"
|
||||
# Generate fake paths
|
||||
self.bs = len(self.im0)
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im):
|
||||
"""Validate and format an image to numpy array."""
|
||||
assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
|
||||
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
|
||||
if isinstance(im, Image.Image):
|
||||
if im.mode != 'RGB':
|
||||
im = im.convert('RGB')
|
||||
if im.mode != "RGB":
|
||||
im = im.convert("RGB")
|
||||
im = np.asarray(im)[:, :, ::-1]
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
return im
|
||||
|
|
@ -412,7 +417,7 @@ class LoadPilAndNumpy:
|
|||
if self.count == 1: # loop only once as it's batch inference
|
||||
raise StopIteration
|
||||
self.count += 1
|
||||
return self.paths, self.im0, None, ''
|
||||
return self.paths, self.im0, None, ""
|
||||
|
||||
def __iter__(self):
|
||||
"""Enables iteration for class LoadPilAndNumpy."""
|
||||
|
|
@ -441,14 +446,16 @@ class LoadTensor:
|
|||
"""Initialize Tensor Dataloader."""
|
||||
self.im0 = self._single_check(im0)
|
||||
self.bs = self.im0.shape[0]
|
||||
self.mode = 'image'
|
||||
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
|
||||
self.mode = "image"
|
||||
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im, stride=32):
|
||||
"""Validate and format an image to torch.Tensor."""
|
||||
s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \
|
||||
f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.'
|
||||
s = (
|
||||
f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
|
||||
f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
|
||||
)
|
||||
if len(im.shape) != 4:
|
||||
if len(im.shape) != 3:
|
||||
raise ValueError(s)
|
||||
|
|
@ -457,8 +464,10 @@ class LoadTensor:
|
|||
if im.shape[2] % stride or im.shape[3] % stride:
|
||||
raise ValueError(s)
|
||||
if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07
|
||||
LOGGER.warning(f'WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. '
|
||||
f'Dividing input by 255.')
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. "
|
||||
f"Dividing input by 255."
|
||||
)
|
||||
im = im.float() / 255.0
|
||||
|
||||
return im
|
||||
|
|
@ -473,7 +482,7 @@ class LoadTensor:
|
|||
if self.count == 1:
|
||||
raise StopIteration
|
||||
self.count += 1
|
||||
return self.paths, self.im0, None, ''
|
||||
return self.paths, self.im0, None, ""
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the batch size."""
|
||||
|
|
@ -485,12 +494,14 @@ def autocast_list(source):
|
|||
files = []
|
||||
for im in source:
|
||||
if isinstance(im, (str, Path)): # filename or uri
|
||||
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
|
||||
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im))
|
||||
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
||||
files.append(im)
|
||||
else:
|
||||
raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n'
|
||||
f'See https://docs.ultralytics.com/modes/predict for supported source types.')
|
||||
raise TypeError(
|
||||
f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
|
||||
f"See https://docs.ultralytics.com/modes/predict for supported source types."
|
||||
)
|
||||
|
||||
return files
|
||||
|
||||
|
|
@ -513,16 +524,18 @@ def get_best_youtube_url(url, use_pafy=True):
|
|||
(str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
|
||||
"""
|
||||
if use_pafy:
|
||||
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
||||
check_requirements(("pafy", "youtube_dl==2020.12.2"))
|
||||
import pafy # noqa
|
||||
return pafy.new(url).getbestvideo(preftype='mp4').url
|
||||
|
||||
return pafy.new(url).getbestvideo(preftype="mp4").url
|
||||
else:
|
||||
check_requirements('yt-dlp')
|
||||
check_requirements("yt-dlp")
|
||||
import yt_dlp
|
||||
with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
|
||||
|
||||
with yt_dlp.YoutubeDL({"quiet": True}) as ydl:
|
||||
info_dict = ydl.extract_info(url, download=False) # extract info
|
||||
for f in reversed(info_dict.get('formats', [])): # reversed because best is usually last
|
||||
for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last
|
||||
# Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size
|
||||
good_size = (f.get('width') or 0) >= 1920 or (f.get('height') or 0) >= 1080
|
||||
if good_size and f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4':
|
||||
return f.get('url')
|
||||
good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080
|
||||
if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4":
|
||||
return f.get("url")
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from tqdm import tqdm
|
|||
from ultralytics.data.utils import exif_size, img2label_paths
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements('shapely')
|
||||
check_requirements("shapely")
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
|
||||
|
|
@ -54,7 +54,7 @@ def bbox_iof(polygon1, bbox2, eps=1e-6):
|
|||
return outputs
|
||||
|
||||
|
||||
def load_yolo_dota(data_root, split='train'):
|
||||
def load_yolo_dota(data_root, split="train"):
|
||||
"""
|
||||
Load DOTA dataset.
|
||||
|
||||
|
|
@ -72,10 +72,10 @@ def load_yolo_dota(data_root, split='train'):
|
|||
- train
|
||||
- val
|
||||
"""
|
||||
assert split in ['train', 'val']
|
||||
im_dir = os.path.join(data_root, f'images/{split}')
|
||||
assert split in ["train", "val"]
|
||||
im_dir = os.path.join(data_root, f"images/{split}")
|
||||
assert Path(im_dir).exists(), f"Can't find {im_dir}, please check your data root."
|
||||
im_files = glob(os.path.join(data_root, f'images/{split}/*'))
|
||||
im_files = glob(os.path.join(data_root, f"images/{split}/*"))
|
||||
lb_files = img2label_paths(im_files)
|
||||
annos = []
|
||||
for im_file, lb_file in zip(im_files, lb_files):
|
||||
|
|
@ -100,7 +100,7 @@ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.0
|
|||
h, w = im_size
|
||||
windows = []
|
||||
for crop_size, gap in zip(crop_sizes, gaps):
|
||||
assert crop_size > gap, f'invaild crop_size gap pair [{crop_size} {gap}]'
|
||||
assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]"
|
||||
step = crop_size - gap
|
||||
|
||||
xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
|
||||
|
|
@ -132,8 +132,8 @@ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.0
|
|||
|
||||
def get_window_obj(anno, windows, iof_thr=0.7):
|
||||
"""Get objects for each window."""
|
||||
h, w = anno['ori_size']
|
||||
label = anno['label']
|
||||
h, w = anno["ori_size"]
|
||||
label = anno["label"]
|
||||
if len(label):
|
||||
label[:, 1::2] *= w
|
||||
label[:, 2::2] *= h
|
||||
|
|
@ -166,15 +166,15 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
|
|||
- train
|
||||
- val
|
||||
"""
|
||||
im = cv2.imread(anno['filepath'])
|
||||
name = Path(anno['filepath']).stem
|
||||
im = cv2.imread(anno["filepath"])
|
||||
name = Path(anno["filepath"]).stem
|
||||
for i, window in enumerate(windows):
|
||||
x_start, y_start, x_stop, y_stop = window.tolist()
|
||||
new_name = name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start)
|
||||
new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start)
|
||||
patch_im = im[y_start:y_stop, x_start:x_stop]
|
||||
ph, pw = patch_im.shape[:2]
|
||||
|
||||
cv2.imwrite(os.path.join(im_dir, f'{new_name}.jpg'), patch_im)
|
||||
cv2.imwrite(os.path.join(im_dir, f"{new_name}.jpg"), patch_im)
|
||||
label = window_objs[i]
|
||||
if len(label) == 0:
|
||||
continue
|
||||
|
|
@ -183,13 +183,13 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
|
|||
label[:, 1::2] /= pw
|
||||
label[:, 2::2] /= ph
|
||||
|
||||
with open(os.path.join(lb_dir, f'{new_name}.txt'), 'w') as f:
|
||||
with open(os.path.join(lb_dir, f"{new_name}.txt"), "w") as f:
|
||||
for lb in label:
|
||||
formatted_coords = ['{:.6g}'.format(coord) for coord in lb[1:]]
|
||||
formatted_coords = ["{:.6g}".format(coord) for coord in lb[1:]]
|
||||
f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")
|
||||
|
||||
|
||||
def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024], gaps=[200]):
|
||||
def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=[1024], gaps=[200]):
|
||||
"""
|
||||
Split both images and labels.
|
||||
|
||||
|
|
@ -207,14 +207,14 @@ def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024
|
|||
- labels
|
||||
- split
|
||||
"""
|
||||
im_dir = Path(save_dir) / 'images' / split
|
||||
im_dir = Path(save_dir) / "images" / split
|
||||
im_dir.mkdir(parents=True, exist_ok=True)
|
||||
lb_dir = Path(save_dir) / 'labels' / split
|
||||
lb_dir = Path(save_dir) / "labels" / split
|
||||
lb_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
annos = load_yolo_dota(data_root, split=split)
|
||||
for anno in tqdm(annos, total=len(annos), desc=split):
|
||||
windows = get_windows(anno['ori_size'], crop_sizes, gaps)
|
||||
windows = get_windows(anno["ori_size"], crop_sizes, gaps)
|
||||
window_objs = get_window_obj(anno, windows)
|
||||
crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
|
||||
|
||||
|
|
@ -245,7 +245,7 @@ def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
|
|||
for r in rates:
|
||||
crop_sizes.append(int(crop_size / r))
|
||||
gaps.append(int(gap / r))
|
||||
for split in ['train', 'val']:
|
||||
for split in ["train", "val"]:
|
||||
split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)
|
||||
|
||||
|
||||
|
|
@ -267,30 +267,30 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
|
|||
for r in rates:
|
||||
crop_sizes.append(int(crop_size / r))
|
||||
gaps.append(int(gap / r))
|
||||
save_dir = Path(save_dir) / 'images' / 'test'
|
||||
save_dir = Path(save_dir) / "images" / "test"
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
im_dir = Path(os.path.join(data_root, 'images/test'))
|
||||
im_dir = Path(os.path.join(data_root, "images/test"))
|
||||
assert im_dir.exists(), f"Can't find {str(im_dir)}, please check your data root."
|
||||
im_files = glob(str(im_dir / '*'))
|
||||
for im_file in tqdm(im_files, total=len(im_files), desc='test'):
|
||||
im_files = glob(str(im_dir / "*"))
|
||||
for im_file in tqdm(im_files, total=len(im_files), desc="test"):
|
||||
w, h = exif_size(Image.open(im_file))
|
||||
windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
|
||||
im = cv2.imread(im_file)
|
||||
name = Path(im_file).stem
|
||||
for window in windows:
|
||||
x_start, y_start, x_stop, y_stop = window.tolist()
|
||||
new_name = (name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start))
|
||||
new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start)
|
||||
patch_im = im[y_start:y_stop, x_start:x_stop]
|
||||
cv2.imwrite(os.path.join(str(save_dir), f'{new_name}.jpg'), patch_im)
|
||||
cv2.imwrite(os.path.join(str(save_dir), f"{new_name}.jpg"), patch_im)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
split_trainval(
|
||||
data_root='DOTAv2',
|
||||
save_dir='DOTAv2-split',
|
||||
data_root="DOTAv2",
|
||||
save_dir="DOTAv2-split",
|
||||
)
|
||||
split_test(
|
||||
data_root='DOTAv2',
|
||||
save_dir='DOTAv2-split',
|
||||
data_root="DOTAv2",
|
||||
save_dir="DOTAv2-split",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,36 +17,47 @@ import numpy as np
|
|||
from PIL import Image, ImageOps
|
||||
|
||||
from ultralytics.nn.autobackend import check_class_names
|
||||
from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, TQDM, clean_url, colorstr,
|
||||
emojis, yaml_load, yaml_save)
|
||||
from ultralytics.utils import (
|
||||
DATASETS_DIR,
|
||||
LOGGER,
|
||||
NUM_THREADS,
|
||||
ROOT,
|
||||
SETTINGS_YAML,
|
||||
TQDM,
|
||||
clean_url,
|
||||
colorstr,
|
||||
emojis,
|
||||
yaml_load,
|
||||
yaml_save,
|
||||
)
|
||||
from ultralytics.utils.checks import check_file, check_font, is_ascii
|
||||
from ultralytics.utils.downloads import download, safe_download, unzip_file
|
||||
from ultralytics.utils.ops import segments2boxes
|
||||
|
||||
HELP_URL = 'See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance.'
|
||||
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
|
||||
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
|
||||
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
||||
HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance."
|
||||
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # image suffixes
|
||||
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm" # video suffixes
|
||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
||||
|
||||
|
||||
def img2label_paths(img_paths):
|
||||
"""Define label paths as a function of image paths."""
|
||||
sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
|
||||
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
|
||||
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
|
||||
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
|
||||
|
||||
|
||||
def get_hash(paths):
|
||||
"""Returns a single hash value of a list of paths (files or dirs)."""
|
||||
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
|
||||
h = hashlib.sha256(str(size).encode()) # hash sizes
|
||||
h.update(''.join(paths).encode()) # hash paths
|
||||
h.update("".join(paths).encode()) # hash paths
|
||||
return h.hexdigest() # return hash
|
||||
|
||||
|
||||
def exif_size(img: Image.Image):
|
||||
"""Returns exif-corrected PIL size."""
|
||||
s = img.size # (width, height)
|
||||
if img.format == 'JPEG': # only support JPEG images
|
||||
if img.format == "JPEG": # only support JPEG images
|
||||
with contextlib.suppress(Exception):
|
||||
exif = img.getexif()
|
||||
if exif:
|
||||
|
|
@ -60,24 +71,24 @@ def verify_image(args):
|
|||
"""Verify one image."""
|
||||
(im_file, cls), prefix = args
|
||||
# Number (found, corrupt), message
|
||||
nf, nc, msg = 0, 0, ''
|
||||
nf, nc, msg = 0, 0, ""
|
||||
try:
|
||||
im = Image.open(im_file)
|
||||
im.verify() # PIL verify
|
||||
shape = exif_size(im) # image size
|
||||
shape = (shape[1], shape[0]) # hw
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
|
||||
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
|
||||
if im.format.lower() in ('jpg', 'jpeg'):
|
||||
with open(im_file, 'rb') as f:
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
||||
if im.format.lower() in ("jpg", "jpeg"):
|
||||
with open(im_file, "rb") as f:
|
||||
f.seek(-2, 2)
|
||||
if f.read() != b'\xff\xd9': # corrupt JPEG
|
||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
|
||||
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
|
||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
||||
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
|
||||
nf = 1
|
||||
except Exception as e:
|
||||
nc = 1
|
||||
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
|
||||
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
|
||||
return (im_file, cls), nf, nc, msg
|
||||
|
||||
|
||||
|
|
@ -85,21 +96,21 @@ def verify_image_label(args):
|
|||
"""Verify one image-label pair."""
|
||||
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
|
||||
# Number (missing, found, empty, corrupt), message, segments, keypoints
|
||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
|
||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
|
||||
try:
|
||||
# Verify images
|
||||
im = Image.open(im_file)
|
||||
im.verify() # PIL verify
|
||||
shape = exif_size(im) # image size
|
||||
shape = (shape[1], shape[0]) # hw
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
|
||||
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
|
||||
if im.format.lower() in ('jpg', 'jpeg'):
|
||||
with open(im_file, 'rb') as f:
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
||||
if im.format.lower() in ("jpg", "jpeg"):
|
||||
with open(im_file, "rb") as f:
|
||||
f.seek(-2, 2)
|
||||
if f.read() != b'\xff\xd9': # corrupt JPEG
|
||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
|
||||
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
|
||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
||||
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
|
||||
|
||||
# Verify labels
|
||||
if os.path.isfile(lb_file):
|
||||
|
|
@ -114,25 +125,26 @@ def verify_image_label(args):
|
|||
nl = len(lb)
|
||||
if nl:
|
||||
if keypoint:
|
||||
assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each'
|
||||
assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
|
||||
points = lb[:, 5:].reshape(-1, ndim)[:, :2]
|
||||
else:
|
||||
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
|
||||
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
|
||||
points = lb[:, 1:]
|
||||
assert points.max() <= 1, f'non-normalized or out of bounds coordinates {points[points > 1]}'
|
||||
assert lb.min() >= 0, f'negative label values {lb[lb < 0]}'
|
||||
assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
|
||||
assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"
|
||||
|
||||
# All labels
|
||||
max_cls = lb[:, 0].max() # max label count
|
||||
assert max_cls <= num_cls, \
|
||||
f'Label class {int(max_cls)} exceeds dataset class count {num_cls}. ' \
|
||||
f'Possible class labels are 0-{num_cls - 1}'
|
||||
assert max_cls <= num_cls, (
|
||||
f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
|
||||
f"Possible class labels are 0-{num_cls - 1}"
|
||||
)
|
||||
_, i = np.unique(lb, axis=0, return_index=True)
|
||||
if len(i) < nl: # duplicate row check
|
||||
lb = lb[i] # remove duplicates
|
||||
if segments:
|
||||
segments = [segments[x] for x in i]
|
||||
msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
|
||||
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
|
||||
else:
|
||||
ne = 1 # label empty
|
||||
lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
|
||||
|
|
@ -148,7 +160,7 @@ def verify_image_label(args):
|
|||
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
||||
except Exception as e:
|
||||
nc = 1
|
||||
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
|
||||
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
|
||||
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
||||
|
||||
|
||||
|
|
@ -194,8 +206,10 @@ def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
|
|||
|
||||
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
|
||||
"""Return a (640, 640) overlap mask."""
|
||||
masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
|
||||
dtype=np.int32 if len(segments) > 255 else np.uint8)
|
||||
masks = np.zeros(
|
||||
(imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
|
||||
dtype=np.int32 if len(segments) > 255 else np.uint8,
|
||||
)
|
||||
areas = []
|
||||
ms = []
|
||||
for si in range(len(segments)):
|
||||
|
|
@ -226,7 +240,7 @@ def find_dataset_yaml(path: Path) -> Path:
|
|||
Returns:
|
||||
(Path): The path of the found YAML file.
|
||||
"""
|
||||
files = list(path.glob('*.yaml')) or list(path.rglob('*.yaml')) # try root level first and then recursive
|
||||
files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
|
||||
assert files, f"No YAML file found in '{path.resolve()}'"
|
||||
if len(files) > 1:
|
||||
files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
|
||||
|
|
@ -253,7 +267,7 @@ def check_det_dataset(dataset, autodownload=True):
|
|||
file = check_file(dataset)
|
||||
|
||||
# Download (optional)
|
||||
extract_dir = ''
|
||||
extract_dir = ""
|
||||
if zipfile.is_zipfile(file) or is_tarfile(file):
|
||||
new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||
file = find_dataset_yaml(DATASETS_DIR / new_dir)
|
||||
|
|
@ -263,43 +277,44 @@ def check_det_dataset(dataset, autodownload=True):
|
|||
data = yaml_load(file, append_filename=True) # dictionary
|
||||
|
||||
# Checks
|
||||
for k in 'train', 'val':
|
||||
for k in "train", "val":
|
||||
if k not in data:
|
||||
if k != 'val' or 'validation' not in data:
|
||||
if k != "val" or "validation" not in data:
|
||||
raise SyntaxError(
|
||||
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))
|
||||
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
|
||||
)
|
||||
LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
|
||||
data['val'] = data.pop('validation') # replace 'validation' key with 'val' key
|
||||
if 'names' not in data and 'nc' not in data:
|
||||
data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
|
||||
if "names" not in data and "nc" not in data:
|
||||
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
|
||||
if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:
|
||||
if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
|
||||
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
|
||||
if 'names' not in data:
|
||||
data['names'] = [f'class_{i}' for i in range(data['nc'])]
|
||||
if "names" not in data:
|
||||
data["names"] = [f"class_{i}" for i in range(data["nc"])]
|
||||
else:
|
||||
data['nc'] = len(data['names'])
|
||||
data["nc"] = len(data["names"])
|
||||
|
||||
data['names'] = check_class_names(data['names'])
|
||||
data["names"] = check_class_names(data["names"])
|
||||
|
||||
# Resolve paths
|
||||
path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
|
||||
path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
|
||||
if not path.is_absolute():
|
||||
path = (DATASETS_DIR / path).resolve()
|
||||
|
||||
# Set paths
|
||||
data['path'] = path # download scripts
|
||||
for k in 'train', 'val', 'test':
|
||||
data["path"] = path # download scripts
|
||||
for k in "train", "val", "test":
|
||||
if data.get(k): # prepend path
|
||||
if isinstance(data[k], str):
|
||||
x = (path / data[k]).resolve()
|
||||
if not x.exists() and data[k].startswith('../'):
|
||||
if not x.exists() and data[k].startswith("../"):
|
||||
x = (path / data[k][3:]).resolve()
|
||||
data[k] = str(x)
|
||||
else:
|
||||
data[k] = [str((path / x).resolve()) for x in data[k]]
|
||||
|
||||
# Parse YAML
|
||||
val, s = (data.get(x) for x in ('val', 'download'))
|
||||
val, s = (data.get(x) for x in ("val", "download"))
|
||||
if val:
|
||||
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
||||
if not all(x.exists() for x in val):
|
||||
|
|
@ -312,22 +327,22 @@ def check_det_dataset(dataset, autodownload=True):
|
|||
raise FileNotFoundError(m)
|
||||
t = time.time()
|
||||
r = None # success
|
||||
if s.startswith('http') and s.endswith('.zip'): # URL
|
||||
if s.startswith("http") and s.endswith(".zip"): # URL
|
||||
safe_download(url=s, dir=DATASETS_DIR, delete=True)
|
||||
elif s.startswith('bash '): # bash script
|
||||
LOGGER.info(f'Running {s} ...')
|
||||
elif s.startswith("bash "): # bash script
|
||||
LOGGER.info(f"Running {s} ...")
|
||||
r = os.system(s)
|
||||
else: # python script
|
||||
exec(s, {'yaml': data})
|
||||
dt = f'({round(time.time() - t, 1)}s)'
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
|
||||
LOGGER.info(f'Dataset download {s}\n')
|
||||
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
|
||||
exec(s, {"yaml": data})
|
||||
dt = f"({round(time.time() - t, 1)}s)"
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
||||
LOGGER.info(f"Dataset download {s}\n")
|
||||
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
|
||||
|
||||
return data # dictionary
|
||||
|
||||
|
||||
def check_cls_dataset(dataset, split=''):
|
||||
def check_cls_dataset(dataset, split=""):
|
||||
"""
|
||||
Checks a classification dataset such as Imagenet.
|
||||
|
||||
|
|
@ -348,54 +363,59 @@ def check_cls_dataset(dataset, split=''):
|
|||
"""
|
||||
|
||||
# Download (optional if dataset=https://file.zip is passed directly)
|
||||
if str(dataset).startswith(('http:/', 'https:/')):
|
||||
if str(dataset).startswith(("http:/", "https:/")):
|
||||
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||
|
||||
dataset = Path(dataset)
|
||||
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
|
||||
if not data_dir.is_dir():
|
||||
LOGGER.warning(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
|
||||
LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...")
|
||||
t = time.time()
|
||||
if str(dataset) == 'imagenet':
|
||||
if str(dataset) == "imagenet":
|
||||
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
||||
else:
|
||||
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
||||
url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip"
|
||||
download(url, dir=data_dir.parent)
|
||||
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
||||
LOGGER.info(s)
|
||||
train_set = data_dir / 'train'
|
||||
val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if \
|
||||
(data_dir / 'validation').exists() else None # data/test or data/val
|
||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
|
||||
if split == 'val' and not val_set:
|
||||
train_set = data_dir / "train"
|
||||
val_set = (
|
||||
data_dir / "val"
|
||||
if (data_dir / "val").exists()
|
||||
else data_dir / "validation"
|
||||
if (data_dir / "validation").exists()
|
||||
else None
|
||||
) # data/test or data/val
|
||||
test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
|
||||
if split == "val" and not val_set:
|
||||
LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
|
||||
elif split == 'test' and not test_set:
|
||||
elif split == "test" and not test_set:
|
||||
LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
|
||||
|
||||
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
||||
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
||||
nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
|
||||
names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
|
||||
names = dict(enumerate(sorted(names)))
|
||||
|
||||
# Print to console
|
||||
for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
|
||||
for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
|
||||
prefix = f'{colorstr(f"{k}:")} {v}...'
|
||||
if v is None:
|
||||
LOGGER.info(prefix)
|
||||
else:
|
||||
files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
|
||||
files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
|
||||
nf = len(files) # number of files
|
||||
nd = len({file.parent for file in files}) # number of directories
|
||||
if nf == 0:
|
||||
if k == 'train':
|
||||
if k == "train":
|
||||
raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
|
||||
else:
|
||||
LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found')
|
||||
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found")
|
||||
elif nd != nc:
|
||||
LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}')
|
||||
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}")
|
||||
else:
|
||||
LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ')
|
||||
LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
|
||||
|
||||
return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names}
|
||||
return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
|
||||
|
||||
|
||||
class HUBDatasetStats:
|
||||
|
|
@ -423,42 +443,43 @@ class HUBDatasetStats:
|
|||
```
|
||||
"""
|
||||
|
||||
def __init__(self, path='coco8.yaml', task='detect', autodownload=False):
|
||||
def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
|
||||
"""Initialize class."""
|
||||
path = Path(path).resolve()
|
||||
LOGGER.info(f'Starting HUB dataset checks for {path}....')
|
||||
LOGGER.info(f"Starting HUB dataset checks for {path}....")
|
||||
|
||||
self.task = task # detect, segment, pose, classify
|
||||
if self.task == 'classify':
|
||||
if self.task == "classify":
|
||||
unzip_dir = unzip_file(path)
|
||||
data = check_cls_dataset(unzip_dir)
|
||||
data['path'] = unzip_dir
|
||||
data["path"] = unzip_dir
|
||||
else: # detect, segment, pose
|
||||
_, data_dir, yaml_path = self._unzip(Path(path))
|
||||
try:
|
||||
# Load YAML with checks
|
||||
data = yaml_load(yaml_path)
|
||||
data['path'] = '' # strip path since YAML should be in dataset root for all HUB datasets
|
||||
data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
|
||||
yaml_save(yaml_path, data)
|
||||
data = check_det_dataset(yaml_path, autodownload) # dict
|
||||
data['path'] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
|
||||
data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
|
||||
except Exception as e:
|
||||
raise Exception('error/HUB/dataset_stats/init') from e
|
||||
raise Exception("error/HUB/dataset_stats/init") from e
|
||||
|
||||
self.hub_dir = Path(f'{data["path"]}-hub')
|
||||
self.im_dir = self.hub_dir / 'images'
|
||||
self.im_dir = self.hub_dir / "images"
|
||||
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
||||
self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
|
||||
self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
|
||||
self.data = data
|
||||
|
||||
@staticmethod
|
||||
def _unzip(path):
|
||||
"""Unzip data.zip."""
|
||||
if not str(path).endswith('.zip'): # path is data.yaml
|
||||
if not str(path).endswith(".zip"): # path is data.yaml
|
||||
return False, None, path
|
||||
unzip_dir = unzip_file(path, path=path.parent)
|
||||
assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
|
||||
f'path/to/abc.zip MUST unzip to path/to/abc/'
|
||||
assert unzip_dir.is_dir(), (
|
||||
f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/"
|
||||
)
|
||||
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
|
||||
|
||||
def _hub_ops(self, f):
|
||||
|
|
@ -470,31 +491,31 @@ class HUBDatasetStats:
|
|||
|
||||
def _round(labels):
|
||||
"""Update labels to integer class and 4 decimal place floats."""
|
||||
if self.task == 'detect':
|
||||
coordinates = labels['bboxes']
|
||||
elif self.task == 'segment':
|
||||
coordinates = [x.flatten() for x in labels['segments']]
|
||||
elif self.task == 'pose':
|
||||
n = labels['keypoints'].shape[0]
|
||||
coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1)
|
||||
if self.task == "detect":
|
||||
coordinates = labels["bboxes"]
|
||||
elif self.task == "segment":
|
||||
coordinates = [x.flatten() for x in labels["segments"]]
|
||||
elif self.task == "pose":
|
||||
n = labels["keypoints"].shape[0]
|
||||
coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, -1)), 1)
|
||||
else:
|
||||
raise ValueError('Undefined dataset task.')
|
||||
zipped = zip(labels['cls'], coordinates)
|
||||
raise ValueError("Undefined dataset task.")
|
||||
zipped = zip(labels["cls"], coordinates)
|
||||
return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
|
||||
|
||||
for split in 'train', 'val', 'test':
|
||||
for split in "train", "val", "test":
|
||||
self.stats[split] = None # predefine
|
||||
path = self.data.get(split)
|
||||
|
||||
# Check split
|
||||
if path is None: # no split
|
||||
continue
|
||||
files = [f for f in Path(path).rglob('*.*') if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
|
||||
files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
|
||||
if not files: # no images
|
||||
continue
|
||||
|
||||
# Get dataset statistics
|
||||
if self.task == 'classify':
|
||||
if self.task == "classify":
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
dataset = ImageFolder(self.data[split])
|
||||
|
|
@ -504,38 +525,35 @@ class HUBDatasetStats:
|
|||
x[im[1]] += 1
|
||||
|
||||
self.stats[split] = {
|
||||
'instance_stats': {
|
||||
'total': len(dataset),
|
||||
'per_class': x.tolist()},
|
||||
'image_stats': {
|
||||
'total': len(dataset),
|
||||
'unlabelled': 0,
|
||||
'per_class': x.tolist()},
|
||||
'labels': [{
|
||||
Path(k).name: v} for k, v in dataset.imgs]}
|
||||
"instance_stats": {"total": len(dataset), "per_class": x.tolist()},
|
||||
"image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
|
||||
"labels": [{Path(k).name: v} for k, v in dataset.imgs],
|
||||
}
|
||||
else:
|
||||
from ultralytics.data import YOLODataset
|
||||
|
||||
dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
|
||||
x = np.array([
|
||||
np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
|
||||
for label in TQDM(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
|
||||
x = np.array(
|
||||
[
|
||||
np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
|
||||
for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
|
||||
]
|
||||
) # shape(128x80)
|
||||
self.stats[split] = {
|
||||
'instance_stats': {
|
||||
'total': int(x.sum()),
|
||||
'per_class': x.sum(0).tolist()},
|
||||
'image_stats': {
|
||||
'total': len(dataset),
|
||||
'unlabelled': int(np.all(x == 0, 1).sum()),
|
||||
'per_class': (x > 0).sum(0).tolist()},
|
||||
'labels': [{
|
||||
Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]}
|
||||
"instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
|
||||
"image_stats": {
|
||||
"total": len(dataset),
|
||||
"unlabelled": int(np.all(x == 0, 1).sum()),
|
||||
"per_class": (x > 0).sum(0).tolist(),
|
||||
},
|
||||
"labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
|
||||
}
|
||||
|
||||
# Save, print and return
|
||||
if save:
|
||||
stats_path = self.hub_dir / 'stats.json'
|
||||
LOGGER.info(f'Saving {stats_path.resolve()}...')
|
||||
with open(stats_path, 'w') as f:
|
||||
stats_path = self.hub_dir / "stats.json"
|
||||
LOGGER.info(f"Saving {stats_path.resolve()}...")
|
||||
with open(stats_path, "w") as f:
|
||||
json.dump(self.stats, f) # save stats.json
|
||||
if verbose:
|
||||
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
||||
|
|
@ -545,14 +563,14 @@ class HUBDatasetStats:
|
|||
"""Compress images for Ultralytics HUB."""
|
||||
from ultralytics.data import YOLODataset # ClassificationDataset
|
||||
|
||||
for split in 'train', 'val', 'test':
|
||||
for split in "train", "val", "test":
|
||||
if self.data.get(split) is None:
|
||||
continue
|
||||
dataset = YOLODataset(img_path=self.data[split], data=self.data)
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
|
||||
for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
|
||||
pass
|
||||
LOGGER.info(f'Done. All images saved to {self.im_dir}')
|
||||
LOGGER.info(f"Done. All images saved to {self.im_dir}")
|
||||
return self.im_dir
|
||||
|
||||
|
||||
|
|
@ -583,9 +601,9 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
|||
r = max_dim / max(im.height, im.width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = im.resize((int(im.width * r), int(im.height * r)))
|
||||
im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save
|
||||
im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
|
||||
except Exception as e: # use OpenCV
|
||||
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
|
||||
LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}")
|
||||
im = cv2.imread(f)
|
||||
im_height, im_width = im.shape[:2]
|
||||
r = max_dim / max(im_height, im_width) # ratio
|
||||
|
|
@ -594,7 +612,7 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
|||
cv2.imwrite(str(f_new or f), im)
|
||||
|
||||
|
||||
def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
|
||||
def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
|
||||
"""
|
||||
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
|
||||
|
||||
|
|
@ -612,18 +630,18 @@ def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annot
|
|||
"""
|
||||
|
||||
path = Path(path) # images dir
|
||||
files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
|
||||
files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
|
||||
n = len(files) # number of files
|
||||
random.seed(0) # for reproducibility
|
||||
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
|
||||
|
||||
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
|
||||
txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
|
||||
for x in txt:
|
||||
if (path.parent / x).exists():
|
||||
(path.parent / x).unlink() # remove existing
|
||||
|
||||
LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
|
||||
LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
|
||||
for i, img in TQDM(zip(indices, files), total=n):
|
||||
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
||||
with open(path.parent / txt[i], 'a') as f:
|
||||
f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
|
||||
with open(path.parent / txt[i], "a") as f:
|
||||
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue