ultralytics 8.3.1 update AMP checks for YOLO11n (#16560)

This commit is contained in:
Glenn Jocher 2024-09-30 16:42:40 +02:00 committed by GitHub
parent 3e896eae13
commit 7a6c76d16c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 9 deletions

View file

@ -629,24 +629,24 @@ def collect_system_info():
def check_amp(model):
"""
Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks fail, it means
Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model. If the checks fail, it means
there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled
during training.
Args:
model (nn.Module): A YOLOv8 model instance.
model (nn.Module): A YOLO11 model instance.
Example:
```python
from ultralytics import YOLO
from ultralytics.utils.checks import check_amp
model = YOLO("yolov8n.pt").model.cuda()
model = YOLO("yolo11n.pt").model.cuda()
check_amp(model)
```
Returns:
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
(bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.
"""
from ultralytics.utils.torch_utils import autocast
@ -665,19 +665,19 @@ def check_amp(model):
im = ASSETS / "bus.jpg" # image to check
prefix = colorstr("AMP: ")
LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...")
LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLO11n...")
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
try:
from ultralytics import YOLO
assert amp_allclose(YOLO("yolov8n.pt"), im)
assert amp_allclose(YOLO("yolo11n.pt"), im)
LOGGER.info(f"{prefix}checks passed ✅")
except ConnectionError:
LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}")
LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLO11n. {warning_msg}")
except (AttributeError, ModuleNotFoundError):
LOGGER.warning(
f"{prefix}checks skipped ⚠️. "
f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}"
f"Unable to load YOLO11n due to possible Ultralytics package modifications. {warning_msg}"
)
except AssertionError:
LOGGER.warning(