Skip to content

Commit 0b773d0

Browse files
committed
ENH: Roll own util.patch() for monkey-patching objects
1 parent 236889b commit 0b773d0

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

backtesting/_util.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4+
from contextlib import contextmanager
45
from numbers import Number
56
from typing import Dict, List, Optional, Sequence, Union, cast
67

@@ -15,6 +16,20 @@ def try_(lazy_func, default=None, exception=Exception):
1516
return default
1617

1718

19+
@contextmanager
20+
def patch(obj, attr, newvalue):
21+
had_attr = hasattr(obj, attr)
22+
orig_value = getattr(obj, attr, None)
23+
setattr(obj, attr, newvalue)
24+
try:
25+
yield
26+
finally:
27+
if had_attr:
28+
setattr(obj, attr, orig_value)
29+
else:
30+
delattr(obj, attr)
31+
32+
1833
def _as_str(value) -> str:
1934
if isinstance(value, (Number, str)):
2035
return str(value)

backtesting/test/_test.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import multiprocessing
23
import os
34
import sys
45
import time
@@ -10,15 +11,14 @@
1011
from runpy import run_path
1112
from tempfile import NamedTemporaryFile, gettempdir
1213
from unittest import TestCase
13-
from unittest.mock import patch
1414

1515
import numpy as np
1616
import pandas as pd
1717
from pandas.testing import assert_frame_equal
1818

1919
from backtesting import Backtest, Strategy
2020
from backtesting._stats import compute_drawdown_duration_peaks
21-
from backtesting._util import _Array, _as_str, _Indicator, try_
21+
from backtesting._util import _Array, _as_str, _Indicator, patch, try_
2222
from backtesting.lib import (
2323
FractionalBacktest, OHLCV_AGG,
2424
SignalStrategy,
@@ -626,7 +626,7 @@ def test_multiprocessing_windows_spawn(self):
626626
kw = {'fast': [10]}
627627

628628
stats1 = Backtest(df, SmaCross).optimize(**kw)
629-
with patch('multiprocessing.get_start_method', lambda **_: 'spawn'):
629+
with patch(multiprocessing, 'get_start_method', lambda **_: 'spawn'):
630630
with self.assertWarns(UserWarning) as cm:
631631
stats2 = Backtest(df, SmaCross).optimize(**kw)
632632

@@ -776,7 +776,7 @@ def init(self):
776776
bt.run()
777777
import backtesting._plotting
778778
with _tempfile() as f, \
779-
patch.object(backtesting._plotting, '_MAX_CANDLES', 10), \
779+
patch(backtesting._plotting, '_MAX_CANDLES', 10), \
780780
self.assertWarns(UserWarning):
781781
bt.plot(filename=f, resample=True)
782782
# Give browser time to open before tempfile is removed
@@ -976,6 +976,15 @@ def __call__(self):
976976
for s in ('Open', 'High', 'Low', 'Close', 'Volume'):
977977
self.assertEqual(_as_str(_Array([1], name=s)), s[0])
978978

979+
def test_patch(self):
980+
class Object:
981+
pass
982+
o = Object()
983+
o.attr = False
984+
with patch(o, 'attr', True):
985+
self.assertTrue(o.attr)
986+
self.assertFalse(o.attr)
987+
979988
def test_pandas_accessors(self):
980989
class S(Strategy):
981990
def init(self):

0 commit comments

Comments
 (0)