Fix torch.amp.autocast('cuda') warnings (#14633)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
23ce08791f
commit
0d7bf447eb
7 changed files with 51 additions and 7 deletions
|
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
|
||||
from ultralytics.utils.torch_utils import profile
|
||||
from ultralytics.utils.torch_utils import autocast, profile
|
||||
|
||||
|
||||
def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
|
||||
|
|
@ -23,7 +23,7 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
|
|||
(int): Optimal batch size computed using the autobatch() function.
|
||||
"""
|
||||
|
||||
with torch.cuda.amp.autocast(amp):
|
||||
with autocast(enabled=amp):
|
||||
return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue