Skip to content

Commit a59197d

Browse files
author
Ted
committed
add ability to shift moving average on plots
1 parent 3c02aea commit a59197d

File tree

4 files changed

+57
-8
lines changed

4 files changed

+57
-8
lines changed

src/mplfinance/_arg_validators.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,36 @@ def _mav_validator(mav_value):
110110
'''
111111
if isinstance(mav_value,int) and mav_value > 1:
112112
return True
113-
elif not isinstance(mav_value,tuple) and not isinstance(mav_value,list):
113+
elif not isinstance(mav_value,tuple) and not isinstance(mav_value,list) and not isinstance(mav_value,dict):
114114
return False
115115

116-
if not len(mav_value) < 8:
117-
return False
118-
for num in mav_value:
119-
if not isinstance(num,int) and num > 1:
116+
if isinstance(mav_value,dict):
117+
if 'scale' not in mav_value or not (isinstance(mav_value['scale'],tuple) or\
118+
isinstance(mav_value['scale'],int) or isinstance(mav_value['scale'], list)):
119+
return False
120+
if 'shift' in mav_value:
121+
if not (isinstance(mav_value['shift'],tuple) or isinstance(mav_value['shift'],int) or\
122+
isinstance(mav_value['shift'], list)):
123+
return False
124+
if isinstance(mav_value['scale'], int) and isinstance(mav_value['shift'], int):
125+
return True
126+
elif isinstance(mav_value['scale'], int) or isinstance(mav_value['shift'], int):
127+
return False
128+
if len(mav_value['scale']) == len(mav_value['shift']):
129+
for num in mav_value['scale']:
130+
if not isinstance(num, int) and num > 1:
131+
return False
132+
for num in mav_value['shift']:
133+
if not isinstance(num, int) and num > 1:
134+
return False
135+
return True
120136
return False
137+
elif not len(mav_value) < 8:
138+
return False
139+
else:
140+
for num in mav_value:
141+
if not isinstance(num,int) and num > 1:
142+
return False
121143
return True
122144

123145
def _hlines_validator(value):

src/mplfinance/plotting.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -978,8 +978,12 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
978978
mavgs = config['mav']
979979
mavp_list = []
980980
if mavgs is not None:
981+
shift = None
982+
if isinstance(mavgs,dict):
983+
shift = mavgs['shift']
984+
mavgs = mavgs['scale']
981985
if isinstance(mavgs,int):
982-
mavgs = mavgs, # convert to tuple
986+
mavgs = mavgs, # convert to tuple
983987
if len(mavgs) > 7:
984988
mavgs = mavgs[0:7] # take at most 7
985989

@@ -988,8 +992,11 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
988992
else:
989993
mavc = None
990994

991-
for mav in mavgs:
992-
mavprices = pd.Series(prices).rolling(mav).mean().values
995+
for idx,mav in enumerate(mavgs):
996+
mean = pd.Series(prices).rolling(mav).mean()
997+
if shift is not None:
998+
mean = mean.shift(periods=shift[idx])
999+
mavprices = mean.values
9931000
lw = config['_width_config']['line_width']
9941001
if mavc:
9951002
ax.plot(xdates, mavprices, linewidth=lw, color=next(mavc))

tests/reference_images/addplot12.png

60.2 KB
Loading

tests/test_addplot.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,23 @@ def test_addplot11(bolldata):
354354
print('result=',result)
355355
assert result is None
356356

357+
def test_addplot12(bolldata):
358+
359+
df = bolldata
360+
361+
fname = base+'12.png'
362+
tname = os.path.join(tdir,fname)
363+
rname = os.path.join(refd,fname)
364+
365+
mpf.plot(df,type='candle',volume=True,savefig=tname,mav={'scale':(20,40,60), 'shift': [5,10,20]})
366+
367+
tsize = os.path.getsize(tname)
368+
print(glob.glob(tname),'[',tsize,'bytes',']')
369+
370+
rsize = os.path.getsize(rname)
371+
print(glob.glob(rname),'[',rsize,'bytes',']')
372+
373+
result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
374+
if result is not None:
375+
print('result=',result)
376+
assert result is None

0 commit comments

Comments
 (0)