Skip to content

Commit b7b5657

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Allow user to manually pass module.name associated with global in {add}_safe_global (pytorch#142153)
Fixes pytorch#142144 A global x is saved in checkpoint as `GLOBAL x.__module__ x.__name__`. So , after allowlisting a GLOBAL it is expected to match any GLOBAL instruction of the form `GLOBAL x.__module__ x.__name__` but there are edge cases when for the same API from the same module, what `__module__` gives changes between versions which prevents users from allowlisting the global. In this case, in numpy < 2.1 ``` torch.save("bla", np_array) # checkpoint has GLOBAL "np.core.multiarray" "_reconstruct" ``` In np version 2.1 ``` with safe_globals([np.core.multiarray._reconstruct]): torch.load("bla") ``` np.core.multiarray._reconstruct.__module__ gives "np._core.multiarray" (note the extra _ before core) and see what was done [here](https://github.com/numpy/numpy/blob/main/numpy/core/multiarray.py) Since the dictionary to access safe globals is keyed on "{foo.__module__}.{foo.__name__}", __module__, __name__ will no longer match that in the checkpoint so "np.core.multiarray._reconstruct" can no longer be properly allowlisted (instead np._core.multiarray._reconstruct is a key in the dict). We allow `add_safe_globals/safe_globals` to optionally take tuples of (global, str of module.name) to workaround such (odd/edge case) situations. Pull Request resolved: pytorch#142153 Approved by: https://github.com/albanD
1 parent 1a7da6e commit b7b5657

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

test/test_cpp_extensions_open_device_registration.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -539,10 +539,6 @@ def test_open_device_tensorlist_type_fallback(self):
539539
np.__version__ < "1.25",
540540
"versions < 1.25 serialize dtypes differently from how it's serialized in data_legacy_numpy",
541541
)
542-
@unittest.skipIf(
543-
np.__version__ >= "2.1",
544-
"weights_only failure on numpy >= 2.1",
545-
)
546542
def test_open_device_numpy_serialization(self):
547543
"""
548544
This tests the legacy _rebuild_device_tensor_from_numpy serialization path
@@ -601,7 +597,9 @@ def test_open_device_numpy_serialization(self):
601597

602598
with safe_globals(
603599
[
604-
np.core.multiarray._reconstruct,
600+
(np.core.multiarray._reconstruct, "numpy.core.multiarray._reconstruct")
601+
if np.__version__ >= "2.1"
602+
else np.core.multiarray._reconstruct,
605603
np.ndarray,
606604
np.dtype,
607605
_codecs.encode,

torch/_weights_only_unpickler.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
)
6969
from struct import unpack
7070
from sys import maxsize
71-
from typing import Any, Callable, Dict, List, Set, Tuple
71+
from typing import Any, Callable, Dict, List, Set, Tuple, Union
7272

7373
import torch
7474
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
@@ -83,15 +83,15 @@
8383
"nt",
8484
]
8585

86-
_marked_safe_globals_set: Set[Any] = set()
86+
_marked_safe_globals_set: Set[Union[Callable, Tuple[Callable, str]]] = set()
8787

8888

89-
def _add_safe_globals(safe_globals: List[Any]):
89+
def _add_safe_globals(safe_globals: List[Union[Callable, Tuple[Callable, str]]]):
9090
global _marked_safe_globals_set
9191
_marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals))
9292

9393

94-
def _get_safe_globals() -> List[Any]:
94+
def _get_safe_globals() -> List[Union[Callable, Tuple[Callable, str]]]:
9595
global _marked_safe_globals_set
9696
return list(_marked_safe_globals_set)
9797

@@ -101,13 +101,15 @@ def _clear_safe_globals():
101101
_marked_safe_globals_set = set()
102102

103103

104-
def _remove_safe_globals(globals_to_remove: List[Any]):
104+
def _remove_safe_globals(
105+
globals_to_remove: List[Union[Callable, Tuple[Callable, str]]],
106+
):
105107
global _marked_safe_globals_set
106108
_marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove)
107109

108110

109111
class _safe_globals:
110-
def __init__(self, safe_globals: List[Any]):
112+
def __init__(self, safe_globals: List[Union[Callable, Tuple[Callable, str]]]):
111113
self.safe_globals = safe_globals
112114

113115
def __enter__(self):
@@ -127,8 +129,20 @@ def __exit__(self, type, value, tb):
127129
def _get_user_allowed_globals():
128130
rc: Dict[str, Any] = {}
129131
for f in _marked_safe_globals_set:
130-
module, name = f.__module__, f.__name__
131-
rc[f"{module}.{name}"] = f
132+
if isinstance(f, tuple):
133+
if len(f) != 2:
134+
raise ValueError(
135+
f"Expected tuple of length 2 (global, str of callable full path), but got tuple of length: {len(f)}"
136+
)
137+
if type(f[1]) is not str:
138+
raise TypeError(
139+
f"Expected second item in tuple to be str of callable full path, but got: {type(f[1])}"
140+
)
141+
f, name = f
142+
rc[name] = f
143+
else:
144+
module, name = f.__module__, f.__name__
145+
rc[f"{module}.{name}"] = f
132146
return rc
133147

134148

torch/serialization.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,21 +261,29 @@ def clear_safe_globals() -> None:
261261
_weights_only_unpickler._clear_safe_globals()
262262

263263

264-
def get_safe_globals() -> List[Any]:
264+
def get_safe_globals() -> List[Union[Callable, Tuple[Callable, str]]]:
265265
"""
266266
Returns the list of user-added globals that are safe for ``weights_only`` load.
267267
"""
268268
return _weights_only_unpickler._get_safe_globals()
269269

270270

271-
def add_safe_globals(safe_globals: List[Any]) -> None:
271+
def add_safe_globals(safe_globals: List[Union[Callable, Tuple[Callable, str]]]) -> None:
272272
"""
273273
Marks the given globals as safe for ``weights_only`` load. For example, functions
274274
added to this list can be called during unpickling, classes could be instantiated
275275
and have state set.
276276
277+
Each item in the list can either be a function/class or a tuple of the form
278+
(function/class, string) where string is the full path of the function/class.
279+
280+
Within the serialized format, each function is identified with its full
281+
path as ``{__module__}.{__name__}``. When calling this API, you can provide this
282+
full path that should match the one in the checkpoint otherwise the default
283+
``{fn.__module__}.{fn.__name__}`` will be used.
284+
277285
Args:
278-
safe_globals (List[Any]): list of globals to mark as safe
286+
safe_globals (List[Union[Callable, Tuple[Callable, str]]]): list of globals to mark as safe
279287
280288
Example:
281289
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")

0 commit comments

Comments
 (0)