New dataset fraction=1.0 argument (#2860)

This commit is contained in:
Glenn Jocher 2023-05-28 02:13:46 +02:00 committed by GitHub
parent 61fa5efe6d
commit 0bdd4ad379
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 16 additions and 5 deletions

View file

@ -36,6 +36,7 @@ class BaseDataset(Dataset):
pad (float, optional): Padding. Defaults to 0.0.
single_cls (bool, optional): If True, single class training is used. Defaults to False.
classes (list): List of included classes. Default is None.
fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).
Attributes:
im_files (list): List of image file paths.
@ -58,13 +59,15 @@ class BaseDataset(Dataset):
stride=32,
pad=0.5,
single_cls=False,
classes=None):
classes=None,
fraction=1.0):
super().__init__()
self.img_path = img_path
self.imgsz = imgsz
self.augment = augment
self.single_cls = single_cls
self.prefix = prefix
self.fraction = fraction
self.im_files = self.get_img_files(self.img_path)
self.labels = self.get_labels()
self.update_labels(include_class=classes) # single_cls and include_class
@ -114,6 +117,8 @@ class BaseDataset(Dataset):
assert im_files, f'{self.prefix}No images found'
except Exception as e:
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
if self.fraction < 1:
im_files = im_files[:round(len(im_files) * self.fraction)]
return im_files
def update_labels(self, include_class: Optional[list]):