diff --git a/docs/en/reference/utils/torch_utils.md b/docs/en/reference/utils/torch_utils.md
index 6a48fec7..dd4c364d 100644
--- a/docs/en/reference/utils/torch_utils.md
+++ b/docs/en/reference/utils/torch_utils.md
@@ -27,6 +27,10 @@ keywords: Ultralytics, torch utils, model optimization, device selection, infere
+## ::: ultralytics.utils.torch_utils.autocast
+
+
+
## ::: ultralytics.utils.torch_utils.get_cpu_info
diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py
index 3fb3e0b8..4415ba94 100644
--- a/ultralytics/engine/trainer.py
+++ b/ultralytics/engine/trainer.py
@@ -41,8 +41,10 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import (
+ TORCH_1_13,
EarlyStopping,
ModelEMA,
+ autocast,
convert_optimizer_state_dict_to_fp16,
init_seeds,
one_cycle,
@@ -264,7 +266,11 @@ class BaseTrainer:
if RANK > -1 and world_size > 1: # DDP
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
self.amp = bool(self.amp) # as boolean
- self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
+ self.scaler = (
+ torch.amp.GradScaler("cuda", enabled=self.amp)
+ if TORCH_1_13
+ else torch.cuda.amp.GradScaler(enabled=self.amp)
+ )
if world_size > 1:
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
@@ -376,7 +382,7 @@ class BaseTrainer:
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
# Forward
- with torch.cuda.amp.autocast(self.amp):
+ with autocast(self.amp):
batch = self.preprocess_batch(batch)
self.loss, self.loss_items = self.model(batch)
if RANK != -1:
diff --git a/ultralytics/models/utils/ops.py b/ultralytics/models/utils/ops.py
index 4f66feef..64d10e36 100644
--- a/ultralytics/models/utils/ops.py
+++ b/ultralytics/models/utils/ops.py
@@ -133,7 +133,7 @@ class HungarianMatcher(nn.Module):
# sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
# tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
#
- # with torch.cuda.amp.autocast(False):
+ # with torch.amp.autocast("cuda", enabled=False):
# # binary cross entropy cost
# pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
# neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
diff --git a/ultralytics/utils/autobatch.py b/ultralytics/utils/autobatch.py
index 2f695df8..784210c5 100644
--- a/ultralytics/utils/autobatch.py
+++ b/ultralytics/utils/autobatch.py
@@ -7,7 +7,7 @@ import numpy as np
import torch
from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
-from ultralytics.utils.torch_utils import profile
+from ultralytics.utils.torch_utils import autocast, profile
def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
@@ -23,7 +23,7 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
(int): Optimal batch size computed using the autobatch() function.
"""
- with torch.cuda.amp.autocast(amp):
+ with autocast(enabled=amp):
return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)
diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py
index dfd79228..d94e157f 100644
--- a/ultralytics/utils/checks.py
+++ b/ultralytics/utils/checks.py
@@ -641,6 +641,8 @@ def check_amp(model):
Returns:
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
"""
+ from ultralytics.utils.torch_utils import autocast
+
device = next(model.parameters()).device # get model device
if device.type in {"cpu", "mps"}:
return False # AMP only used on CUDA devices
@@ -648,7 +650,7 @@ def check_amp(model):
def amp_allclose(m, im):
"""All close FP32 vs AMP results."""
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
- with torch.cuda.amp.autocast(True):
+ with autocast(enabled=True):
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py
index 3c3d3b71..15bf92f9 100644
--- a/ultralytics/utils/loss.py
+++ b/ultralytics/utils/loss.py
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from ultralytics.utils.metrics import OKS_SIGMA
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
+from ultralytics.utils.torch_utils import autocast
from .metrics import bbox_iou, probiou
from .tal import bbox2dist
@@ -27,7 +28,7 @@ class VarifocalLoss(nn.Module):
def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
"""Computes varfocal loss."""
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
- with torch.cuda.amp.autocast(enabled=False):
+ with autocast(enabled=False):
loss = (
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
.mean(1)
diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py
index 21973d7e..fcecd148 100644
--- a/ultralytics/utils/torch_utils.py
+++ b/ultralytics/utils/torch_utils.py
@@ -68,6 +68,37 @@ def smart_inference_mode():
return decorate
+def autocast(enabled: bool, device: str = "cuda"):
+ """
+ Get the appropriate autocast context manager based on PyTorch version and AMP setting.
+
+ This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
+ older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
+
+ Args:
+ enabled (bool): Whether to enable automatic mixed precision.
+ device (str, optional): The device to use for autocast. Defaults to 'cuda'.
+
+ Returns:
+ (torch.amp.autocast): The appropriate autocast context manager.
+
+ Note:
+ - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
+ - For older versions, it uses `torch.cuda.autocast`.
+
+ Example:
+ ```python
+ with autocast(amp=True):
+ # Your mixed precision operations here
+ pass
+ ```
+ """
+ if TORCH_1_13:
+ return torch.amp.autocast(device, enabled=enabled)
+ else:
+ return torch.cuda.amp.autocast(enabled)
+
+
def get_cpu_info():
"""Return a string with system CPU information, i.e. 'Apple M2'."""
import cpuinfo # pip install py-cpuinfo