Skip to content

Commit a05dd0d

Browse files
Merge branch 'master' into master
2 parents ac8aef2 + 4db565c commit a05dd0d

File tree

5 files changed

+187
-4
lines changed

5 files changed

+187
-4
lines changed

src/mplfinance/plotting.py

+63-4
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,17 @@ def _valid_plot_kwargs():
119119
'mav' : { 'Default' : None,
120120
'Description' : 'Moving Average window size(s); (int or tuple of ints)',
121121
'Validator' : _mav_validator },
122+
123+
'ema' : { 'Default' : None,
124+
'Description' : 'Exponential Moving Average window size(s); (int or tuple of ints)',
125+
'Validator' : _mav_validator },
122126

127+
'mavcolors' : { 'Default' : None,
128+
'Description' : 'color cycle for moving averages (list or tuple of colors)'+
129+
'(overrides mpf style mavcolors).',
130+
'Validator' : lambda value: isinstance(value,(list,tuple)) and
131+
all([mcolors.is_color_like(v) for v in value]) },
132+
123133
'renko_params' : { 'Default' : dict(),
124134
'Description' : 'dict of renko parameters; call `mpf.kwarg_help("renko_params")`',
125135
'Validator' : lambda value: isinstance(value,dict) },
@@ -454,6 +464,13 @@ def plot( data, **kwargs ):
454464
else:
455465
raise TypeError('style should be a `dict`; why is it not?')
456466

467+
if config['mavcolors'] is not None:
468+
config['_ma_color_cycle'] = cycle(config['mavcolors'])
469+
elif style['mavcolors'] is not None:
470+
config['_ma_color_cycle'] = cycle(style['mavcolors'])
471+
else:
472+
config['_ma_color_cycle'] = None
473+
457474
if not external_axes_mode:
458475
fig = plt.figure()
459476
_adjust_figsize(fig,config)
@@ -532,8 +549,10 @@ def plot( data, **kwargs ):
532549

533550
if ptype in VALID_PMOVE_TYPES:
534551
mavprices = _plot_mav(axA1,config,xdates,pmove_avgvals)
552+
emaprices = _plot_ema(axA1, config, xdates, pmove_avgvals)
535553
else:
536554
mavprices = _plot_mav(axA1,config,xdates,closes)
555+
emaprices = _plot_ema(axA1, config, xdates, closes)
537556

538557
avg_dist_between_points = (xdates[-1] - xdates[0]) / float(len(xdates))
539558
if not config['tight_layout']:
@@ -599,6 +618,13 @@ def plot( data, **kwargs ):
599618
else:
600619
for jj in range(0,len(mav)):
601620
retdict['mav' + str(mav[jj])] = mavprices[jj]
621+
if config['ema'] is not None:
622+
ema = config['ema']
623+
if len(ema) != len(emaprices):
624+
warnings.warn('len(ema)='+str(len(ema))+' BUT len(emaprices)='+str(len(emaprices)))
625+
else:
626+
for jj in range(0, len(ema)):
627+
retdict['ema' + str(ema[jj])] = emaprices[jj]
602628
retdict['minx'] = minx
603629
retdict['maxx'] = maxx
604630
retdict['miny'] = miny
@@ -1140,10 +1166,7 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
11401166
if len(mavgs) > 7:
11411167
mavgs = mavgs[0:7] # take at most 7
11421168

1143-
if style['mavcolors'] is not None:
1144-
mavc = cycle(style['mavcolors'])
1145-
else:
1146-
mavc = None
1169+
mavc = config['_ma_color_cycle']
11471170

11481171
for idx,mav in enumerate(mavgs):
11491172
mean = pd.Series(prices).rolling(mav).mean()
@@ -1158,6 +1181,42 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
11581181
mavp_list.append(mavprices)
11591182
return mavp_list
11601183

1184+
1185+
def _plot_ema(ax,config,xdates,prices,apmav=None,apwidth=None):
1186+
'''ema: exponential moving average'''
1187+
style = config['style']
1188+
if apmav is not None:
1189+
mavgs = apmav
1190+
else:
1191+
mavgs = config['ema']
1192+
mavp_list = []
1193+
if mavgs is not None:
1194+
shift = None
1195+
if isinstance(mavgs,dict):
1196+
shift = mavgs['shift']
1197+
mavgs = mavgs['period']
1198+
if isinstance(mavgs,int):
1199+
mavgs = mavgs, # convert to tuple
1200+
if len(mavgs) > 7:
1201+
mavgs = mavgs[0:7] # take at most 7
1202+
1203+
mavc = config['_ma_color_cycle']
1204+
1205+
for idx,mav in enumerate(mavgs):
1206+
# mean = pd.Series(prices).rolling(mav).mean()
1207+
mean = pd.Series(prices).ewm(span=mav,adjust=False).mean()
1208+
if shift is not None:
1209+
mean = mean.shift(periods=shift[idx])
1210+
emaprices = mean.values
1211+
lw = config['_width_config']['line_width']
1212+
if mavc:
1213+
ax.plot(xdates, emaprices, linewidth=lw, color=next(mavc))
1214+
else:
1215+
ax.plot(xdates, emaprices, linewidth=lw)
1216+
mavp_list.append(emaprices)
1217+
return mavp_list
1218+
1219+
11611220
def _auto_secondary_y( panels, panid, ylo, yhi ):
11621221
# If mag(nitude) for this panel is not yet set, then set it
11631222
# here, as this is the first ydata to be plotted on this panel:

tests/reference_images/ema01.png

48.3 KB
Loading

tests/reference_images/ema02.png

72 KB
Loading

tests/reference_images/ema03.png

78.2 KB
Loading

tests/test_ema.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import os
2+
import os.path
3+
import glob
4+
import mplfinance as mpf
5+
import pandas as pd
6+
import matplotlib.pyplot as plt
7+
from matplotlib.testing.compare import compare_images
8+
9+
print('mpf.__version__ =',mpf.__version__) # for the record
10+
print('mpf.__file__ =',mpf.__file__) # for the record
11+
print("plt.rcParams['backend'] =",plt.rcParams['backend']) # for the record
12+
13+
base='ema'
14+
tdir = os.path.join('tests','test_images')
15+
refd = os.path.join('tests','reference_images')
16+
17+
globpattern = os.path.join(tdir,base+'*.png')
18+
oldtestfiles = glob.glob(globpattern)
19+
for fn in oldtestfiles:
20+
try:
21+
os.remove(fn)
22+
except:
23+
print('Error removing file "'+fn+'"')
24+
25+
IMGCOMP_TOLERANCE = 10.0 # this works fine for linux
26+
# IMGCOMP_TOLERANCE = 11.0 # required for a windows pass. (really 10.25 may do it).
27+
28+
_df = pd.DataFrame()
29+
def get_ema_data():
30+
global _df
31+
if len(_df) == 0:
32+
_df = pd.read_csv('./examples/data/yahoofinance-GOOG-20040819-20180120.csv',
33+
index_col='Date',parse_dates=True)
34+
return _df
35+
36+
37+
def create_ema_image(tname):
38+
39+
df = get_ema_data()
40+
df = df[-50:] # show last 50 data points only
41+
42+
ema25 = df['Close'].ewm(span=25.0, adjust=False).mean()
43+
mav25 = df['Close'].rolling(window=25).mean()
44+
45+
ap = [
46+
mpf.make_addplot(df, panel=1, type='ohlc', color='c',
47+
ylabel='mpf mav', mav=25, secondary_y=False),
48+
mpf.make_addplot(ema25, panel=2, type='line', width=2, color='c',
49+
ylabel='calculated', secondary_y=False),
50+
mpf.make_addplot(mav25, panel=2, type='line', width=2, color='blue',
51+
ylabel='calculated', secondary_y=False)
52+
]
53+
54+
# plot and save in `tname` path
55+
mpf.plot(df, ylabel="mpf ema", type='ohlc',
56+
ema=25, addplot=ap, panel_ratios=(1, 1), savefig=tname
57+
)
58+
59+
60+
def test_ema01():
61+
62+
fname = base+'01.png'
63+
tname = os.path.join(tdir,fname)
64+
rname = os.path.join(refd,fname)
65+
66+
create_ema_image(tname)
67+
68+
tsize = os.path.getsize(tname)
69+
print(glob.glob(tname),'[',tsize,'bytes',']')
70+
71+
rsize = os.path.getsize(rname)
72+
print(glob.glob(rname),'[',rsize,'bytes',']')
73+
74+
result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
75+
if result is not None:
76+
print('result=',result)
77+
assert result is None
78+
79+
def test_ema02():
80+
fname = base+'02.png'
81+
tname = os.path.join(tdir,fname)
82+
rname = os.path.join(refd,fname)
83+
84+
df = get_ema_data()
85+
df = df[-125:-35]
86+
87+
mpf.plot(df, type='candle', ema=(5,15,25), mav=(5,15,25), savefig=tname)
88+
89+
tsize = os.path.getsize(tname)
90+
print(glob.glob(tname),'[',tsize,'bytes',']')
91+
92+
rsize = os.path.getsize(rname)
93+
print(glob.glob(rname),'[',rsize,'bytes',']')
94+
95+
result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
96+
if result is not None:
97+
print('result=',result)
98+
assert result is None
99+
100+
def test_ema03():
101+
fname = base+'03.png'
102+
tname = os.path.join(tdir,fname)
103+
rname = os.path.join(refd,fname)
104+
105+
df = get_ema_data()
106+
df = df[-125:-35]
107+
108+
mac = ['red','orange','yellow','green','blue','purple']
109+
110+
mpf.plot(df, type='candle', ema=(5,10,15,25), mav=(5,15,25),
111+
mavcolors=mac, savefig=tname)
112+
113+
114+
tsize = os.path.getsize(tname)
115+
print(glob.glob(tname),'[',tsize,'bytes',']')
116+
117+
rsize = os.path.getsize(rname)
118+
print(glob.glob(rname),'[',rsize,'bytes',']')
119+
120+
result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
121+
if result is not None:
122+
print('result=',result)
123+
assert result is None
124+

0 commit comments

Comments
 (0)