ultralytics 8.0.81 single-line docstring updates (#2061)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
5bce1c3021
commit
a38f227672
64 changed files with 620 additions and 58 deletions
|
|
@ -12,9 +12,11 @@ class VarifocalLoss(nn.Module):
|
|||
"""Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the VarifocalLoss class."""
|
||||
super().__init__()
|
||||
|
||||
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
|
||||
"""Computes varfocal loss."""
|
||||
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
|
||||
|
|
@ -25,6 +27,7 @@ class VarifocalLoss(nn.Module):
|
|||
class BboxLoss(nn.Module):
|
||||
|
||||
def __init__(self, reg_max, use_dfl=False):
|
||||
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
||||
super().__init__()
|
||||
self.reg_max = reg_max
|
||||
self.use_dfl = use_dfl
|
||||
|
|
@ -64,6 +67,7 @@ class KeypointLoss(nn.Module):
|
|||
self.sigmas = sigmas
|
||||
|
||||
def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
|
||||
"""Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
|
||||
d = (pred_kpts[..., 0] - gt_kpts[..., 0]) ** 2 + (pred_kpts[..., 1] - gt_kpts[..., 1]) ** 2
|
||||
kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0)) / (torch.sum(kpt_mask != 0) + 1e-9)
|
||||
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue