From 03e0b1033cdd097651188321a16a33b315feb908 Mon Sep 17 00:00:00 2001 From: alanZee <92136487+alanZee@users.noreply.github.com> Date: Tue, 13 Aug 2024 04:53:48 +0800 Subject: [PATCH] Improve trainer DDP device handling (#15383) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- ultralytics/engine/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 47063466..6ebe7536 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -174,9 +174,11 @@ class BaseTrainer: world_size = len(self.args.device.split(",")) elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list) world_size = len(self.args.device) + elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps' + world_size = 0 elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number world_size = 1 # default to device 0 - else: # i.e. device='cpu' or 'mps' + else: # i.e. device=None or device='' world_size = 0 # Run subprocess if DDP training, else train normally