ultralytics 8.2.93 new SafeClass and SafeUnpickler classes (#16269)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
e309b6efab
commit
c2068df9d9
2 changed files with 52 additions and 5 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.2.92"
|
__version__ = "8.2.93"
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import pickle
|
||||||
|
import types
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
@ -750,7 +752,35 @@ def temporary_modules(modules=None, attributes=None):
|
||||||
del sys.modules[old]
|
del sys.modules[old]
|
||||||
|
|
||||||
|
|
||||||
def torch_safe_load(weight):
|
class SafeClass:
|
||||||
|
"""A placeholder class to replace unknown classes during unpickling."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""Initialize SafeClass instance, ignoring all arguments."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SafeUnpickler(pickle.Unpickler):
|
||||||
|
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
||||||
|
|
||||||
|
def find_class(self, module, name):
|
||||||
|
"""Attempt to find a class, returning SafeClass if not among safe modules."""
|
||||||
|
safe_modules = (
|
||||||
|
"torch",
|
||||||
|
"collections",
|
||||||
|
"collections.abc",
|
||||||
|
"builtins",
|
||||||
|
"math",
|
||||||
|
"numpy",
|
||||||
|
# Add other modules considered safe
|
||||||
|
)
|
||||||
|
if module in safe_modules:
|
||||||
|
return super().find_class(module, name)
|
||||||
|
else:
|
||||||
|
return SafeClass
|
||||||
|
|
||||||
|
|
||||||
|
def torch_safe_load(weight, safe_only=False):
|
||||||
"""
|
"""
|
||||||
Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
|
Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
|
||||||
error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
|
error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
|
||||||
|
|
@ -758,9 +788,18 @@ def torch_safe_load(weight):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
weight (str): The file path of the PyTorch model.
|
weight (str): The file path of the PyTorch model.
|
||||||
|
safe_only (bool): If True, replace unknown classes with SafeClass during loading.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.nn.tasks import torch_safe_load
|
||||||
|
|
||||||
|
ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
|
||||||
|
```
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(dict): The loaded PyTorch model.
|
ckpt (dict): The loaded model checkpoint.
|
||||||
|
file (str): The loaded filename
|
||||||
"""
|
"""
|
||||||
from ultralytics.utils.downloads import attempt_download_asset
|
from ultralytics.utils.downloads import attempt_download_asset
|
||||||
|
|
||||||
|
|
@ -779,6 +818,14 @@ def torch_safe_load(weight):
|
||||||
"ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10
|
"ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10
|
||||||
},
|
},
|
||||||
):
|
):
|
||||||
|
if safe_only:
|
||||||
|
# Load via custom pickle module
|
||||||
|
safe_pickle = types.ModuleType("safe_pickle")
|
||||||
|
safe_pickle.Unpickler = SafeUnpickler
|
||||||
|
safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
|
||||||
|
with open(file, "rb") as f:
|
||||||
|
ckpt = torch.load(f, pickle_module=safe_pickle)
|
||||||
|
else:
|
||||||
ckpt = torch.load(file, map_location="cpu")
|
ckpt = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
except ModuleNotFoundError as e: # e.name is missing module name
|
except ModuleNotFoundError as e: # e.name is missing module name
|
||||||
|
|
@ -809,7 +856,7 @@ def torch_safe_load(weight):
|
||||||
)
|
)
|
||||||
ckpt = {"model": ckpt.model}
|
ckpt = {"model": ckpt.model}
|
||||||
|
|
||||||
return ckpt, file # load
|
return ckpt, file
|
||||||
|
|
||||||
|
|
||||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue