Skip to content

Commit 4db565c

Browse files
Merge pull request #563 from andrewrgarcia/master
Add Exponential Moving Average kwarg `ema=<int or <list|tuple> of ints>`
2 parents bf1a603 + f784363 commit 4db565c

File tree

6 files changed

+188
-5
lines changed

6 files changed

+188
-5
lines changed

src/mplfinance/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
version_info = (0, 12, 9, 'beta', 2)
1+
version_info = (0, 12, 9, 'beta', 3)
22

33
_specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''}
44

src/mplfinance/plotting.py

+63-4
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,17 @@ def _valid_plot_kwargs():
119119
'mav' : { 'Default' : None,
120120
'Description' : 'Moving Average window size(s); (int or tuple of ints)',
121121
'Validator' : _mav_validator },
122+
123+
'ema' : { 'Default' : None,
124+
'Description' : 'Exponential Moving Average window size(s); (int or tuple of ints)',
125+
'Validator' : _mav_validator },
122126

127+
'mavcolors' : { 'Default' : None,
128+
'Description' : 'color cycle for moving averages (list or tuple of colors)'+
129+
'(overrides mpf style mavcolors).',
130+
'Validator' : lambda value: isinstance(value,(list,tuple)) and
131+
all([mcolors.is_color_like(v) for v in value]) },
132+
123133
'renko_params' : { 'Default' : dict(),
124134
'Description' : 'dict of renko parameters; call `mpf.kwarg_help("renko_params")`',
125135
'Validator' : lambda value: isinstance(value,dict) },
@@ -450,6 +460,13 @@ def plot( data, **kwargs ):
450460
else:
451461
raise TypeError('style should be a `dict`; why is it not?')
452462

463+
if config['mavcolors'] is not None:
464+
config['_ma_color_cycle'] = cycle(config['mavcolors'])
465+
elif style['mavcolors'] is not None:
466+
config['_ma_color_cycle'] = cycle(style['mavcolors'])
467+
else:
468+
config['_ma_color_cycle'] = None
469+
453470
if not external_axes_mode:
454471
fig = plt.figure()
455472
_adjust_figsize(fig,config)
@@ -528,8 +545,10 @@ def plot( data, **kwargs ):
528545

529546
if ptype in VALID_PMOVE_TYPES:
530547
mavprices = _plot_mav(axA1,config,xdates,pmove_avgvals)
548+
emaprices = _plot_ema(axA1, config, xdates, pmove_avgvals)
531549
else:
532550
mavprices = _plot_mav(axA1,config,xdates,closes)
551+
emaprices = _plot_ema(axA1, config, xdates, closes)
533552

534553
avg_dist_between_points = (xdates[-1] - xdates[0]) / float(len(xdates))
535554
if not config['tight_layout']:
@@ -595,6 +614,13 @@ def plot( data, **kwargs ):
595614
else:
596615
for jj in range(0,len(mav)):
597616
retdict['mav' + str(mav[jj])] = mavprices[jj]
617+
if config['ema'] is not None:
618+
ema = config['ema']
619+
if len(ema) != len(emaprices):
620+
warnings.warn('len(ema)='+str(len(ema))+' BUT len(emaprices)='+str(len(emaprices)))
621+
else:
622+
for jj in range(0, len(ema)):
623+
retdict['ema' + str(ema[jj])] = emaprices[jj]
598624
retdict['minx'] = minx
599625
retdict['maxx'] = maxx
600626
retdict['miny'] = miny
@@ -1129,10 +1155,7 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
11291155
if len(mavgs) > 7:
11301156
mavgs = mavgs[0:7] # take at most 7
11311157

1132-
if style['mavcolors'] is not None:
1133-
mavc = cycle(style['mavcolors'])
1134-
else:
1135-
mavc = None
1158+
mavc = config['_ma_color_cycle']
11361159

11371160
for idx,mav in enumerate(mavgs):
11381161
mean = pd.Series(prices).rolling(mav).mean()
@@ -1147,6 +1170,42 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
11471170
mavp_list.append(mavprices)
11481171
return mavp_list
11491172

1173+
1174+
def _plot_ema(ax,config,xdates,prices,apmav=None,apwidth=None):
1175+
'''ema: exponential moving average'''
1176+
style = config['style']
1177+
if apmav is not None:
1178+
mavgs = apmav
1179+
else:
1180+
mavgs = config['ema']
1181+
mavp_list = []
1182+
if mavgs is not None:
1183+
shift = None
1184+
if isinstance(mavgs,dict):
1185+
shift = mavgs['shift']
1186+
mavgs = mavgs['period']
1187+
if isinstance(mavgs,int):
1188+
mavgs = mavgs, # convert to tuple
1189+
if len(mavgs) > 7:
1190+
mavgs = mavgs[0:7] # take at most 7
1191+
1192+
mavc = config['_ma_color_cycle']
1193+
1194+
for idx,mav in enumerate(mavgs):
1195+
# mean = pd.Series(prices).rolling(mav).mean()
1196+
mean = pd.Series(prices).ewm(span=mav,adjust=False).mean()
1197+
if shift is not None:
1198+
mean = mean.shift(periods=shift[idx])
1199+
emaprices = mean.values
1200+
lw = config['_width_config']['line_width']
1201+
if mavc:
1202+
ax.plot(xdates, emaprices, linewidth=lw, color=next(mavc))
1203+
else:
1204+
ax.plot(xdates, emaprices, linewidth=lw)
1205+
mavp_list.append(emaprices)
1206+
return mavp_list
1207+
1208+
11501209
def _auto_secondary_y( panels, panid, ylo, yhi ):
11511210
# If mag(nitude) for this panel is not yet set, then set it
11521211
# here, as this is the first ydata to be plotted on this panel:

tests/reference_images/ema01.png

48.3 KB
Loading

tests/reference_images/ema02.png

72 KB
Loading

tests/reference_images/ema03.png

78.2 KB
Loading

tests/test_ema.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import os
2+
import os.path
3+
import glob
4+
import mplfinance as mpf
5+
import pandas as pd
6+
import matplotlib.pyplot as plt
7+
from matplotlib.testing.compare import compare_images
8+
9+
print('mpf.__version__ =',mpf.__version__) # for the record
10+
print('mpf.__file__ =',mpf.__file__) # for the record
11+
print("plt.rcParams['backend'] =",plt.rcParams['backend']) # for the record
12+
13+
base='ema'
14+
tdir = os.path.join('tests','test_images')
15+
refd = os.path.join('tests','reference_images')
16+
17+
globpattern = os.path.join(tdir,base+'*.png')
18+
oldtestfiles = glob.glob(globpattern)
19+
for fn in oldtestfiles:
20+
try:
21+
os.remove(fn)
22+
except:
23+
print('Error removing file "'+fn+'"')
24+
25+
IMGCOMP_TOLERANCE = 10.0 # this works fine for linux
26+
# IMGCOMP_TOLERANCE = 11.0 # required for a windows pass. (really 10.25 may do it).
27+
28+
_df = pd.DataFrame()
29+
def get_ema_data():
30+
global _df
31+
if len(_df) == 0:
32+
_df = pd.read_csv('./examples/data/yahoofinance-GOOG-20040819-20180120.csv',
33+
index_col='Date',parse_dates=True)
34+
return _df
35+
36+
37+
def create_ema_image(tname):
38+
39+
df = get_ema_data()
40+
df = df[-50:] # show last 50 data points only
41+
42+
ema25 = df['Close'].ewm(span=25.0, adjust=False).mean()
43+
mav25 = df['Close'].rolling(window=25).mean()
44+
45+
ap = [
46+
mpf.make_addplot(df, panel=1, type='ohlc', color='c',
47+
ylabel='mpf mav', mav=25, secondary_y=False),
48+
mpf.make_addplot(ema25, panel=2, type='line', width=2, color='c',
49+
ylabel='calculated', secondary_y=False),
50+
mpf.make_addplot(mav25, panel=2, type='line', width=2, color='blue',
51+
ylabel='calculated', secondary_y=False)
52+
]
53+
54+
# plot and save in `tname` path
55+
mpf.plot(df, ylabel="mpf ema", type='ohlc',
56+
ema=25, addplot=ap, panel_ratios=(1, 1), savefig=tname
57+
)
58+
59+
60+
def test_ema01():
61+
62+
fname = base+'01.png'
63+
tname = os.path.join(tdir,fname)
64+
rname = os.path.join(refd,fname)
65+
66+
create_ema_image(tname)
67+
68+
tsize = os.path.getsize(tname)
69+
print(glob.glob(tname),'[',tsize,'bytes',']')
70+
71+
rsize = os.path.getsize(rname)
72+
print(glob.glob(rname),'[',rsize,'bytes',']')
73+
74+
result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
75+
if result is not None:
76+
print('result=',result)
77+
assert result is None
78+
79+
def test_ema02():
80+
fname = base+'02.png'
81+
tname = os.path.join(tdir,fname)
82+
rname = os.path.join(refd,fname)
83+
84+
df = get_ema_data()
85+
df = df[-125:-35]
86+
87+
mpf.plot(df, type='candle', ema=(5,15,25), mav=(5,15,25), savefig=tname)
88+
89+
tsize = os.path.getsize(tname)
90+
print(glob.glob(tname),'[',tsize,'bytes',']')
91+
92+
rsize = os.path.getsize(rname)
93+
print(glob.glob(rname),'[',rsize,'bytes',']')
94+
95+
result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
96+
if result is not None:
97+
print('result=',result)
98+
assert result is None
99+
100+
def test_ema03():
101+
fname = base+'03.png'
102+
tname = os.path.join(tdir,fname)
103+
rname = os.path.join(refd,fname)
104+
105+
df = get_ema_data()
106+
df = df[-125:-35]
107+
108+
mac = ['red','orange','yellow','green','blue','purple']
109+
110+
mpf.plot(df, type='candle', ema=(5,10,15,25), mav=(5,15,25),
111+
mavcolors=mac, savefig=tname)
112+
113+
114+
tsize = os.path.getsize(tname)
115+
print(glob.glob(tname),'[',tsize,'bytes',']')
116+
117+
rsize = os.path.getsize(rname)
118+
print(glob.glob(rname),'[',rsize,'bytes',']')
119+
120+
result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
121+
if result is not None:
122+
print('result=',result)
123+
assert result is None
124+

0 commit comments

Comments
 (0)