ultralytics 8.0.216 fix hard-coded batch=64 cls loss (#6523)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: HDW AI group <huzhongshan@gmail.com>
This commit is contained in:
Glenn Jocher 2023-11-22 21:16:55 +01:00 committed by GitHub
parent 16a13a1ce0
commit 10f6ac5e9b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 115 additions and 6 deletions

View file

@ -523,6 +523,6 @@ class v8ClassificationLoss:
def __call__(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / 64
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='mean')
loss_items = loss.detach()
return loss, loss_items