Store Classify Val results in INT32 on CPU (#10125)

This commit is contained in:
bobyard-com 2024-04-17 19:39:44 -07:00 committed by GitHub
parent 1f4bed233a
commit f9461d50b0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -56,8 +56,8 @@ class ClassificationValidator(BaseValidator):
def update_metrics(self, preds, batch):
"""Updates running metrics with model predictions and batch targets."""
n5 = min(len(self.names), 5)
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
self.targets.append(batch["cls"])
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
self.targets.append(batch["cls"].type(torch.int32).cpu())
def finalize_metrics(self, *args, **kwargs):
"""Finalizes metrics of the model such as confusion_matrix and speed."""