Pass args when creating validator for classification (#16025)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
93b80552fc
commit
d0dd2b9313
1 changed files with 5 additions and 1 deletions
|
|
@ -1,5 +1,7 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
|
|
@ -107,7 +109,9 @@ class ClassificationTrainer(BaseTrainer):
|
|||
def get_validator(self):
|
||||
"""Returns an instance of ClassificationValidator for validation."""
|
||||
self.loss_names = ["loss"]
|
||||
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)
|
||||
return yolo.classify.ClassificationValidator(
|
||||
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue