ultralytics 8.3.31 add max_num_obj factor for AutoBatch (#17514)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-11-14 06:51:24 +08:00 committed by GitHub
parent e100484422
commit 4453ddab93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 38 additions and 12 deletions

View file

@ -141,3 +141,10 @@ class DetectionTrainer(BaseTrainer):
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
def auto_batch(self):
"""Get batch size by calculating memory occupation of model."""
train_dataset = self.build_dataset(self.trainset, mode="train", batch=16)
# 4 for mosaic augmentation
max_num_obj = max(len(l["cls"]) for l in train_dataset.labels) * 4
return super().auto_batch(max_num_obj)