Patch torch.load(..., weights_only=False) to reduce warnings (#14638)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
72466b9648
commit
fb20867262
3 changed files with 34 additions and 3 deletions
|
|
@ -1066,8 +1066,9 @@ TESTS_RUNNING = is_pytest_running() or is_github_action_running()
|
|||
set_sentry()
|
||||
|
||||
# Apply monkey patches
|
||||
from ultralytics.utils.patches import imread, imshow, imwrite, torch_save
|
||||
from ultralytics.utils.patches import imread, imshow, imwrite, torch_load, torch_save
|
||||
|
||||
torch.load = torch_load
|
||||
torch.save = torch_save
|
||||
if WINDOWS:
|
||||
# Apply cv2 patches for non-ASCII and non-UTF characters in image paths
|
||||
|
|
|
|||
|
|
@ -57,7 +57,33 @@ def imshow(winname: str, mat: np.ndarray):
|
|||
|
||||
|
||||
# PyTorch functions ----------------------------------------------------------------------------------------------------
|
||||
_torch_save = torch.save # copy to avoid recursion errors
|
||||
_torch_load = torch.load # copy to avoid recursion errors
|
||||
_torch_save = torch.save
|
||||
|
||||
|
||||
def torch_load(*args, **kwargs):
|
||||
"""
|
||||
Load a PyTorch model with updated arguments to avoid warnings.
|
||||
|
||||
This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
|
||||
|
||||
Args:
|
||||
*args (Any): Variable length argument list to pass to torch.load.
|
||||
**kwargs (Any): Arbitrary keyword arguments to pass to torch.load.
|
||||
|
||||
Returns:
|
||||
(Any): The loaded PyTorch object.
|
||||
|
||||
Note:
|
||||
For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
|
||||
if the argument is not provided, to avoid deprecation warnings.
|
||||
"""
|
||||
from ultralytics.utils.torch_utils import TORCH_1_13
|
||||
|
||||
if TORCH_1_13 and "weights_only" not in kwargs:
|
||||
kwargs["weights_only"] = False
|
||||
|
||||
return _torch_load(*args, **kwargs)
|
||||
|
||||
|
||||
def torch_save(*args, use_dill=True, **kwargs):
|
||||
|
|
@ -68,7 +94,7 @@ def torch_save(*args, use_dill=True, **kwargs):
|
|||
Args:
|
||||
*args (tuple): Positional arguments to pass to torch.save.
|
||||
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
|
||||
**kwargs (any): Keyword arguments to pass to torch.save.
|
||||
**kwargs (Any): Keyword arguments to pass to torch.save.
|
||||
"""
|
||||
try:
|
||||
assert use_dill
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue