Fix HUBDatasetStats for no-label edge cases (#4583)

This commit is contained in:
Glenn Jocher 2023-08-26 19:38:02 +02:00 committed by GitHub
parent 2db35afad5
commit f755ba88c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 62 additions and 41 deletions

View file

@ -115,7 +115,7 @@ class BaseDataset(Dataset):
raise FileNotFoundError(f'{self.prefix}{p} does not exist')
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert im_files, f'{self.prefix}No images found'
assert im_files, f'{self.prefix}No images found in {img_path}'
except Exception as e:
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
if self.fraction < 1:

View file

@ -110,13 +110,12 @@ class YOLODataset(BaseDataset):
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display results
if cache['msgs']:
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
if nf == 0: # number of labels found
raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}')
# Read cache
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
labels = cache['labels']
assert len(labels), f'No valid labels found, please check your dataset. {HELP_URL}'
if not labels:
LOGGER.warning(f'WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}')
self.im_files = [lb['im_file'] for lb in labels] # update im_files
# Check if the dataset is all boxes or all segments
@ -130,10 +129,9 @@ class YOLODataset(BaseDataset):
for lb in labels:
lb['segments'] = []
if len_cls == 0:
raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}')
LOGGER.warning(f'WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}')
return labels
# TODO: use hyp config to set all these augmentations
def build_transforms(self, hyp=None):
"""Builds and appends transforms to the list."""
if self.augment:

View file

@ -447,10 +447,17 @@ class HUBDatasetStats:
return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
for split in 'train', 'val', 'test':
if self.data.get(split) is None:
self.stats[split] = None # i.e. no test set
self.stats[split] = None # predefine
path = self.data.get(split)
# Check split
if path is None: # no split
continue
files = [f for f in Path(path).rglob('*.*') if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
if not files: # no images
continue
# Get dataset statistics
dataset = YOLODataset(img_path=self.data[split],
data=self.data,
use_segments=self.task == 'segment',