ultralytics 8.0.50 AMP check and YOLOv5u YAMLs (#1263)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Troy <wudashuo@vip.qq.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Huijae Lee <46982469+ZeroAct@users.noreply.github.com>
This commit is contained in:
parent
3861e6c82a
commit
f0d8e4718b
29 changed files with 440 additions and 83 deletions
|
|
@ -25,8 +25,8 @@ from tqdm import tqdm
|
|||
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks,
|
||||
colorstr, emojis, yaml_save)
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, ONLINE, RANK, ROOT, SETTINGS, TQDM_BAR_FORMAT, __version__,
|
||||
callbacks, colorstr, emojis, yaml_save)
|
||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
|
|
@ -111,8 +111,6 @@ class BaseTrainer:
|
|||
print_args(vars(self.args))
|
||||
|
||||
# Device
|
||||
self.amp = self.device.type != 'cpu'
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
if self.device.type == 'cpu':
|
||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||
|
||||
|
|
@ -126,7 +124,7 @@ class BaseTrainer:
|
|||
if 'yaml_file' in self.data:
|
||||
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e
|
||||
raise RuntimeError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e
|
||||
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
self.ema = None
|
||||
|
|
@ -204,6 +202,8 @@ class BaseTrainer:
|
|||
ckpt = self.setup_model()
|
||||
self.model = self.model.to(self.device)
|
||||
self.set_model_attributes()
|
||||
self.amp = check_amp(self.model)
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
if world_size > 1:
|
||||
self.model = DDP(self.model, device_ids=[rank])
|
||||
# Check imgsz
|
||||
|
|
@ -597,3 +597,31 @@ class BaseTrainer:
|
|||
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
|
||||
return optimizer
|
||||
|
||||
|
||||
def check_amp(model):
|
||||
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type in ('cpu', 'mps'):
|
||||
return False # AMP only used on CUDA devices
|
||||
|
||||
def amp_allclose(m, im):
|
||||
# All close FP32 vs AMP results
|
||||
a = m(im, device=device, verbose=False)[0].boxes.boxes # FP32 inference
|
||||
with torch.cuda.amp.autocast(True):
|
||||
b = m(im, device=device, verbose=False)[0].boxes.boxes # AMP inference
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.1) # close to 10% absolute tolerance
|
||||
|
||||
f = ROOT / 'assets/bus.jpg' # image to check
|
||||
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3))
|
||||
prefix = colorstr('AMP: ')
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
|
||||
assert amp_allclose(YOLO('yolov8n.pt'), im)
|
||||
LOGGER.info(f'{prefix}checks passed ✅')
|
||||
return True
|
||||
except AssertionError:
|
||||
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
|
||||
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
|
||||
return False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue