Fix torch.cuda.amp.GradScaler warning (#15978)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
5a6db149e4
commit
7c7f456710
2 changed files with 5 additions and 1 deletions
|
|
@ -42,6 +42,7 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m
|
||||||
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
||||||
from ultralytics.utils.files import get_latest_run
|
from ultralytics.utils.files import get_latest_run
|
||||||
from ultralytics.utils.torch_utils import (
|
from ultralytics.utils.torch_utils import (
|
||||||
|
TORCH_2_4,
|
||||||
EarlyStopping,
|
EarlyStopping,
|
||||||
ModelEMA,
|
ModelEMA,
|
||||||
autocast,
|
autocast,
|
||||||
|
|
@ -265,7 +266,9 @@ class BaseTrainer:
|
||||||
if RANK > -1 and world_size > 1: # DDP
|
if RANK > -1 and world_size > 1: # DDP
|
||||||
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
||||||
self.amp = bool(self.amp) # as boolean
|
self.amp = bool(self.amp) # as boolean
|
||||||
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
|
self.scaler = (
|
||||||
|
torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
|
||||||
|
)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
|
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ except ImportError:
|
||||||
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
|
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
|
||||||
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
|
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
|
||||||
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
|
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
|
||||||
|
TORCH_2_4 = check_version(torch.__version__, "2.4.0")
|
||||||
TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
|
TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
|
||||||
TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
|
TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
|
||||||
TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
|
TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue