diff --git a/backtesting/__init__.py b/backtesting/__init__.py index c670d7e4..aa8615af 100644 --- a/backtesting/__init__.py +++ b/backtesting/__init__.py @@ -68,3 +68,23 @@ from . import lib # noqa: F401 from ._plotting import set_bokeh_output # noqa: F401 from .backtesting import Backtest, Strategy # noqa: F401 + + +# Add overridable backtesting.Pool used for parallel optimization +def Pool(processes=None, initializer=None, initargs=()): + import multiprocessing as mp + if mp.get_start_method() == 'spawn': + import warnings + warnings.warn( + "If you want to use multi-process optimization with " + "`multiprocessing.get_start_method() == 'spawn'` (e.g. on Windows)," + "set `backtesting.Pool = multiprocessing.Pool` (or of the desired context) " + "and hide `bt.optimize()` call behind a `if __name__ == '__main__'` guard. " + "Currently using thread-based paralellism, " + "which might be slightly slower for non-numpy / non-GIL-releasing code. " + "See https://github.com/kernc/backtesting.py/issues/1256", + category=RuntimeWarning, stacklevel=3) + from multiprocessing.dummy import Pool + return Pool(processes, initializer, initargs) + else: + return mp.Pool(processes, initializer, initargs) diff --git a/backtesting/backtesting.py b/backtesting/backtesting.py index fa53a91a..a85f5c9d 100644 --- a/backtesting/backtesting.py +++ b/backtesting/backtesting.py @@ -8,7 +8,6 @@ from __future__ import annotations -import multiprocessing as mp import sys import warnings from abc import ABCMeta, abstractmethod @@ -1309,7 +1308,8 @@ def run(self, **kwargs) -> pd.Series: # np.nan >= 3 is not invalid; it's False. with np.errstate(invalid='ignore'): - for i in _tqdm(range(start, len(self._data)), desc=self.run.__qualname__): + for i in _tqdm(range(start, len(self._data)), desc=self.run.__qualname__, + unit='bar', mininterval=2, miniters=100): # Prepare data and indicators for `next` call data._set_length(i + 1) for attr, indicator in indicator_attrs: @@ -1501,9 +1501,9 @@ def _optimize_grid() -> Union[pd.Series, Tuple[pd.Series, pd.Series]]: [p.values() for p in param_combos], names=next(iter(param_combos)).keys())) - with mp.Pool() as pool, \ + from . import Pool + with Pool() as pool, \ SharedMemoryManager() as smm: - with patch(self, '_data', None): bt = copy(self) # bt._data will be reassigned in _mp_task worker results = _tqdm( @@ -1565,7 +1565,8 @@ def memoized_run(tup): stats = self.run(**dict(tup)) return -maximize(stats) - progress = iter(_tqdm(repeat(None), total=max_tries, leave=False, desc='Backtest.optimize')) + progress = iter(_tqdm(repeat(None), total=max_tries, leave=False, + desc=self.optimize.__qualname__, mininterval=2)) _names = tuple(kwargs.keys()) def objective_function(x): diff --git a/backtesting/lib.py b/backtesting/lib.py index ec6b8359..c584f567 100644 --- a/backtesting/lib.py +++ b/backtesting/lib.py @@ -13,7 +13,6 @@ from __future__ import annotations -import multiprocessing as mp import warnings from collections import OrderedDict from inspect import currentframe @@ -569,7 +568,8 @@ def run(self, **kwargs): Wraps `backtesting.backtesting.Backtest.run`. Returns `pd.DataFrame` with currency indexes in columns. """ - with mp.Pool() as pool, \ + from . import Pool + with Pool() as pool, \ SharedMemoryManager() as smm: shm = [smm.df2shm(df) for df in self._dfs] results = _tqdm( @@ -577,7 +577,8 @@ def run(self, **kwargs): ((df_batch, self._strategy, self._bt_kwargs, kwargs) for df_batch in _batch(shm))), total=len(shm), - desc=self.__class__.__name__, + desc=self.run.__qualname__, + mininterval=2 ) df = pd.DataFrame(list(chain(*results))).transpose() return df @@ -605,7 +606,7 @@ def optimize(self, **kwargs) -> pd.DataFrame: """ heatmaps = [] # Simple loop since bt.optimize already does its own multiprocessing - for df in _tqdm(self._dfs, desc=self.__class__.__name__): + for df in _tqdm(self._dfs, desc=self.__class__.__name__, mininterval=2): bt = Backtest(df, self._strategy, **self._bt_kwargs) _best_stats, heatmap = bt.optimize( # type: ignore return_heatmap=True, return_optimization=False, **kwargs) diff --git a/backtesting/test/_test.py b/backtesting/test/_test.py index 47dc393b..3b4c5c5c 100644 --- a/backtesting/test/_test.py +++ b/backtesting/test/_test.py @@ -1,4 +1,5 @@ import inspect +import multiprocessing as mp import os import sys import time @@ -629,7 +630,8 @@ def test_optimize_speed(self): bt.optimize(fast=range(2, 20, 2), slow=range(10, 40, 2)) end = time.process_time() print(end - start) - self.assertLess(end - start, .3) + handicap = 5 if 'win' in sys.platform else .1 + self.assertLess(end - start, .3 + handicap) class TestPlot(TestCase): @@ -934,13 +936,20 @@ def test_FractionalBacktest(self): self.assertEqual(stats['# Trades'], 41) def test_MultiBacktest(self): - btm = MultiBacktest([GOOG, EURUSD, BTCUSD], SmaCross, cash=100_000) - res = btm.run(fast=2) - self.assertIsInstance(res, pd.DataFrame) - self.assertEqual(res.columns.tolist(), [0, 1, 2]) - heatmap = btm.optimize(fast=[2, 4], slow=[10, 20]) - self.assertIsInstance(heatmap, pd.DataFrame) - self.assertEqual(heatmap.columns.tolist(), [0, 1, 2]) + import backtesting + assert callable(getattr(backtesting, 'Pool', None)), backtesting.__dict__ + for start_method in mp.get_all_start_methods(): + with self.subTest(start_method=start_method), \ + patch(backtesting, 'Pool', mp.get_context(start_method).Pool): + start_time = time.monotonic() + btm = MultiBacktest([GOOG, EURUSD, BTCUSD], SmaCross, cash=100_000) + res = btm.run(fast=2) + self.assertIsInstance(res, pd.DataFrame) + self.assertEqual(res.columns.tolist(), [0, 1, 2]) + heatmap = btm.optimize(fast=[2, 4], slow=[10, 20]) + self.assertIsInstance(heatmap, pd.DataFrame) + self.assertEqual(heatmap.columns.tolist(), [0, 1, 2]) + print(start_method, time.monotonic() - start_time) plot_heatmaps(heatmap.mean(axis=1), open_browser=False) @@ -1001,7 +1010,6 @@ def test_indicators_picklable(self): class TestDocs(TestCase): DOCS_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'doc') - @unittest.skipIf('win' in sys.platform, "Locks up with `ModuleNotFoundError: No module named ''`") @unittest.skipUnless(os.path.isdir(DOCS_DIR), "docs dir doesn't exist") def test_examples(self): examples = glob(os.path.join(self.DOCS_DIR, 'examples', '*.py'))