ultralytics 8.0.162 Multi-GPU DDP fix (#4544)

Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: andresinsitu <andres.rodriguez@ingenieriainsitu.com>
This commit is contained in:
Glenn Jocher 2023-08-24 13:13:49 +02:00 committed by GitHub
parent 1db9afc2e5
commit 2bcee56e70
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 24 additions and 14 deletions

View file

@ -184,7 +184,7 @@ class BaseTrainer:
# Command
cmd, file = generate_ddp_command(world_size, self)
try:
LOGGER.info(f'DDP command: {cmd}')
LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
subprocess.run(cmd, check=True)
except Exception as e:
raise e
@ -197,7 +197,7 @@ class BaseTrainer:
"""Initializes and sets the DistributedDataParallel parameters for training."""
torch.cuda.set_device(RANK)
self.device = torch.device('cuda', RANK)
LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
dist.init_process_group(
'nccl' if dist.is_nccl_available() else 'gloo',
@ -299,8 +299,7 @@ class BaseTrainer:
self.epoch_time_start = time.time()
self.train_time_start = time.time()
nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs *
nb), 100) if self.args.warmup_epochs > 0 else -1 # number of warmup iterations
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
last_opt_step = -1
self.run_callbacks('on_train_start')
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
@ -557,7 +556,7 @@ class BaseTrainer:
n = len(metrics) + 1 # number of cols
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
with open(self.csv, 'a') as f:
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n')
def plot_metrics(self):
"""Plot and display metrics visually."""