Pass args when creating validator for classification (#16025)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
93b80552fc
commit
d0dd2b9313
1 changed files with 5 additions and 1 deletions
|
|
@ -1,5 +1,7 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||||
|
|
@ -107,7 +109,9 @@ class ClassificationTrainer(BaseTrainer):
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
"""Returns an instance of ClassificationValidator for validation."""
|
"""Returns an instance of ClassificationValidator for validation."""
|
||||||
self.loss_names = ["loss"]
|
self.loss_names = ["loss"]
|
||||||
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)
|
return yolo.classify.ClassificationValidator(
|
||||||
|
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||||
|
)
|
||||||
|
|
||||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue