ultralytics 8.0.162 Multi-GPU DDP fix (#4544)
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: andresinsitu <andres.rodriguez@ingenieriainsitu.com>
This commit is contained in:
parent
1db9afc2e5
commit
2bcee56e70
9 changed files with 24 additions and 14 deletions
|
|
@ -184,7 +184,7 @@ class BaseTrainer:
|
|||
# Command
|
||||
cmd, file = generate_ddp_command(world_size, self)
|
||||
try:
|
||||
LOGGER.info(f'DDP command: {cmd}')
|
||||
LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
|
||||
subprocess.run(cmd, check=True)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
@ -197,7 +197,7 @@ class BaseTrainer:
|
|||
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||
torch.cuda.set_device(RANK)
|
||||
self.device = torch.device('cuda', RANK)
|
||||
LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
|
||||
dist.init_process_group(
|
||||
'nccl' if dist.is_nccl_available() else 'gloo',
|
||||
|
|
@ -299,8 +299,7 @@ class BaseTrainer:
|
|||
self.epoch_time_start = time.time()
|
||||
self.train_time_start = time.time()
|
||||
nb = len(self.train_loader) # number of batches
|
||||
nw = max(round(self.args.warmup_epochs *
|
||||
nb), 100) if self.args.warmup_epochs > 0 else -1 # number of warmup iterations
|
||||
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
|
||||
last_opt_step = -1
|
||||
self.run_callbacks('on_train_start')
|
||||
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
|
|
@ -557,7 +556,7 @@ class BaseTrainer:
|
|||
n = len(metrics) + 1 # number of cols
|
||||
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
|
||||
with open(self.csv, 'a') as f:
|
||||
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
|
||||
f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n')
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plot and display metrics visually."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue