Patch torch.load(..., weights_only=False) to reduce warnings (#14638)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-07-23 20:35:58 +02:00 committed by GitHub
parent 72466b9648
commit fb20867262
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 3 deletions

View file

@ -1066,8 +1066,9 @@ TESTS_RUNNING = is_pytest_running() or is_github_action_running()
set_sentry()
# Apply monkey patches
from ultralytics.utils.patches import imread, imshow, imwrite, torch_save
from ultralytics.utils.patches import imread, imshow, imwrite, torch_load, torch_save
torch.load = torch_load
torch.save = torch_save
if WINDOWS:
# Apply cv2 patches for non-ASCII and non-UTF characters in image paths

View file

@ -57,7 +57,33 @@ def imshow(winname: str, mat: np.ndarray):
# PyTorch functions ----------------------------------------------------------------------------------------------------
_torch_save = torch.save # copy to avoid recursion errors
_torch_load = torch.load # copy to avoid recursion errors
_torch_save = torch.save
def torch_load(*args, **kwargs):
"""
Load a PyTorch model with updated arguments to avoid warnings.
This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
Args:
*args (Any): Variable length argument list to pass to torch.load.
**kwargs (Any): Arbitrary keyword arguments to pass to torch.load.
Returns:
(Any): The loaded PyTorch object.
Note:
For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
if the argument is not provided, to avoid deprecation warnings.
"""
from ultralytics.utils.torch_utils import TORCH_1_13
if TORCH_1_13 and "weights_only" not in kwargs:
kwargs["weights_only"] = False
return _torch_load(*args, **kwargs)
def torch_save(*args, use_dill=True, **kwargs):
@ -68,7 +94,7 @@ def torch_save(*args, use_dill=True, **kwargs):
Args:
*args (tuple): Positional arguments to pass to torch.save.
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
**kwargs (any): Keyword arguments to pass to torch.save.
**kwargs (Any): Keyword arguments to pass to torch.save.
"""
try:
assert use_dill