Hardcode failing GPUs in AMP checks (#17977)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
1f5646634d
commit
5c84099a9d
1 changed files with 14 additions and 1 deletions
|
|
@ -669,8 +669,22 @@ def check_amp(model):
|
||||||
from ultralytics.utils.torch_utils import autocast
|
from ultralytics.utils.torch_utils import autocast
|
||||||
|
|
||||||
device = next(model.parameters()).device # get model device
|
device = next(model.parameters()).device # get model device
|
||||||
|
prefix = colorstr("AMP: ")
|
||||||
if device.type in {"cpu", "mps"}:
|
if device.type in {"cpu", "mps"}:
|
||||||
return False # AMP only used on CUDA devices
|
return False # AMP only used on CUDA devices
|
||||||
|
else:
|
||||||
|
# GPUs that have issues with AMP
|
||||||
|
pattern = re.compile(
|
||||||
|
r"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)", re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
gpu = torch.cuda.get_device_name(device)
|
||||||
|
if bool(pattern.search(gpu)):
|
||||||
|
LOGGER.warning(
|
||||||
|
f"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause "
|
||||||
|
f"NaN losses or zero-mAP results, so AMP will be disabled during training."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
def amp_allclose(m, im):
|
def amp_allclose(m, im):
|
||||||
"""All close FP32 vs AMP results."""
|
"""All close FP32 vs AMP results."""
|
||||||
|
|
@ -683,7 +697,6 @@ def check_amp(model):
|
||||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
||||||
|
|
||||||
im = ASSETS / "bus.jpg" # image to check
|
im = ASSETS / "bus.jpg" # image to check
|
||||||
prefix = colorstr("AMP: ")
|
|
||||||
LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...")
|
LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...")
|
||||||
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
|
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue