Patch torch.load(..., weights_only=False) to reduce warnings (#14638)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
72466b9648
commit
fb20867262
3 changed files with 34 additions and 3 deletions
|
|
@ -23,6 +23,10 @@ keywords: Ultralytics, utils, patches, imread, imwrite, imshow, torch_save, Open
|
||||||
|
|
||||||
<br><br><hr><br>
|
<br><br><hr><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.utils.patches.torch_load
|
||||||
|
|
||||||
|
<br><br><hr><br>
|
||||||
|
|
||||||
## ::: ultralytics.utils.patches.torch_save
|
## ::: ultralytics.utils.patches.torch_save
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
|
||||||
|
|
@ -1066,8 +1066,9 @@ TESTS_RUNNING = is_pytest_running() or is_github_action_running()
|
||||||
set_sentry()
|
set_sentry()
|
||||||
|
|
||||||
# Apply monkey patches
|
# Apply monkey patches
|
||||||
from ultralytics.utils.patches import imread, imshow, imwrite, torch_save
|
from ultralytics.utils.patches import imread, imshow, imwrite, torch_load, torch_save
|
||||||
|
|
||||||
|
torch.load = torch_load
|
||||||
torch.save = torch_save
|
torch.save = torch_save
|
||||||
if WINDOWS:
|
if WINDOWS:
|
||||||
# Apply cv2 patches for non-ASCII and non-UTF characters in image paths
|
# Apply cv2 patches for non-ASCII and non-UTF characters in image paths
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,33 @@ def imshow(winname: str, mat: np.ndarray):
|
||||||
|
|
||||||
|
|
||||||
# PyTorch functions ----------------------------------------------------------------------------------------------------
|
# PyTorch functions ----------------------------------------------------------------------------------------------------
|
||||||
_torch_save = torch.save # copy to avoid recursion errors
|
_torch_load = torch.load # copy to avoid recursion errors
|
||||||
|
_torch_save = torch.save
|
||||||
|
|
||||||
|
|
||||||
|
def torch_load(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Load a PyTorch model with updated arguments to avoid warnings.
|
||||||
|
|
||||||
|
This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args (Any): Variable length argument list to pass to torch.load.
|
||||||
|
**kwargs (Any): Arbitrary keyword arguments to pass to torch.load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Any): The loaded PyTorch object.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
|
||||||
|
if the argument is not provided, to avoid deprecation warnings.
|
||||||
|
"""
|
||||||
|
from ultralytics.utils.torch_utils import TORCH_1_13
|
||||||
|
|
||||||
|
if TORCH_1_13 and "weights_only" not in kwargs:
|
||||||
|
kwargs["weights_only"] = False
|
||||||
|
|
||||||
|
return _torch_load(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def torch_save(*args, use_dill=True, **kwargs):
|
def torch_save(*args, use_dill=True, **kwargs):
|
||||||
|
|
@ -68,7 +94,7 @@ def torch_save(*args, use_dill=True, **kwargs):
|
||||||
Args:
|
Args:
|
||||||
*args (tuple): Positional arguments to pass to torch.save.
|
*args (tuple): Positional arguments to pass to torch.save.
|
||||||
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
|
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
|
||||||
**kwargs (any): Keyword arguments to pass to torch.save.
|
**kwargs (Any): Keyword arguments to pass to torch.save.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
assert use_dill
|
assert use_dill
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue