Skip to content

Commit 48b0d58

Browse files
committed
BUG: Fix grid optimization with tz-aware datetime index
Fixes #1252
1 parent 89eee8e commit 48b0d58

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

Diff for: backtesting/_util.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,11 @@ def arr2shm(self, vals):
305305
"""Array to shared memory. Returns (shm_name, shape, dtype) used for restore."""
306306
assert vals.ndim == 1, (vals.ndim, vals.shape, vals)
307307
shm = self.SharedMemory(size=vals.nbytes, create=True)
308-
buf = np.ndarray(vals.shape, dtype=vals.dtype, buffer=shm.buf)
309-
buf[:] = vals[:] # Copy into shared memory
308+
# np.array can't handle pandas' tz-aware datetimes
309+
# https://github.com/numpy/numpy/issues/18279
310+
buf = np.ndarray(vals.shape, dtype=vals.dtype.base, buffer=shm.buf)
311+
has_tz = getattr(vals.dtype, 'tz', None)
312+
buf[:] = vals.tz_localize(None) if has_tz else vals # Copy into shared memory
310313
return shm.name, vals.shape, vals.dtype
311314

312315
def df2shm(self, df):
@@ -316,18 +319,18 @@ def df2shm(self, df):
316319
))
317320

318321
@staticmethod
319-
def shm2arr(shm, shape, dtype):
320-
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
322+
def shm2s(shm, shape, dtype) -> pd.Series:
323+
arr = np.ndarray(shape, dtype=dtype.base, buffer=shm.buf)
321324
arr.setflags(write=False)
322-
return arr
325+
return pd.Series(arr, dtype=dtype)
323326

324327
_DF_INDEX_COL = '__bt_index'
325328

326329
@staticmethod
327330
def shm2df(data_shm):
328331
shm = [SharedMemory(name=name, create=False, track=False) for _, name, _, _ in data_shm]
329332
df = pd.DataFrame({
330-
col: SharedMemoryManager.shm2arr(shm, shape, dtype)
333+
col: SharedMemoryManager.shm2s(shm, shape, dtype)
331334
for shm, (col, _, shape, dtype) in zip(shm, data_shm)})
332335
df.set_index(SharedMemoryManager._DF_INDEX_COL, drop=True, inplace=True)
333336
df.index.name = None

Diff for: backtesting/test/_test.py

+6
Original file line numberDiff line numberDiff line change
@@ -1119,3 +1119,9 @@ def next(self):
11191119
trades = Backtest(SHORT_DATA, S).run()._trades
11201120
self.assertEqual(trades['ExitBar'].iloc[0], 3)
11211121
self.assertEqual(trades['ExitPrice'].iloc[0], 105)
1122+
1123+
def test_optimize_datetime_index_with_timezone(self):
1124+
data: pd.DataFrame = GOOG.iloc[:100]
1125+
data.index = data.index.tz_localize('Asia/Kolkata')
1126+
res = Backtest(data, SmaCross).optimize(fast=range(2, 3), slow=range(4, 5))
1127+
self.assertGreater(res['# Trades'], 0)

0 commit comments

Comments
 (0)