ultralytics 8.0.53 DDP AMP and Edge TPU fixes (#1362)
Co-authored-by: Richard Aljaste <richardaljasteabramson@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Vuong Kha Sieu <75152429+hotfur@users.noreply.github.com>
This commit is contained in:
parent
177a68b39f
commit
f921e1ac21
46 changed files with 1045 additions and 384 deletions
|
|
@ -95,9 +95,9 @@ class BaseTrainer:
|
|||
self.save_dir = Path(self.args.save_dir)
|
||||
else:
|
||||
self.save_dir = Path(
|
||||
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True))
|
||||
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
|
||||
self.wdir = self.save_dir / 'weights' # weights dir
|
||||
if RANK in {-1, 0}:
|
||||
if RANK in (-1, 0):
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
self.args.save_dir = str(self.save_dir)
|
||||
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
|
||||
|
|
@ -144,7 +144,7 @@ class BaseTrainer:
|
|||
|
||||
# Callbacks
|
||||
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||
if RANK in {0, -1}:
|
||||
if RANK in (-1, 0):
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
|
|
@ -203,9 +203,14 @@ class BaseTrainer:
|
|||
self.model = self.model.to(self.device)
|
||||
self.set_model_attributes()
|
||||
# Check AMP
|
||||
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as they are reset by check_amp()
|
||||
self.amp = check_amp(self.model)
|
||||
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||
self.amp = torch.tensor(True).to(self.device)
|
||||
if RANK in (-1, 0): # Single-GPU and DDP
|
||||
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
||||
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
||||
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||
if RANK > -1: # DDP
|
||||
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
||||
self.amp = bool(self.amp) # as boolean
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
if world_size > 1:
|
||||
self.model = DDP(self.model, device_ids=[rank])
|
||||
|
|
@ -239,7 +244,7 @@ class BaseTrainer:
|
|||
# dataloaders
|
||||
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode='train')
|
||||
if rank in {0, -1}:
|
||||
if rank in (-1, 0):
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
|
||||
self.validator = self.get_validator()
|
||||
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
|
||||
|
|
@ -286,7 +291,7 @@ class BaseTrainer:
|
|||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
|
||||
if rank in {-1, 0}:
|
||||
if rank in (-1, 0):
|
||||
LOGGER.info(self.progress_string())
|
||||
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
|
||||
self.tloss = None
|
||||
|
|
@ -327,7 +332,7 @@ class BaseTrainer:
|
|||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
||||
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
|
||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||
if rank in {-1, 0}:
|
||||
if rank in (-1, 0):
|
||||
pbar.set_description(
|
||||
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
|
||||
|
|
@ -342,7 +347,7 @@ class BaseTrainer:
|
|||
self.scheduler.step()
|
||||
self.run_callbacks('on_train_epoch_end')
|
||||
|
||||
if rank in {-1, 0}:
|
||||
if rank in (-1, 0):
|
||||
|
||||
# Validation
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
|
|
@ -372,7 +377,7 @@ class BaseTrainer:
|
|||
if self.stop:
|
||||
break # must break all DDP ranks
|
||||
|
||||
if rank in {-1, 0}:
|
||||
if rank in (-1, 0):
|
||||
# Do final val with best.pt
|
||||
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
||||
|
|
@ -603,7 +608,20 @@ class BaseTrainer:
|
|||
|
||||
|
||||
def check_amp(model):
|
||||
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
|
||||
"""
|
||||
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model.
|
||||
If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
|
||||
results, so AMP will be disabled during training.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A YOLOv8 model instance.
|
||||
|
||||
Returns:
|
||||
bool: Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the AMP checks fail, indicating anomalies with the AMP functionality on the system.
|
||||
"""
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type in ('cpu', 'mps'):
|
||||
return False # AMP only used on CUDA devices
|
||||
|
|
@ -613,18 +631,21 @@ def check_amp(model):
|
|||
a = m(im, device=device, verbose=False)[0].boxes.boxes # FP32 inference
|
||||
with torch.cuda.amp.autocast(True):
|
||||
b = m(im, device=device, verbose=False)[0].boxes.boxes # AMP inference
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), rtol=0.1) # close to 10% absolute tolerance
|
||||
del m
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
||||
|
||||
f = ROOT / 'assets/bus.jpg' # image to check
|
||||
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3))
|
||||
prefix = colorstr('AMP: ')
|
||||
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
|
||||
assert amp_allclose(YOLO('yolov8n.pt'), im)
|
||||
LOGGER.info(f'{prefix}checks passed ✅')
|
||||
return True
|
||||
except ConnectionError:
|
||||
LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. Setting 'amp=True'.")
|
||||
except AssertionError:
|
||||
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
|
||||
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
|
||||
return False
|
||||
return True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue