Trainer + Dataloaders (#27)
Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ayush Chaurasia <ayushchaurasia@Ayushs-MacBook-Pro.local> Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com>
This commit is contained in:
parent
7a2e5fdfa3
commit
d0b3c9812b
27 changed files with 2885 additions and 9 deletions
37
ultralytics/yolo/data/dataset_wrappers.py
Normal file
37
ultralytics/yolo/data/dataset_wrappers.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import collections
|
||||
from copy import deepcopy
|
||||
|
||||
from .augment import LetterBox
|
||||
|
||||
|
||||
class MixAndRectDataset:
|
||||
"""A wrapper of multiple images mixed dataset.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`BaseDataset`): The dataset to be mixed.
|
||||
transforms (Sequence[dict]): config dict to be composed.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
self.img_size = dataset.img_size
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
labels = deepcopy(self.dataset[index])
|
||||
for transform in self.dataset.transforms.tolist():
|
||||
# mosaic and mixup
|
||||
if hasattr(transform, "get_indexes"):
|
||||
indexes = transform.get_indexes(self.dataset)
|
||||
if not isinstance(indexes, collections.abc.Sequence):
|
||||
indexes = [indexes]
|
||||
mix_labels = [deepcopy(self.dataset[index]) for index in indexes]
|
||||
labels["mix_labels"] = mix_labels
|
||||
if self.dataset.rect and isinstance(transform, LetterBox):
|
||||
transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
|
||||
labels = transform(labels)
|
||||
if "mix_labels" in labels:
|
||||
labels.pop("mix_labels")
|
||||
return labels
|
||||
Loading…
Add table
Add a link
Reference in a new issue