diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py index a1e465d7..e51349fa 100644 --- a/ultralytics/models/yolo/classify/train.py +++ b/ultralytics/models/yolo/classify/train.py @@ -1,5 +1,7 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license +from copy import copy + import torch from ultralytics.data import ClassificationDataset, build_dataloader @@ -107,7 +109,9 @@ class ClassificationTrainer(BaseTrainer): def get_validator(self): """Returns an instance of ClassificationValidator for validation.""" self.loss_names = ["loss"] - return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks) + return yolo.classify.ClassificationValidator( + self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) def label_loss_items(self, loss_items=None, prefix="train"): """