Skip to content

Commit e947649

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
[BE] Change _marked_safe_globals_list to set (pytorch#139303)
Prevent same global from being added multiple times Pull Request resolved: pytorch#139303 Approved by: https://github.com/janeyx99 ghstack dependencies: pytorch#138936, pytorch#139221, pytorch#139433, pytorch#139541, pytorch#137602
1 parent 1565eba commit e947649

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

torch/_weights_only_unpickler.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,27 @@
8383
"nt",
8484
]
8585

86-
_marked_safe_globals_list: List[Any] = []
86+
_marked_safe_globals_set: Set[Any] = set()
8787

8888

8989
def _add_safe_globals(safe_globals: List[Any]):
90-
global _marked_safe_globals_list
91-
_marked_safe_globals_list += safe_globals
90+
global _marked_safe_globals_set
91+
_marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals))
9292

9393

9494
def _get_safe_globals() -> List[Any]:
95-
global _marked_safe_globals_list
96-
return _marked_safe_globals_list
95+
global _marked_safe_globals_set
96+
return list(_marked_safe_globals_set)
9797

9898

9999
def _clear_safe_globals():
100-
global _marked_safe_globals_list
101-
_marked_safe_globals_list = []
100+
global _marked_safe_globals_set
101+
_marked_safe_globals_set = set()
102102

103103

104104
def _remove_safe_globals(globals_to_remove: List[Any]):
105-
global _marked_safe_globals_list
106-
_marked_safe_globals_list = list(
107-
set(_marked_safe_globals_list) - set(globals_to_remove)
108-
)
105+
global _marked_safe_globals_set
106+
_marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove)
109107

110108

111109
class _safe_globals:
@@ -128,7 +126,7 @@ def __exit__(self, type, value, tb):
128126
# _get_allowed_globals due to the lru_cache
129127
def _get_user_allowed_globals():
130128
rc: Dict[str, Any] = {}
131-
for f in _marked_safe_globals_list:
129+
for f in _marked_safe_globals_set:
132130
module, name = f.__module__, f.__name__
133131
rc[f"{module}.{name}"] = f
134132
return rc

0 commit comments

Comments
 (0)