Fix torch.amp.autocast('cuda') warnings (#14633)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
23ce08791f
commit
0d7bf447eb
7 changed files with 51 additions and 7 deletions
|
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
|
||||
from ultralytics.utils.torch_utils import profile
|
||||
from ultralytics.utils.torch_utils import autocast, profile
|
||||
|
||||
|
||||
def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
|
||||
|
|
@ -23,7 +23,7 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
|
|||
(int): Optimal batch size computed using the autobatch() function.
|
||||
"""
|
||||
|
||||
with torch.cuda.amp.autocast(amp):
|
||||
with autocast(enabled=amp):
|
||||
return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -641,6 +641,8 @@ def check_amp(model):
|
|||
Returns:
|
||||
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
|
||||
"""
|
||||
from ultralytics.utils.torch_utils import autocast
|
||||
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type in {"cpu", "mps"}:
|
||||
return False # AMP only used on CUDA devices
|
||||
|
|
@ -648,7 +650,7 @@ def check_amp(model):
|
|||
def amp_allclose(m, im):
|
||||
"""All close FP32 vs AMP results."""
|
||||
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
|
||||
with torch.cuda.amp.autocast(True):
|
||||
with autocast(enabled=True):
|
||||
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
|
||||
del m
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||
from ultralytics.utils.metrics import OKS_SIGMA
|
||||
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
||||
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
||||
from ultralytics.utils.torch_utils import autocast
|
||||
|
||||
from .metrics import bbox_iou, probiou
|
||||
from .tal import bbox2dist
|
||||
|
|
@ -27,7 +28,7 @@ class VarifocalLoss(nn.Module):
|
|||
def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
|
||||
"""Computes varfocal loss."""
|
||||
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with autocast(enabled=False):
|
||||
loss = (
|
||||
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
|
||||
.mean(1)
|
||||
|
|
|
|||
|
|
@ -68,6 +68,37 @@ def smart_inference_mode():
|
|||
return decorate
|
||||
|
||||
|
||||
def autocast(enabled: bool, device: str = "cuda"):
|
||||
"""
|
||||
Get the appropriate autocast context manager based on PyTorch version and AMP setting.
|
||||
|
||||
This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
|
||||
older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
|
||||
|
||||
Args:
|
||||
enabled (bool): Whether to enable automatic mixed precision.
|
||||
device (str, optional): The device to use for autocast. Defaults to 'cuda'.
|
||||
|
||||
Returns:
|
||||
(torch.amp.autocast): The appropriate autocast context manager.
|
||||
|
||||
Note:
|
||||
- For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
|
||||
- For older versions, it uses `torch.cuda.autocast`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
with autocast(amp=True):
|
||||
# Your mixed precision operations here
|
||||
pass
|
||||
```
|
||||
"""
|
||||
if TORCH_1_13:
|
||||
return torch.amp.autocast(device, enabled=enabled)
|
||||
else:
|
||||
return torch.cuda.amp.autocast(enabled)
|
||||
|
||||
|
||||
def get_cpu_info():
|
||||
"""Return a string with system CPU information, i.e. 'Apple M2'."""
|
||||
import cpuinfo # pip install py-cpuinfo
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue