Tests improvements (#4616)
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
parent
4bd62a299c
commit
896da0c0a0
6 changed files with 19 additions and 15 deletions
|
|
@ -19,7 +19,8 @@ class VarifocalLoss(nn.Module):
|
|||
"""Initialize the VarifocalLoss class."""
|
||||
super().__init__()
|
||||
|
||||
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
|
||||
@staticmethod
|
||||
def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
|
||||
"""Computes varfocal loss."""
|
||||
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
|
|
@ -28,14 +29,14 @@ class VarifocalLoss(nn.Module):
|
|||
return loss
|
||||
|
||||
|
||||
# Losses
|
||||
class FocalLoss(nn.Module):
|
||||
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
|
||||
|
||||
def __init__(self, ):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, pred, label, gamma=1.5, alpha=0.25):
|
||||
@staticmethod
|
||||
def forward(pred, label, gamma=1.5, alpha=0.25):
|
||||
"""Calculates and updates confusion matrix for object detection/classification tasks."""
|
||||
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
|
||||
# p_t = torch.exp(-loss)
|
||||
|
|
@ -89,6 +90,7 @@ class BboxLoss(nn.Module):
|
|||
|
||||
|
||||
class KeypointLoss(nn.Module):
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, sigmas) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -103,8 +105,8 @@ class KeypointLoss(nn.Module):
|
|||
return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean()
|
||||
|
||||
|
||||
# Criterion class for computing Detection training losses
|
||||
class v8DetectionLoss:
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
|
||||
|
|
@ -199,8 +201,8 @@ class v8DetectionLoss:
|
|||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
|
||||
# Criterion class for computing training losses
|
||||
class v8SegmentationLoss(v8DetectionLoss):
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
super().__init__(model)
|
||||
|
|
@ -294,8 +296,8 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
||||
|
||||
|
||||
# Criterion class for computing training losses
|
||||
class v8PoseLoss(v8DetectionLoss):
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
super().__init__(model)
|
||||
|
|
@ -374,7 +376,8 @@ class v8PoseLoss(v8DetectionLoss):
|
|||
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
def kpts_decode(self, anchor_points, pred_kpts):
|
||||
@staticmethod
|
||||
def kpts_decode(anchor_points, pred_kpts):
|
||||
"""Decodes predicted keypoints to image coordinates."""
|
||||
y = pred_kpts.clone()
|
||||
y[..., :2] *= 2.0
|
||||
|
|
@ -384,6 +387,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|||
|
||||
|
||||
class v8ClassificationLoss:
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue