Buffered Mosaic for reduced HDD reads (#2791)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
dada5b73c4
commit
07b57c03c8
3 changed files with 48 additions and 19 deletions
|
|
@ -80,7 +80,7 @@ class BaseDataset(Dataset):
|
|||
# Cache stuff
|
||||
if cache == 'ram' and not self.check_cache_ram():
|
||||
cache = False
|
||||
self.ims = [None] * self.ni
|
||||
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
||||
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
|
||||
if cache:
|
||||
self.cache_images(cache)
|
||||
|
|
@ -88,6 +88,10 @@ class BaseDataset(Dataset):
|
|||
# Transforms
|
||||
self.transforms = self.build_transforms(hyp=hyp)
|
||||
|
||||
# Buffer thread for mosaic images
|
||||
self.buffer = [] # buffer size = batch size
|
||||
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
||||
|
||||
def get_img_files(self, img_path):
|
||||
"""Read image files."""
|
||||
try:
|
||||
|
|
@ -147,13 +151,22 @@ class BaseDataset(Dataset):
|
|||
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
|
||||
im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)),
|
||||
interpolation=interp)
|
||||
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
||||
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
|
||||
|
||||
# Add to buffer if training with augmentations
|
||||
if self.augment:
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
||||
self.buffer.append(i)
|
||||
if len(self.buffer) >= self.max_buffer_length:
|
||||
j = self.buffer.pop(0)
|
||||
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
|
||||
|
||||
return im, (h0, w0), im.shape[:2]
|
||||
|
||||
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
||||
|
||||
def cache_images(self, cache):
|
||||
"""Cache images to memory or disk."""
|
||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
||||
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(fcn, range(self.ni))
|
||||
|
|
@ -218,9 +231,9 @@ class BaseDataset(Dataset):
|
|||
|
||||
def __getitem__(self, index):
|
||||
"""Returns transformed label information for given index."""
|
||||
return self.transforms(self.get_label_info(index))
|
||||
return self.transforms(self.get_image_and_label(index))
|
||||
|
||||
def get_label_info(self, index):
|
||||
def get_image_and_label(self, index):
|
||||
"""Get and return label information from the dataset."""
|
||||
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
|
||||
label.pop('shape', None) # shape is for rect, remove it
|
||||
|
|
@ -229,8 +242,7 @@ class BaseDataset(Dataset):
|
|||
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
|
||||
if self.rect:
|
||||
label['rect_shape'] = self.batch_shapes[self.batch[index]]
|
||||
label = self.update_labels_info(label)
|
||||
return label
|
||||
return self.update_labels_info(label)
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the length of the labels list for the dataset."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue