Fix torch.amp.autocast('cuda') warnings (#14633)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2024-07-23 21:58:39 +02:00 committed by GitHub
parent 23ce08791f
commit 0d7bf447eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 51 additions and 7 deletions

View file

@ -68,6 +68,37 @@ def smart_inference_mode():
return decorate
def autocast(enabled: bool, device: str = "cuda"):
"""
Get the appropriate autocast context manager based on PyTorch version and AMP setting.
This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
Args:
enabled (bool): Whether to enable automatic mixed precision.
device (str, optional): The device to use for autocast. Defaults to 'cuda'.
Returns:
(torch.amp.autocast): The appropriate autocast context manager.
Note:
- For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
- For older versions, it uses `torch.cuda.autocast`.
Example:
```python
with autocast(amp=True):
# Your mixed precision operations here
pass
```
"""
if TORCH_1_13:
return torch.amp.autocast(device, enabled=enabled)
else:
return torch.cuda.amp.autocast(enabled)
def get_cpu_info():
"""Return a string with system CPU information, i.e. 'Apple M2'."""
import cpuinfo # pip install py-cpuinfo