diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 58e8071d..9b280e38 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -168,9 +168,11 @@ class BaseTrainer: def train(self): """Allow device='', device=None on Multi-GPU systems to default to device=0.""" - if isinstance(self.args.device, int) or self.args.device: # i.e. device=0 or device=[0,1,2,3] - world_size = torch.cuda.device_count() - elif torch.cuda.is_available(): # i.e. device=None or device='' + if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3' + world_size = len(self.args.device.split(',')) + elif isinstance(self.args.device, tuple): # multi devices from cli is tuple type + world_size = len(self.args.device) + elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number world_size = 1 # default to device 0 else: # i.e. device='cpu' or 'mps' world_size = 0