Store Classify Val results in INT32 on CPU (#10125)
This commit is contained in:
parent
1f4bed233a
commit
f9461d50b0
1 changed files with 2 additions and 2 deletions
|
|
@ -56,8 +56,8 @@ class ClassificationValidator(BaseValidator):
|
||||||
def update_metrics(self, preds, batch):
|
def update_metrics(self, preds, batch):
|
||||||
"""Updates running metrics with model predictions and batch targets."""
|
"""Updates running metrics with model predictions and batch targets."""
|
||||||
n5 = min(len(self.names), 5)
|
n5 = min(len(self.names), 5)
|
||||||
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
|
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
||||||
self.targets.append(batch["cls"])
|
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
||||||
|
|
||||||
def finalize_metrics(self, *args, **kwargs):
|
def finalize_metrics(self, *args, **kwargs):
|
||||||
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue