@@ -450,15 +450,16 @@ def next(self):
450
450
451
451
452
452
class TestStrategy (TestCase ):
453
- def _Backtest (self , strategy_coroutine , ** kwargs ):
453
+ @staticmethod
454
+ def _Backtest (strategy_coroutine , data = SHORT_DATA , ** kwargs ):
454
455
class S (Strategy ):
455
456
def init (self ):
456
457
self .step = strategy_coroutine (self )
457
458
458
459
def next (self ):
459
460
try_ (self .step .__next__ , None , StopIteration )
460
461
461
- return Backtest (SHORT_DATA , S , ** kwargs )
462
+ return Backtest (data , S , ** kwargs )
462
463
463
464
def test_position (self ):
464
465
def coroutine (self ):
@@ -1032,12 +1033,8 @@ def next(self):
1032
1033
if self .data .Close [- 1 ] == 100 :
1033
1034
self .buy (size = 1 , sl = 90 )
1034
1035
1035
- df = pd .DataFrame ({
1036
- 'Open' : [100 , 100 , 100 , 50 , 50 ],
1037
- 'High' : [100 , 100 , 100 , 50 , 50 ],
1038
- 'Low' : [100 , 100 , 100 , 50 , 50 ],
1039
- 'Close' : [100 , 100 , 100 , 50 , 50 ],
1040
- })
1036
+ arr = np .r_ [100 , 100 , 100 , 50 , 50 ]
1037
+ df = pd .DataFrame ({'Open' : arr , 'High' : arr , 'Low' : arr , 'Close' : arr })
1041
1038
with self .assertWarnsRegex (UserWarning , 'index is not datetime' ):
1042
1039
bt = Backtest (df , S , cash = 100 , trade_on_close = True )
1043
1040
self .assertEqual (bt .run ()._trades ['ExitPrice' ][0 ], 50 )
@@ -1059,3 +1056,44 @@ def next(self):
1059
1056
order .cancel ()
1060
1057
1061
1058
Backtest (SHORT_DATA , S ).run ()
1059
+
1060
+ def test_trade_on_close_closes_trades_on_close (self ):
1061
+ def coro (strat ):
1062
+ yield strat .buy (size = 1 , sl = 90 ) and strat .buy (size = 1 , sl = 80 )
1063
+ assert len (strat .trades ) == 2
1064
+ yield strat .trades [0 ].close ()
1065
+ yield
1066
+
1067
+ arr = np .r_ [100 , 101 , 102 , 50 , 51 ]
1068
+ df = pd .DataFrame ({
1069
+ 'Open' : arr - 10 ,
1070
+ 'Close' : arr , 'High' : arr , 'Low' : arr })
1071
+ with self .assertWarnsRegex (UserWarning , 'index is not datetime' ):
1072
+ trades = TestStrategy ._Backtest (coro , df , cash = 250 , trade_on_close = True ).run ()._trades
1073
+ # trades = Backtest(df, S, cash=250, trade_on_close=True).run()._trades
1074
+ self .assertEqual (trades ['EntryBar' ][0 ], 1 )
1075
+ self .assertEqual (trades ['ExitBar' ][0 ], 2 )
1076
+ self .assertEqual (trades ['EntryPrice' ][0 ], 101 )
1077
+ self .assertEqual (trades ['ExitPrice' ][0 ], 102 )
1078
+ self .assertEqual (trades ['EntryBar' ][1 ], 1 )
1079
+ self .assertEqual (trades ['ExitBar' ][1 ], 3 )
1080
+ self .assertEqual (trades ['EntryPrice' ][1 ], 101 )
1081
+ self .assertEqual (trades ['ExitPrice' ][1 ], 40 )
1082
+
1083
+ with self .assertWarnsRegex (UserWarning , 'index is not datetime' ):
1084
+ trades = TestStrategy ._Backtest (coro , df , cash = 250 , trade_on_close = False ).run ()._trades
1085
+ # trades = Backtest(df, S, cash=250, trade_on_close=False).run()._trades
1086
+ self .assertEqual (trades ['EntryBar' ][0 ], 2 )
1087
+ self .assertEqual (trades ['ExitBar' ][0 ], 3 )
1088
+ self .assertEqual (trades ['EntryPrice' ][0 ], 92 )
1089
+ self .assertEqual (trades ['ExitPrice' ][0 ], 40 )
1090
+ self .assertEqual (trades ['EntryBar' ][1 ], 2 )
1091
+ self .assertEqual (trades ['ExitBar' ][1 ], 3 )
1092
+ self .assertEqual (trades ['EntryPrice' ][1 ], 92 )
1093
+ self .assertEqual (trades ['ExitPrice' ][1 ], 40 )
1094
+
1095
+ def test_trades_dates_match_prices (self ):
1096
+ bt = Backtest (EURUSD , SmaCross , trade_on_close = True )
1097
+ trades = bt .run ()._trades
1098
+ self .assertEqual (EURUSD .Close [trades ['ExitTime' ]].tolist (),
1099
+ trades ['ExitPrice' ].tolist ())
0 commit comments