From f9461d50b09e1ad124683d14c1bbd51553f9d7ae Mon Sep 17 00:00:00 2001 From: bobyard-com <154289614+bobyard-com@users.noreply.github.com> Date: Wed, 17 Apr 2024 19:39:44 -0700 Subject: [PATCH] Store Classify Val results in INT32 on CPU (#10125) --- ultralytics/models/yolo/classify/val.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ultralytics/models/yolo/classify/val.py b/ultralytics/models/yolo/classify/val.py index de3cff2b..98d5f788 100644 --- a/ultralytics/models/yolo/classify/val.py +++ b/ultralytics/models/yolo/classify/val.py @@ -56,8 +56,8 @@ class ClassificationValidator(BaseValidator): def update_metrics(self, preds, batch): """Updates running metrics with model predictions and batch targets.""" n5 = min(len(self.names), 5) - self.pred.append(preds.argsort(1, descending=True)[:, :n5]) - self.targets.append(batch["cls"]) + self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu()) + self.targets.append(batch["cls"].type(torch.int32).cpu()) def finalize_metrics(self, *args, **kwargs): """Finalizes metrics of the model such as confusion_matrix and speed."""