Pass callbacks to validator (#7320)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Thomas de Lange 2024-01-05 09:08:17 +00:00 committed by GitHub
parent 072291bc78
commit 2f9ec8c0b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 4 deletions

View file

@ -49,7 +49,10 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
def get_validator(self):
"""Returns an instance of the PoseValidator class for validation."""
self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
return yolo.pose.PoseValidator(self.test_loader,
save_dir=self.save_dir,
args=copy(self.args),
_callbacks=self.callbacks)
def plot_training_samples(self, batch, ni):
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""