From d0dd2b9313b0d38b1ded9677eaab7af685ba410b Mon Sep 17 00:00:00 2001 From: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Date: Fri, 6 Sep 2024 21:45:59 +0800 Subject: [PATCH] Pass `args` when creating validator for classification (#16025) Co-authored-by: UltralyticsAssistant Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/models/yolo/classify/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py index a1e465d7..e51349fa 100644 --- a/ultralytics/models/yolo/classify/train.py +++ b/ultralytics/models/yolo/classify/train.py @@ -1,5 +1,7 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license +from copy import copy + import torch from ultralytics.data import ClassificationDataset, build_dataloader @@ -107,7 +109,9 @@ class ClassificationTrainer(BaseTrainer): def get_validator(self): """Returns an instance of ClassificationValidator for validation.""" self.loss_names = ["loss"] - return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks) + return yolo.classify.ClassificationValidator( + self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) def label_loss_items(self, loss_items=None, prefix="train"): """