diff --git a/ultralytics/models/yolo/classify/val.py b/ultralytics/models/yolo/classify/val.py index de3cff2b..98d5f788 100644 --- a/ultralytics/models/yolo/classify/val.py +++ b/ultralytics/models/yolo/classify/val.py @@ -56,8 +56,8 @@ class ClassificationValidator(BaseValidator): def update_metrics(self, preds, batch): """Updates running metrics with model predictions and batch targets.""" n5 = min(len(self.names), 5) - self.pred.append(preds.argsort(1, descending=True)[:, :n5]) - self.targets.append(batch["cls"]) + self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu()) + self.targets.append(batch["cls"].type(torch.int32).cpu()) def finalize_metrics(self, *args, **kwargs): """Finalizes metrics of the model such as confusion_matrix and speed."""