Improve trainer DDP device handling (#15383)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
592f421663
commit
03e0b1033c
1 changed files with 3 additions and 1 deletions
|
|
@ -174,9 +174,11 @@ class BaseTrainer:
|
|||
world_size = len(self.args.device.split(","))
|
||||
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
||||
world_size = len(self.args.device)
|
||||
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
|
||||
world_size = 0
|
||||
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
||||
world_size = 1 # default to device 0
|
||||
else: # i.e. device='cpu' or 'mps'
|
||||
else: # i.e. device=None or device=''
|
||||
world_size = 0
|
||||
|
||||
# Run subprocess if DDP training, else train normally
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue