Fix dataloader (#32)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Laughing 2022-11-04 03:00:40 -05:00 committed by GitHub
parent 92c60758dd
commit 2e9b18ce4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 404 additions and 41 deletions

View file

@ -132,7 +132,12 @@ class YOLODataset(BaseDataset):
transforms = affine_transforms(self.img_size, hyp)
else:
transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))])
transforms.append(Format(bbox_format="xywh", normalize=True, mask=self.use_segments, batch_idx=True))
transforms.append(
Format(bbox_format="xywh",
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
batch_idx=True))
return transforms
def update_labels_info(self, label):
@ -140,7 +145,7 @@ class YOLODataset(BaseDataset):
# NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
bboxes = label.pop("bboxes")
segments = label.pop("segments", None)
segments = label.pop("segments")
keypoints = label.pop("keypoints", None)
bbox_format = label.pop("bbox_format")
normalized = label.pop("normalized")
@ -158,9 +163,9 @@ class YOLODataset(BaseDataset):
value = values[i]
if k == "img":
value = torch.stack(value, 0)
if k in ["mask", "keypoint", "bboxes", "cls"]:
if k in ["masks", "keypoints", "bboxes", "cls"]:
value = torch.cat(value, 0)
new_batch[k] = values[i]
new_batch[k] = value
new_batch["batch_idx"] = list(new_batch["batch_idx"])
for i in range(len(new_batch["batch_idx"])):
new_batch["batch_idx"][i] += i # add target image index for build_targets()