New dataset fraction=1.0 argument (#2860)

This commit is contained in:
Glenn Jocher 2023-05-28 02:13:46 +02:00 committed by GitHub
parent 61fa5efe6d
commit 0bdd4ad379
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 16 additions and 5 deletions

View file

@ -226,6 +226,8 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
cache (Union[bool, str], optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
"""
super().__init__(root=root)
if augment and args.fraction < 1.0: # reduce training fraction
self.samples = self.samples[:round(len(self.samples) * args.fraction)]
self.cache_ram = cache is True or cache == 'ram'
self.cache_disk = cache == 'disk'
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
@ -269,4 +271,4 @@ class SemanticDataset(BaseDataset):
def __init__(self):
"""Initialize a SemanticDataset object."""
pass
super().__init__()