Update .pre-commit-config.yaml (#1026)
This commit is contained in:
parent
9047d737f4
commit
edd3ff1669
76 changed files with 928 additions and 935 deletions
|
|
@ -28,7 +28,7 @@ class YOLODataset(BaseDataset):
|
|||
cache=False,
|
||||
augment=True,
|
||||
hyp=None,
|
||||
prefix="",
|
||||
prefix='',
|
||||
rect=False,
|
||||
batch_size=None,
|
||||
stride=32,
|
||||
|
|
@ -40,14 +40,14 @@ class YOLODataset(BaseDataset):
|
|||
self.use_segments = use_segments
|
||||
self.use_keypoints = use_keypoints
|
||||
self.names = names
|
||||
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
||||
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
|
||||
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
|
||||
|
||||
def cache_labels(self, path=Path("./labels.cache")):
|
||||
def cache_labels(self, path=Path('./labels.cache')):
|
||||
# Cache dataset labels, check images and read shapes
|
||||
x = {"labels": []}
|
||||
x = {'labels': []}
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||
desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
|
||||
total = len(self.im_files)
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(func=verify_image_label,
|
||||
|
|
@ -60,7 +60,7 @@ class YOLODataset(BaseDataset):
|
|||
ne += ne_f
|
||||
nc += nc_f
|
||||
if im_file:
|
||||
x["labels"].append(
|
||||
x['labels'].append(
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=shape,
|
||||
|
|
@ -69,68 +69,68 @@ class YOLODataset(BaseDataset):
|
|||
segments=segments,
|
||||
keypoints=keypoint,
|
||||
normalized=True,
|
||||
bbox_format="xywh"))
|
||||
bbox_format='xywh'))
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||
pbar.close()
|
||||
|
||||
if msgs:
|
||||
LOGGER.info("\n".join(msgs))
|
||||
LOGGER.info('\n'.join(msgs))
|
||||
if nf == 0:
|
||||
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
|
||||
x["hash"] = get_hash(self.label_files + self.im_files)
|
||||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||||
x["msgs"] = msgs # warnings
|
||||
x["version"] = self.cache_version # cache version
|
||||
LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
|
||||
x['hash'] = get_hash(self.label_files + self.im_files)
|
||||
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
||||
x['msgs'] = msgs # warnings
|
||||
x['version'] = self.cache_version # cache version
|
||||
if is_dir_writeable(path.parent):
|
||||
if path.exists():
|
||||
path.unlink() # remove *.cache file if exists
|
||||
np.save(str(path), x) # save cache for next time
|
||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||
LOGGER.info(f"{self.prefix}New cache created: {path}")
|
||||
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
||||
LOGGER.info(f'{self.prefix}New cache created: {path}')
|
||||
else:
|
||||
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
||||
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
|
||||
return x
|
||||
|
||||
def get_labels(self):
|
||||
self.label_files = img2label_paths(self.im_files)
|
||||
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
||||
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
|
||||
try:
|
||||
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
|
||||
assert cache["version"] == self.cache_version # matches current version
|
||||
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
||||
assert cache['version'] == self.cache_version # matches current version
|
||||
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
||||
except (FileNotFoundError, AssertionError, AttributeError):
|
||||
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||||
|
||||
# Display cache
|
||||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
||||
if exists and LOCAL_RANK in {-1, 0}:
|
||||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
if cache['msgs']:
|
||||
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
||||
if nf == 0: # number of labels found
|
||||
raise FileNotFoundError(f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}")
|
||||
raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}')
|
||||
|
||||
# Read cache
|
||||
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
||||
labels = cache["labels"]
|
||||
self.im_files = [lb["im_file"] for lb in labels] # update im_files
|
||||
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
|
||||
labels = cache['labels']
|
||||
self.im_files = [lb['im_file'] for lb in labels] # update im_files
|
||||
|
||||
# Check if the dataset is all boxes or all segments
|
||||
len_cls = sum(len(lb["cls"]) for lb in labels)
|
||||
len_boxes = sum(len(lb["bboxes"]) for lb in labels)
|
||||
len_segments = sum(len(lb["segments"]) for lb in labels)
|
||||
len_cls = sum(len(lb['cls']) for lb in labels)
|
||||
len_boxes = sum(len(lb['bboxes']) for lb in labels)
|
||||
len_segments = sum(len(lb['segments']) for lb in labels)
|
||||
if len_segments and len_boxes != len_segments:
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
||||
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
||||
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")
|
||||
f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
|
||||
f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
|
||||
'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
|
||||
for lb in labels:
|
||||
lb["segments"] = []
|
||||
lb['segments'] = []
|
||||
if len_cls == 0:
|
||||
raise ValueError(f"All labels empty in {cache_path}, can not start training without labels. {HELP_URL}")
|
||||
raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}')
|
||||
return labels
|
||||
|
||||
# TODO: use hyp config to set all these augmentations
|
||||
|
|
@ -142,7 +142,7 @@ class YOLODataset(BaseDataset):
|
|||
else:
|
||||
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
||||
transforms.append(
|
||||
Format(bbox_format="xywh",
|
||||
Format(bbox_format='xywh',
|
||||
normalize=True,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
|
|
@ -161,12 +161,12 @@ class YOLODataset(BaseDataset):
|
|||
"""custom your label format here"""
|
||||
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
||||
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
||||
bboxes = label.pop("bboxes")
|
||||
segments = label.pop("segments")
|
||||
keypoints = label.pop("keypoints", None)
|
||||
bbox_format = label.pop("bbox_format")
|
||||
normalized = label.pop("normalized")
|
||||
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||||
bboxes = label.pop('bboxes')
|
||||
segments = label.pop('segments')
|
||||
keypoints = label.pop('keypoints', None)
|
||||
bbox_format = label.pop('bbox_format')
|
||||
normalized = label.pop('normalized')
|
||||
label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||||
return label
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -176,15 +176,15 @@ class YOLODataset(BaseDataset):
|
|||
values = list(zip(*[list(b.values()) for b in batch]))
|
||||
for i, k in enumerate(keys):
|
||||
value = values[i]
|
||||
if k == "img":
|
||||
if k == 'img':
|
||||
value = torch.stack(value, 0)
|
||||
if k in ["masks", "keypoints", "bboxes", "cls"]:
|
||||
if k in ['masks', 'keypoints', 'bboxes', 'cls']:
|
||||
value = torch.cat(value, 0)
|
||||
new_batch[k] = value
|
||||
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||||
for i in range(len(new_batch["batch_idx"])):
|
||||
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
||||
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
||||
new_batch['batch_idx'] = list(new_batch['batch_idx'])
|
||||
for i in range(len(new_batch['batch_idx'])):
|
||||
new_batch['batch_idx'][i] += i # add target image index for build_targets()
|
||||
new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
|
||||
return new_batch
|
||||
|
||||
|
||||
|
|
@ -202,9 +202,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
super().__init__(root=root)
|
||||
self.torch_transforms = classify_transforms(imgsz)
|
||||
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
|
||||
self.cache_ram = cache is True or cache == "ram"
|
||||
self.cache_disk = cache == "disk"
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||||
self.cache_ram = cache is True or cache == 'ram'
|
||||
self.cache_disk = cache == 'disk'
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
||||
|
||||
def __getitem__(self, i):
|
||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||
|
|
@ -217,7 +217,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if self.album_transforms:
|
||||
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
|
||||
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
|
||||
else:
|
||||
sample = self.torch_transforms(im)
|
||||
return {'img': sample, 'cls': j}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue