From ff1a04609ff4d88e7f540d9ba181b8ed47080ad1 Mon Sep 17 00:00:00 2001 From: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Date: Mon, 10 Feb 2025 07:17:08 +0800 Subject: [PATCH] Unset CUBLAS_WORKSPACE_CONFIG for non-deterministic training and inference (#19138) Co-authored-by: UltralyticsAssistant --- docs/en/reference/utils/torch_utils.md | 4 ++++ ultralytics/engine/model.py | 1 + ultralytics/engine/trainer.py | 2 ++ ultralytics/utils/__init__.py | 1 - ultralytics/utils/torch_utils.py | 11 +++++++++-- 5 files changed, 16 insertions(+), 3 deletions(-) diff --git a/docs/en/reference/utils/torch_utils.md b/docs/en/reference/utils/torch_utils.md index 8242b70a..5487ba6f 100644 --- a/docs/en/reference/utils/torch_utils.md +++ b/docs/en/reference/utils/torch_utils.md @@ -119,6 +119,10 @@ keywords: Ultralytics, torch utils, model optimization, device selection, infere



+## ::: ultralytics.utils.torch_utils.unset_deterministic + +



+ ## ::: ultralytics.utils.torch_utils.strip_optimizer



diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index e4153837..940b9a6f 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -140,6 +140,7 @@ class Model(torch.nn.Module): return # Load or create new YOLO model + __import__("os").environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # to avoid deterministic warnings if Path(model).suffix in {".yaml", ".yml"}: self._new(model, task=task, verbose=verbose) else: diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 8c3c9d77..dbf0e3cf 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -52,6 +52,7 @@ from ultralytics.utils.torch_utils import ( select_device, strip_optimizer, torch_distributed_zero_first, + unset_deterministic, ) @@ -471,6 +472,7 @@ class BaseTrainer: self.plot_metrics() self.run_callbacks("on_train_end") self._clear_memory() + unset_deterministic() self.run_callbacks("teardown") def auto_batch(self, max_num_obj=0): diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index 3afa07a9..f54348e8 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -128,7 +128,6 @@ torch.set_printoptions(linewidth=320, precision=4, profile="default") np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format}) # format short g, %precision=5 cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # for deterministic training to avoid CUDA warning os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress verbose TF compiler warnings in Colab os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" # suppress "NNPACK.cpp could not initialize NNPACK" warnings os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output when computing FLOPs diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 1f87ec79..305dda50 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -488,8 +488,15 @@ def init_seeds(seed=0, deterministic=False): else: LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.") else: - torch.use_deterministic_algorithms(False) - torch.backends.cudnn.deterministic = False + unset_deterministic() + + +def unset_deterministic(): + """Unsets all the configurations applied for deterministic training.""" + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.deterministic = False + os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None) + os.environ.pop("PYTHONHASHSEED", None) class ModelEMA: