Fix torch.amp has no attribute GradScaler (#14647)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
b7c90526c8
commit
1c351b5036
1 changed files with 1 additions and 6 deletions
|
|
@ -41,7 +41,6 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m
|
|||
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.utils.files import get_latest_run
|
||||
from ultralytics.utils.torch_utils import (
|
||||
TORCH_1_13,
|
||||
EarlyStopping,
|
||||
ModelEMA,
|
||||
autocast,
|
||||
|
|
@ -266,11 +265,7 @@ class BaseTrainer:
|
|||
if RANK > -1 and world_size > 1: # DDP
|
||||
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
||||
self.amp = bool(self.amp) # as boolean
|
||||
self.scaler = (
|
||||
torch.amp.GradScaler("cuda", enabled=self.amp)
|
||||
if TORCH_1_13
|
||||
else torch.cuda.amp.GradScaler(enabled=self.amp)
|
||||
)
|
||||
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
|
||||
if world_size > 1:
|
||||
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue