Fix torch.amp.autocast('cuda') warnings (#14633)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
23ce08791f
commit
0d7bf447eb
7 changed files with 51 additions and 7 deletions
|
|
@ -641,6 +641,8 @@ def check_amp(model):
|
|||
Returns:
|
||||
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
|
||||
"""
|
||||
from ultralytics.utils.torch_utils import autocast
|
||||
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type in {"cpu", "mps"}:
|
||||
return False # AMP only used on CUDA devices
|
||||
|
|
@ -648,7 +650,7 @@ def check_amp(model):
|
|||
def amp_allclose(m, im):
|
||||
"""All close FP32 vs AMP results."""
|
||||
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
|
||||
with torch.cuda.amp.autocast(True):
|
||||
with autocast(enabled=True):
|
||||
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
|
||||
del m
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue