Buffered Mosaic for reduced HDD reads (#2791)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-05-25 00:42:13 +02:00 committed by GitHub
parent dada5b73c4
commit 07b57c03c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 19 deletions

View file

@ -80,7 +80,7 @@ class BaseDataset(Dataset):
# Cache stuff
if cache == 'ram' and not self.check_cache_ram():
cache = False
self.ims = [None] * self.ni
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
if cache:
self.cache_images(cache)
@ -88,6 +88,10 @@ class BaseDataset(Dataset):
# Transforms
self.transforms = self.build_transforms(hyp=hyp)
# Buffer thread for mosaic images
self.buffer = [] # buffer size = batch size
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
def get_img_files(self, img_path):
"""Read image files."""
try:
@ -147,13 +151,22 @@ class BaseDataset(Dataset):
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)),
interpolation=interp)
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
# Add to buffer if training with augmentations
if self.augment:
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
self.buffer.append(i)
if len(self.buffer) >= self.max_buffer_length:
j = self.buffer.pop(0)
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
return im, (h0, w0), im.shape[:2]
return self.ims[i], self.im_hw0[i], self.im_hw[i]
def cache_images(self, cache):
"""Cache images to memory or disk."""
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(fcn, range(self.ni))
@ -218,9 +231,9 @@ class BaseDataset(Dataset):
def __getitem__(self, index):
"""Returns transformed label information for given index."""
return self.transforms(self.get_label_info(index))
return self.transforms(self.get_image_and_label(index))
def get_label_info(self, index):
def get_image_and_label(self, index):
"""Get and return label information from the dataset."""
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
label.pop('shape', None) # shape is for rect, remove it
@ -229,8 +242,7 @@ class BaseDataset(Dataset):
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
if self.rect:
label['rect_shape'] = self.batch_shapes[self.batch[index]]
label = self.update_labels_info(label)
return label
return self.update_labels_info(label)
def __len__(self):
"""Returns the length of the labels list for the dataset."""