Fix torch.amp.autocast('cuda') warnings (#14633)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
23ce08791f
commit
0d7bf447eb
7 changed files with 51 additions and 7 deletions
|
|
@ -68,6 +68,37 @@ def smart_inference_mode():
|
|||
return decorate
|
||||
|
||||
|
||||
def autocast(enabled: bool, device: str = "cuda"):
|
||||
"""
|
||||
Get the appropriate autocast context manager based on PyTorch version and AMP setting.
|
||||
|
||||
This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
|
||||
older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
|
||||
|
||||
Args:
|
||||
enabled (bool): Whether to enable automatic mixed precision.
|
||||
device (str, optional): The device to use for autocast. Defaults to 'cuda'.
|
||||
|
||||
Returns:
|
||||
(torch.amp.autocast): The appropriate autocast context manager.
|
||||
|
||||
Note:
|
||||
- For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
|
||||
- For older versions, it uses `torch.cuda.autocast`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
with autocast(amp=True):
|
||||
# Your mixed precision operations here
|
||||
pass
|
||||
```
|
||||
"""
|
||||
if TORCH_1_13:
|
||||
return torch.amp.autocast(device, enabled=enabled)
|
||||
else:
|
||||
return torch.cuda.amp.autocast(enabled)
|
||||
|
||||
|
||||
def get_cpu_info():
|
||||
"""Return a string with system CPU information, i.e. 'Apple M2'."""
|
||||
import cpuinfo # pip install py-cpuinfo
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue