diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index 3a8201a5..fe858eb0 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -669,8 +669,22 @@ def check_amp(model): from ultralytics.utils.torch_utils import autocast device = next(model.parameters()).device # get model device + prefix = colorstr("AMP: ") if device.type in {"cpu", "mps"}: return False # AMP only used on CUDA devices + else: + # GPUs that have issues with AMP + pattern = re.compile( + r"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)", re.IGNORECASE + ) + + gpu = torch.cuda.get_device_name(device) + if bool(pattern.search(gpu)): + LOGGER.warning( + f"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) + return False def amp_allclose(m, im): """All close FP32 vs AMP results.""" @@ -683,7 +697,6 @@ def check_amp(model): return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance im = ASSETS / "bus.jpg" # image to check - prefix = colorstr("AMP: ") LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...") warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." try: