Skip to content

Commit cc4b0ec

Browse files
Red4RuZeroIntensitypicnixz
authored andcommitted
gh-104745: Limit starting a patcher more than once without stopping it (#126649)
Previously, this would cause an `AttributeError` if the patch stopped more than once after this, and would also disrupt the original patched object. --------- Co-authored-by: Peter Bierma <[email protected]> Co-authored-by: Bénédikt Tran <[email protected]> Backports: 1e40c5ba47780ddd91868abb3aa064f5ba3015e4 Signed-off-by: Chris Withers <[email protected]>
1 parent 7a018f1 commit cc4b0ec

File tree

3 files changed

+62
-2
lines changed

3 files changed

+62
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Limit starting a patcher (from :func:`unittest.mock.patch` or
2+
:func:`unittest.mock.patch.object`) more than
3+
once without stopping it

mock/mock.py

+9
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,7 @@ def __init__(
14011401
self.autospec = autospec
14021402
self.kwargs = kwargs
14031403
self.additional_patchers = []
1404+
self.is_started = False
14041405

14051406

14061407
def copy(self):
@@ -1513,6 +1514,9 @@ def get_original(self):
15131514

15141515
def __enter__(self):
15151516
"""Perform the patch."""
1517+
if self.is_started:
1518+
raise RuntimeError("Patch is already started")
1519+
15161520
new, spec, spec_set = self.new, self.spec, self.spec_set
15171521
autospec, kwargs = self.autospec, self.kwargs
15181522
new_callable = self.new_callable
@@ -1644,6 +1648,7 @@ def __enter__(self):
16441648
self.temp_original = original
16451649
self.is_local = local
16461650
self._exit_stack = contextlib.ExitStack()
1651+
self.is_started = True
16471652
try:
16481653
setattr(self.target, self.attribute, new_attr)
16491654
if self.attribute_name is not None:
@@ -1663,6 +1668,9 @@ def __enter__(self):
16631668

16641669
def __exit__(self, *exc_info):
16651670
"""Undo the patch."""
1671+
if not self.is_started:
1672+
return
1673+
16661674
if self.is_local and self.temp_original is not DEFAULT:
16671675
setattr(self.target, self.attribute, self.temp_original)
16681676
else:
@@ -1679,6 +1687,7 @@ def __exit__(self, *exc_info):
16791687
del self.target
16801688
exit_stack = self._exit_stack
16811689
del self._exit_stack
1690+
self.is_started = False
16821691
return exit_stack.__exit__(*exc_info)
16831692

16841693

mock/tests/testpatch.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,54 @@ def test_stop_idempotent(self):
743743
self.assertIsNone(patcher.stop())
744744

745745

746+
def test_exit_idempotent(self):
747+
patcher = patch(foo_name, 'bar', 3)
748+
with patcher:
749+
patcher.stop()
750+
751+
752+
def test_second_start_failure(self):
753+
patcher = patch(foo_name, 'bar', 3)
754+
patcher.start()
755+
try:
756+
self.assertRaises(RuntimeError, patcher.start)
757+
finally:
758+
patcher.stop()
759+
760+
761+
def test_second_enter_failure(self):
762+
patcher = patch(foo_name, 'bar', 3)
763+
with patcher:
764+
self.assertRaises(RuntimeError, patcher.start)
765+
766+
767+
def test_second_start_after_stop(self):
768+
patcher = patch(foo_name, 'bar', 3)
769+
patcher.start()
770+
patcher.stop()
771+
patcher.start()
772+
patcher.stop()
773+
774+
775+
def test_property_setters(self):
776+
mock_object = Mock()
777+
mock_bar = mock_object.bar
778+
patcher = patch.object(mock_object, 'bar', 'x')
779+
with patcher:
780+
self.assertEqual(patcher.is_local, False)
781+
self.assertIs(patcher.target, mock_object)
782+
self.assertEqual(patcher.temp_original, mock_bar)
783+
patcher.is_local = True
784+
patcher.target = mock_bar
785+
patcher.temp_original = mock_object
786+
self.assertEqual(patcher.is_local, True)
787+
self.assertIs(patcher.target, mock_bar)
788+
self.assertEqual(patcher.temp_original, mock_object)
789+
# if changes are left intact, they may lead to disruption as shown below (it might be what someone needs though)
790+
self.assertEqual(mock_bar.bar, mock_object)
791+
self.assertEqual(mock_object.bar, 'x')
792+
793+
746794
def test_patchobject_start_stop(self):
747795
original = something
748796
patcher = patch.object(PTModule, 'something', 'foo')
@@ -1096,7 +1144,7 @@ def test_new_callable_patch(self):
10961144

10971145
self.assertIsNot(m1, m2)
10981146
for mock in m1, m2:
1099-
self.assertNotCallable(m1)
1147+
self.assertNotCallable(mock)
11001148

11011149

11021150
def test_new_callable_patch_object(self):
@@ -1109,7 +1157,7 @@ def test_new_callable_patch_object(self):
11091157

11101158
self.assertIsNot(m1, m2)
11111159
for mock in m1, m2:
1112-
self.assertNotCallable(m1)
1160+
self.assertNotCallable(mock)
11131161

11141162

11151163
def test_new_callable_keyword_arguments(self):

0 commit comments

Comments
 (0)