Fix Classification train logging (#157)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
d387359f74
commit
e79ea1666c
7 changed files with 86 additions and 40 deletions
|
|
@ -565,14 +565,8 @@ class SegmentMetrics:
|
|||
@property
|
||||
def keys(self):
|
||||
return [
|
||||
"metrics/precision(B)",
|
||||
"metrics/recall(B)",
|
||||
"metrics/mAP50(B)",
|
||||
"metrics/mAP50-95(B)", # metrics
|
||||
"metrics/precision(M)",
|
||||
"metrics/recall(M)",
|
||||
"metrics/mAP50(M)",
|
||||
"metrics/mAP50-95(M)"]
|
||||
"metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)",
|
||||
"metrics/precision(M)", "metrics/recall(M)", "metrics/mAP50(M)", "metrics/mAP50-95(M)"]
|
||||
|
||||
def mean_results(self):
|
||||
return self.metric_box.mean_results() + self.metric_mask.mean_results()
|
||||
|
|
@ -603,7 +597,10 @@ class ClassifyMetrics:
|
|||
self.top1 = 0
|
||||
self.top5 = 0
|
||||
|
||||
def process(self, correct):
|
||||
def process(self, targets, pred):
|
||||
# target classes and predicted classes
|
||||
pred, targets = torch.cat(pred), torch.cat(targets)
|
||||
correct = (targets[:, None] == pred).float()
|
||||
acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
|
||||
self.top1, self.top5 = acc.mean(0).tolist()
|
||||
|
||||
|
|
@ -617,4 +614,4 @@ class ClassifyMetrics:
|
|||
|
||||
@property
|
||||
def keys(self):
|
||||
return ["top1", "top5"]
|
||||
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue