Fix HUB session with DDP training (#13103)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
68720288d3
commit
169602442c
5 changed files with 11 additions and 5 deletions
|
|
@ -48,6 +48,7 @@ from ultralytics.utils.torch_utils import (
|
||||||
one_cycle,
|
one_cycle,
|
||||||
select_device,
|
select_device,
|
||||||
strip_optimizer,
|
strip_optimizer,
|
||||||
|
torch_distributed_zero_first,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -127,7 +128,8 @@ class BaseTrainer:
|
||||||
|
|
||||||
# Model and Dataset
|
# Model and Dataset
|
||||||
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||||
self.trainset, self.testset = self.get_dataset()
|
with torch_distributed_zero_first(RANK): # avoid auto-downloading dataset multiple times
|
||||||
|
self.trainset, self.testset = self.get_dataset()
|
||||||
self.ema = None
|
self.ema = None
|
||||||
|
|
||||||
# Optimization utils init
|
# Optimization utils init
|
||||||
|
|
@ -143,6 +145,9 @@ class BaseTrainer:
|
||||||
self.csv = self.save_dir / "results.csv"
|
self.csv = self.save_dir / "results.csv"
|
||||||
self.plot_idx = [0, 1, 2]
|
self.plot_idx = [0, 1, 2]
|
||||||
|
|
||||||
|
# HUB
|
||||||
|
self.hub_session = None
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
if RANK in {-1, 0}:
|
if RANK in {-1, 0}:
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ class HUBTrainingSession:
|
||||||
try:
|
try:
|
||||||
session = cls(identifier)
|
session = cls(identifier)
|
||||||
assert session.client.authenticated, "HUB not authenticated"
|
assert session.client.authenticated, "HUB not authenticated"
|
||||||
if args:
|
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
|
||||||
session.create_model(args)
|
session.create_model(args)
|
||||||
assert session.model.id, "HUB model not loaded correctly"
|
assert session.model.id, "HUB model not loaded correctly"
|
||||||
return session
|
return session
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from ultralytics.utils import LOGGER, RANK, SETTINGS
|
||||||
|
|
||||||
def on_pretrain_routine_start(trainer):
|
def on_pretrain_routine_start(trainer):
|
||||||
"""Create a remote Ultralytics HUB session to log local model training."""
|
"""Create a remote Ultralytics HUB session to log local model training."""
|
||||||
if RANK in {-1, 0} and SETTINGS["hub"] is True and not getattr(trainer, "hub_session", None):
|
if RANK in {-1, 0} and SETTINGS["hub"] is True and SETTINGS["api_key"] and trainer.hub_session is None:
|
||||||
trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args)
|
trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ if __name__ == "__main__":
|
||||||
cfg = DEFAULT_CFG_DICT.copy()
|
cfg = DEFAULT_CFG_DICT.copy()
|
||||||
cfg.update(save_dir='') # handle the extra key 'save_dir'
|
cfg.update(save_dir='') # handle the extra key 'save_dir'
|
||||||
trainer = {name}(cfg=cfg, overrides=overrides)
|
trainer = {name}(cfg=cfg, overrides=overrides)
|
||||||
|
trainer.args.model = "{getattr(trainer.hub_session, 'model_url', trainer.args.model)}"
|
||||||
results = trainer.train()
|
results = trainer.train()
|
||||||
"""
|
"""
|
||||||
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
|
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
|
||||||
|
|
|
||||||
|
|
@ -43,8 +43,8 @@ TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def torch_distributed_zero_first(local_rank: int):
|
def torch_distributed_zero_first(local_rank: int):
|
||||||
"""Decorator to make all processes in distributed training wait for each local_master to do something."""
|
"""Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
|
||||||
initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
|
initialized = dist.is_available() and dist.is_initialized()
|
||||||
if initialized and local_rank not in {-1, 0}:
|
if initialized and local_rank not in {-1, 0}:
|
||||||
dist.barrier(device_ids=[local_rank])
|
dist.barrier(device_ids=[local_rank])
|
||||||
yield
|
yield
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue