Skip to content

Commit 07adb02

Browse files
add tests;place mav color cycle into config
1 parent bd18e03 commit 07adb02

File tree

4 files changed

+80
-14
lines changed

4 files changed

+80
-14
lines changed

src/mplfinance/plotting.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ def _valid_plot_kwargs():
124124
'Description' : 'Exponential Moving Average window size(s); (int or tuple of ints)',
125125
'Validator' : _mav_validator },
126126

127+
'mavcolors' : { 'Default' : None,
128+
'Description' : 'color cycle for moving averages (list or tuple of colors)'+
129+
'(overrides mpf style mavcolors).',
130+
'Validator' : lambda value: isinstance(value,(list,tuple)) and
131+
all([mcolors.is_color_like(v) for v in value]) },
132+
127133
'renko_params' : { 'Default' : dict(),
128134
'Description' : 'dict of renko parameters; call `mpf.kwarg_help("renko_params")`',
129135
'Validator' : lambda value: isinstance(value,dict) },
@@ -454,6 +460,13 @@ def plot( data, **kwargs ):
454460
else:
455461
raise TypeError('style should be a `dict`; why is it not?')
456462

463+
if config['mavcolors'] is not None:
464+
config['_ma_color_cycle'] = cycle(config['mavcolors'])
465+
elif style['mavcolors'] is not None:
466+
config['_ma_color_cycle'] = cycle(style['mavcolors'])
467+
else:
468+
config['_ma_color_cycle'] = None
469+
457470
if not external_axes_mode:
458471
fig = plt.figure()
459472
_adjust_figsize(fig,config)
@@ -1142,10 +1155,7 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
11421155
if len(mavgs) > 7:
11431156
mavgs = mavgs[0:7] # take at most 7
11441157

1145-
if style['mavcolors'] is not None:
1146-
mavc = cycle(style['mavcolors'])
1147-
else:
1148-
mavc = None
1158+
mavc = config['_ma_color_cycle']
11491159

11501160
for idx,mav in enumerate(mavgs):
11511161
mean = pd.Series(prices).rolling(mav).mean()
@@ -1178,11 +1188,8 @@ def _plot_ema(ax,config,xdates,prices,apmav=None,apwidth=None):
11781188
mavgs = mavgs, # convert to tuple
11791189
if len(mavgs) > 7:
11801190
mavgs = mavgs[0:7] # take at most 7
1181-
1182-
if style['mavcolors'] is not None:
1183-
mavc = cycle(style['mavcolors'])
1184-
else:
1185-
mavc = None
1191+
1192+
mavc = config['_ma_color_cycle']
11861193

11871194
for idx,mav in enumerate(mavgs):
11881195
# mean = pd.Series(prices).rolling(mav).mean()

tests/reference_images/ema02.png

72 KB
Loading

tests/reference_images/ema03.png

78.2 KB
Loading

tests/test_ema.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,29 @@
1414
tdir = os.path.join('tests','test_images')
1515
refd = os.path.join('tests','reference_images')
1616

17+
globpattern = os.path.join(tdir,base+'*.png')
18+
oldtestfiles = glob.glob(globpattern)
19+
for fn in oldtestfiles:
20+
try:
21+
os.remove(fn)
22+
except:
23+
print('Error removing file "'+fn+'"')
24+
1725
IMGCOMP_TOLERANCE = 10.0 # this works fine for linux
1826
# IMGCOMP_TOLERANCE = 11.0 # required for a windows pass. (really 10.25 may do it).
1927

20-
def create_ema_image(tname):
28+
_df = pd.DataFrame()
29+
def get_ema_data():
30+
global _df
31+
if len(_df) == 0:
32+
_df = pd.read_csv('./examples/data/yahoofinance-GOOG-20040819-20180120.csv',
33+
index_col='Date',parse_dates=True)
34+
return _df
35+
2136

22-
df = pd.read_csv('./examples/data/yahoofinance-GOOG-20040819-20180120.csv', parse_dates=True)
23-
df.index = pd.DatetimeIndex(df['Date'])
37+
def create_ema_image(tname):
2438

39+
df = get_ema_data()
2540
df = df[-50:] # show last 50 data points only
2641

2742
ema25 = df['Close'].ewm(span=25.0, adjust=False).mean()
@@ -42,7 +57,7 @@ def create_ema_image(tname):
4257
)
4358

4459

45-
def test_ema():
60+
def test_ema01():
4661

4762
fname = base+'01.png'
4863
tname = os.path.join(tdir,fname)
@@ -61,5 +76,49 @@ def test_ema():
6176
print('result=',result)
6277
assert result is None
6378

79+
def test_ema02():
80+
fname = base+'02.png'
81+
tname = os.path.join(tdir,fname)
82+
rname = os.path.join(refd,fname)
83+
84+
df = get_ema_data()
85+
df = df[-125:-35]
86+
87+
mpf.plot(df, type='candle', ema=(5,15,25), mav=(5,15,25), savefig=tname)
88+
89+
tsize = os.path.getsize(tname)
90+
print(glob.glob(tname),'[',tsize,'bytes',']')
91+
92+
rsize = os.path.getsize(rname)
93+
print(glob.glob(rname),'[',rsize,'bytes',']')
94+
95+
result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
96+
if result is not None:
97+
print('result=',result)
98+
assert result is None
99+
100+
def test_ema03():
101+
fname = base+'03.png'
102+
tname = os.path.join(tdir,fname)
103+
rname = os.path.join(refd,fname)
104+
105+
df = get_ema_data()
106+
df = df[-125:-35]
107+
108+
mac = ['red','orange','yellow','green','blue','purple']
109+
110+
mpf.plot(df, type='candle', ema=(5,10,15,25), mav=(5,15,25),
111+
mavcolors=mac, savefig=tname)
112+
113+
114+
tsize = os.path.getsize(tname)
115+
print(glob.glob(tname),'[',tsize,'bytes',']')
116+
117+
rsize = os.path.getsize(rname)
118+
print(glob.glob(rname),'[',rsize,'bytes',']')
119+
120+
result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
121+
if result is not None:
122+
print('result=',result)
123+
assert result is None
64124

65-
test_ema()

0 commit comments

Comments
 (0)