Fix dataloader (#32)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
92c60758dd
commit
2e9b18ce4e
8 changed files with 404 additions and 41 deletions
|
|
@ -132,7 +132,12 @@ class YOLODataset(BaseDataset):
|
|||
transforms = affine_transforms(self.img_size, hyp)
|
||||
else:
|
||||
transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))])
|
||||
transforms.append(Format(bbox_format="xywh", normalize=True, mask=self.use_segments, batch_idx=True))
|
||||
transforms.append(
|
||||
Format(bbox_format="xywh",
|
||||
normalize=True,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
batch_idx=True))
|
||||
return transforms
|
||||
|
||||
def update_labels_info(self, label):
|
||||
|
|
@ -140,7 +145,7 @@ class YOLODataset(BaseDataset):
|
|||
# NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label
|
||||
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
||||
bboxes = label.pop("bboxes")
|
||||
segments = label.pop("segments", None)
|
||||
segments = label.pop("segments")
|
||||
keypoints = label.pop("keypoints", None)
|
||||
bbox_format = label.pop("bbox_format")
|
||||
normalized = label.pop("normalized")
|
||||
|
|
@ -158,9 +163,9 @@ class YOLODataset(BaseDataset):
|
|||
value = values[i]
|
||||
if k == "img":
|
||||
value = torch.stack(value, 0)
|
||||
if k in ["mask", "keypoint", "bboxes", "cls"]:
|
||||
if k in ["masks", "keypoints", "bboxes", "cls"]:
|
||||
value = torch.cat(value, 0)
|
||||
new_batch[k] = values[i]
|
||||
new_batch[k] = value
|
||||
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||||
for i in range(len(new_batch["batch_idx"])):
|
||||
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue