Fix torch.amp.autocast('cuda') warnings (#14633)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
23ce08791f
commit
0d7bf447eb
7 changed files with 51 additions and 7 deletions
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||
from ultralytics.utils.metrics import OKS_SIGMA
|
||||
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
||||
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
||||
from ultralytics.utils.torch_utils import autocast
|
||||
|
||||
from .metrics import bbox_iou, probiou
|
||||
from .tal import bbox2dist
|
||||
|
|
@ -27,7 +28,7 @@ class VarifocalLoss(nn.Module):
|
|||
def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
|
||||
"""Computes varfocal loss."""
|
||||
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with autocast(enabled=False):
|
||||
loss = (
|
||||
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
|
||||
.mean(1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue