Tests improvements (#4616)

Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-08-29 15:38:43 +02:00 committed by GitHub
parent 4bd62a299c
commit 896da0c0a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 19 additions and 15 deletions

View file

@ -19,7 +19,8 @@ class VarifocalLoss(nn.Module):
"""Initialize the VarifocalLoss class."""
super().__init__()
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
@staticmethod
def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
"""Computes varfocal loss."""
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False):
@ -28,14 +29,14 @@ class VarifocalLoss(nn.Module):
return loss
# Losses
class FocalLoss(nn.Module):
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
def __init__(self, ):
super().__init__()
def forward(self, pred, label, gamma=1.5, alpha=0.25):
@staticmethod
def forward(pred, label, gamma=1.5, alpha=0.25):
"""Calculates and updates confusion matrix for object detection/classification tasks."""
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
# p_t = torch.exp(-loss)
@ -89,6 +90,7 @@ class BboxLoss(nn.Module):
class KeypointLoss(nn.Module):
"""Criterion class for computing training losses."""
def __init__(self, sigmas) -> None:
super().__init__()
@ -103,8 +105,8 @@ class KeypointLoss(nn.Module):
return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean()
# Criterion class for computing Detection training losses
class v8DetectionLoss:
"""Criterion class for computing training losses."""
def __init__(self, model): # model must be de-paralleled
@ -199,8 +201,8 @@ class v8DetectionLoss:
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
# Criterion class for computing training losses
class v8SegmentationLoss(v8DetectionLoss):
"""Criterion class for computing training losses."""
def __init__(self, model): # model must be de-paralleled
super().__init__(model)
@ -294,8 +296,8 @@ class v8SegmentationLoss(v8DetectionLoss):
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
# Criterion class for computing training losses
class v8PoseLoss(v8DetectionLoss):
"""Criterion class for computing training losses."""
def __init__(self, model): # model must be de-paralleled
super().__init__(model)
@ -374,7 +376,8 @@ class v8PoseLoss(v8DetectionLoss):
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def kpts_decode(self, anchor_points, pred_kpts):
@staticmethod
def kpts_decode(anchor_points, pred_kpts):
"""Decodes predicted keypoints to image coordinates."""
y = pred_kpts.clone()
y[..., :2] *= 2.0
@ -384,6 +387,7 @@ class v8PoseLoss(v8DetectionLoss):
class v8ClassificationLoss:
"""Criterion class for computing training losses."""
def __call__(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""