diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py index cd084f3e..cf715e81 100644 --- a/ultralytics/data/augment.py +++ b/ultralytics/data/augment.py @@ -2221,7 +2221,7 @@ class RandomLoadText: pos_labels = np.unique(cls).tolist() if len(pos_labels) > self.max_samples: - pos_labels = set(random.sample(pos_labels, k=self.max_samples)) + pos_labels = random.sample(pos_labels, k=self.max_samples) neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples)) neg_labels = [i for i in range(num_classes) if i not in pos_labels]