Fix torch.amp.autocast('cuda') warnings (#14633)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2024-07-23 21:58:39 +02:00 committed by GitHub
parent 23ce08791f
commit 0d7bf447eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 51 additions and 7 deletions

View file

@ -41,8 +41,10 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import (
TORCH_1_13,
EarlyStopping,
ModelEMA,
autocast,
convert_optimizer_state_dict_to_fp16,
init_seeds,
one_cycle,
@ -264,7 +266,11 @@ class BaseTrainer:
if RANK > -1 and world_size > 1: # DDP
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
self.amp = bool(self.amp) # as boolean
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
self.scaler = (
torch.amp.GradScaler("cuda", enabled=self.amp)
if TORCH_1_13
else torch.cuda.amp.GradScaler(enabled=self.amp)
)
if world_size > 1:
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
@ -376,7 +382,7 @@ class BaseTrainer:
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
# Forward
with torch.cuda.amp.autocast(self.amp):
with autocast(self.amp):
batch = self.preprocess_batch(batch)
self.loss, self.loss_items = self.model(batch)
if RANK != -1: