Simplify augmentations (#93)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
249dfbdc05
commit
ae05d44877
5 changed files with 36 additions and 68 deletions
|
|
@ -42,12 +42,13 @@ class BaseDataset(Dataset):
|
|||
self.imgsz = imgsz
|
||||
self.label_path = label_path
|
||||
self.augment = augment
|
||||
self.single_cls = single_cls
|
||||
self.prefix = prefix
|
||||
|
||||
self.im_files = self.get_img_files(self.img_path)
|
||||
self.labels = self.get_labels()
|
||||
if single_cls:
|
||||
self.update_labels(include_class=[], single_cls=single_cls)
|
||||
if self.single_cls:
|
||||
self.update_labels(include_class=[])
|
||||
|
||||
self.ni = len(self.im_files)
|
||||
|
||||
|
|
@ -173,10 +174,7 @@ class BaseDataset(Dataset):
|
|||
self.batch = bi # batch index of image
|
||||
|
||||
def __getitem__(self, index):
|
||||
label = self.get_label_info(index)
|
||||
if self.augment:
|
||||
label["dataset"] = self
|
||||
return self.transforms(label)
|
||||
return self.transforms(self.get_label_info(index))
|
||||
|
||||
def get_label_info(self, index):
|
||||
label = self.labels[index].copy()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue