Store Classify Val results in INT32 on CPU (#10125)
This commit is contained in:
parent
1f4bed233a
commit
f9461d50b0
1 changed files with 2 additions and 2 deletions
|
|
@ -56,8 +56,8 @@ class ClassificationValidator(BaseValidator):
|
|||
def update_metrics(self, preds, batch):
|
||||
"""Updates running metrics with model predictions and batch targets."""
|
||||
n5 = min(len(self.names), 5)
|
||||
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
|
||||
self.targets.append(batch["cls"])
|
||||
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
||||
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue