ultralytics 8.2.93 new SafeClass and SafeUnpickler classes (#16269)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
e309b6efab
commit
c2068df9d9
2 changed files with 52 additions and 5 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.2.92"
|
||||
__version__ = "8.2.93"
|
||||
|
||||
|
||||
import os
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import pickle
|
||||
import types
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -750,7 +752,35 @@ def temporary_modules(modules=None, attributes=None):
|
|||
del sys.modules[old]
|
||||
|
||||
|
||||
def torch_safe_load(weight):
|
||||
class SafeClass:
|
||||
"""A placeholder class to replace unknown classes during unpickling."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize SafeClass instance, ignoring all arguments."""
|
||||
pass
|
||||
|
||||
|
||||
class SafeUnpickler(pickle.Unpickler):
|
||||
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
||||
|
||||
def find_class(self, module, name):
|
||||
"""Attempt to find a class, returning SafeClass if not among safe modules."""
|
||||
safe_modules = (
|
||||
"torch",
|
||||
"collections",
|
||||
"collections.abc",
|
||||
"builtins",
|
||||
"math",
|
||||
"numpy",
|
||||
# Add other modules considered safe
|
||||
)
|
||||
if module in safe_modules:
|
||||
return super().find_class(module, name)
|
||||
else:
|
||||
return SafeClass
|
||||
|
||||
|
||||
def torch_safe_load(weight, safe_only=False):
|
||||
"""
|
||||
Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
|
||||
error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
|
||||
|
|
@ -758,9 +788,18 @@ def torch_safe_load(weight):
|
|||
|
||||
Args:
|
||||
weight (str): The file path of the PyTorch model.
|
||||
safe_only (bool): If True, replace unknown classes with SafeClass during loading.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.nn.tasks import torch_safe_load
|
||||
|
||||
ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
|
||||
```
|
||||
|
||||
Returns:
|
||||
(dict): The loaded PyTorch model.
|
||||
ckpt (dict): The loaded model checkpoint.
|
||||
file (str): The loaded filename
|
||||
"""
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
|
||||
|
|
@ -779,6 +818,14 @@ def torch_safe_load(weight):
|
|||
"ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10
|
||||
},
|
||||
):
|
||||
if safe_only:
|
||||
# Load via custom pickle module
|
||||
safe_pickle = types.ModuleType("safe_pickle")
|
||||
safe_pickle.Unpickler = SafeUnpickler
|
||||
safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
|
||||
with open(file, "rb") as f:
|
||||
ckpt = torch.load(f, pickle_module=safe_pickle)
|
||||
else:
|
||||
ckpt = torch.load(file, map_location="cpu")
|
||||
|
||||
except ModuleNotFoundError as e: # e.name is missing module name
|
||||
|
|
@ -809,7 +856,7 @@ def torch_safe_load(weight):
|
|||
)
|
||||
ckpt = {"model": ckpt.model}
|
||||
|
||||
return ckpt, file # load
|
||||
return ckpt, file
|
||||
|
||||
|
||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue