Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mp.Pool hang with multiprocessing start_method "spawn" #1258

Merged
merged 4 commits into from
Mar 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions backtesting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,23 @@
from . import lib # noqa: F401
from ._plotting import set_bokeh_output # noqa: F401
from .backtesting import Backtest, Strategy # noqa: F401


# Add overridable backtesting.Pool used for parallel optimization
def Pool(processes=None, initializer=None, initargs=()):
import multiprocessing as mp
if mp.get_start_method() == 'spawn':
import warnings
warnings.warn(
"If you want to use multi-process optimization with "
"`multiprocessing.get_start_method() == 'spawn'` (e.g. on Windows),"
"set `backtesting.Pool = multiprocessing.Pool` (or of the desired context) "
"and hide `bt.optimize()` call behind a `if __name__ == '__main__'` guard. "
"Currently using thread-based paralellism, "
"which might be slightly slower for non-numpy / non-GIL-releasing code. "
"See https://github.com/kernc/backtesting.py/issues/1256",
category=RuntimeWarning, stacklevel=3)
from multiprocessing.dummy import Pool
return Pool(processes, initializer, initargs)
else:
return mp.Pool(processes, initializer, initargs)
11 changes: 6 additions & 5 deletions backtesting/backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from __future__ import annotations

import multiprocessing as mp
import sys
import warnings
from abc import ABCMeta, abstractmethod
Expand Down Expand Up @@ -1309,7 +1308,8 @@ def run(self, **kwargs) -> pd.Series:
# np.nan >= 3 is not invalid; it's False.
with np.errstate(invalid='ignore'):

for i in _tqdm(range(start, len(self._data)), desc=self.run.__qualname__):
for i in _tqdm(range(start, len(self._data)), desc=self.run.__qualname__,
unit='bar', mininterval=2, miniters=100):
# Prepare data and indicators for `next` call
data._set_length(i + 1)
for attr, indicator in indicator_attrs:
Expand Down Expand Up @@ -1501,9 +1501,9 @@ def _optimize_grid() -> Union[pd.Series, Tuple[pd.Series, pd.Series]]:
[p.values() for p in param_combos],
names=next(iter(param_combos)).keys()))

with mp.Pool() as pool, \
from . import Pool
with Pool() as pool, \
SharedMemoryManager() as smm:

with patch(self, '_data', None):
bt = copy(self) # bt._data will be reassigned in _mp_task worker
results = _tqdm(
Expand Down Expand Up @@ -1565,7 +1565,8 @@ def memoized_run(tup):
stats = self.run(**dict(tup))
return -maximize(stats)

progress = iter(_tqdm(repeat(None), total=max_tries, leave=False, desc='Backtest.optimize'))
progress = iter(_tqdm(repeat(None), total=max_tries, leave=False,
desc=self.optimize.__qualname__, mininterval=2))
_names = tuple(kwargs.keys())

def objective_function(x):
Expand Down
9 changes: 5 additions & 4 deletions backtesting/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from __future__ import annotations

import multiprocessing as mp
import warnings
from collections import OrderedDict
from inspect import currentframe
Expand Down Expand Up @@ -569,15 +568,17 @@ def run(self, **kwargs):
Wraps `backtesting.backtesting.Backtest.run`. Returns `pd.DataFrame` with
currency indexes in columns.
"""
with mp.Pool() as pool, \
from . import Pool
with Pool() as pool, \
SharedMemoryManager() as smm:
shm = [smm.df2shm(df) for df in self._dfs]
results = _tqdm(
pool.imap(self._mp_task_run,
((df_batch, self._strategy, self._bt_kwargs, kwargs)
for df_batch in _batch(shm))),
total=len(shm),
desc=self.__class__.__name__,
desc=self.run.__qualname__,
mininterval=2
)
df = pd.DataFrame(list(chain(*results))).transpose()
return df
Expand Down Expand Up @@ -605,7 +606,7 @@ def optimize(self, **kwargs) -> pd.DataFrame:
"""
heatmaps = []
# Simple loop since bt.optimize already does its own multiprocessing
for df in _tqdm(self._dfs, desc=self.__class__.__name__):
for df in _tqdm(self._dfs, desc=self.__class__.__name__, mininterval=2):
bt = Backtest(df, self._strategy, **self._bt_kwargs)
_best_stats, heatmap = bt.optimize( # type: ignore
return_heatmap=True, return_optimization=False, **kwargs)
Expand Down
26 changes: 17 additions & 9 deletions backtesting/test/_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import multiprocessing as mp
import os
import sys
import time
Expand Down Expand Up @@ -629,7 +630,8 @@ def test_optimize_speed(self):
bt.optimize(fast=range(2, 20, 2), slow=range(10, 40, 2))
end = time.process_time()
print(end - start)
self.assertLess(end - start, .3)
handicap = 5 if 'win' in sys.platform else .1
self.assertLess(end - start, .3 + handicap)


class TestPlot(TestCase):
Expand Down Expand Up @@ -934,13 +936,20 @@ def test_FractionalBacktest(self):
self.assertEqual(stats['# Trades'], 41)

def test_MultiBacktest(self):
btm = MultiBacktest([GOOG, EURUSD, BTCUSD], SmaCross, cash=100_000)
res = btm.run(fast=2)
self.assertIsInstance(res, pd.DataFrame)
self.assertEqual(res.columns.tolist(), [0, 1, 2])
heatmap = btm.optimize(fast=[2, 4], slow=[10, 20])
self.assertIsInstance(heatmap, pd.DataFrame)
self.assertEqual(heatmap.columns.tolist(), [0, 1, 2])
import backtesting
assert callable(getattr(backtesting, 'Pool', None)), backtesting.__dict__
for start_method in mp.get_all_start_methods():
with self.subTest(start_method=start_method), \
patch(backtesting, 'Pool', mp.get_context(start_method).Pool):
start_time = time.monotonic()
btm = MultiBacktest([GOOG, EURUSD, BTCUSD], SmaCross, cash=100_000)
res = btm.run(fast=2)
self.assertIsInstance(res, pd.DataFrame)
self.assertEqual(res.columns.tolist(), [0, 1, 2])
heatmap = btm.optimize(fast=[2, 4], slow=[10, 20])
self.assertIsInstance(heatmap, pd.DataFrame)
self.assertEqual(heatmap.columns.tolist(), [0, 1, 2])
print(start_method, time.monotonic() - start_time)
plot_heatmaps(heatmap.mean(axis=1), open_browser=False)


Expand Down Expand Up @@ -1001,7 +1010,6 @@ def test_indicators_picklable(self):
class TestDocs(TestCase):
DOCS_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'doc')

@unittest.skipIf('win' in sys.platform, "Locks up with `ModuleNotFoundError: No module named '<run_path>'`")
@unittest.skipUnless(os.path.isdir(DOCS_DIR), "docs dir doesn't exist")
def test_examples(self):
examples = glob(os.path.join(self.DOCS_DIR, 'examples', '*.py'))
Expand Down