Skip to content

Commit d777f43

Browse files
Merge pull request #391 from deltreey/mav_shift
add ability to shift moving average on plots
2 parents 3c02aea + d22a8cc commit d777f43

File tree

5 files changed

+64
-14
lines changed

5 files changed

+64
-14
lines changed

src/mplfinance/_arg_validators.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -103,22 +103,45 @@ def _get_valid_plot_types(plottype=None):
103103

104104

105105
def _mav_validator(mav_value):
106-
'''
106+
'''
107107
Value for mav (moving average) keyword may be:
108-
scalar int greater than 1, or tuple of ints, or list of ints (greater than 1).
109-
tuple or list limited to length of 7 moving averages (to keep the plot clean).
108+
scalar int greater than 1, or tuple of ints, or list of ints (each greater than 1)
109+
or a dict of `period` and `shift` each of which may be:
110+
scalar int, or tuple of ints, or list of ints: each `period` int must be greater than 1
110111
'''
111-
if isinstance(mav_value,int) and mav_value > 1:
112+
def _valid_mav(value, is_period=True):
113+
if not isinstance(value,(tuple,list,int)):
114+
return False
115+
if isinstance(value,int):
116+
return (value >= 2 or not is_period)
117+
# Must be a tuple or list here:
118+
for num in value:
119+
if not isinstance(num,int) or (is_period and num < 2):
120+
return False
112121
return True
113-
elif not isinstance(mav_value,tuple) and not isinstance(mav_value,list):
122+
123+
if not isinstance(mav_value,(tuple,list,int,dict)):
114124
return False
115125

116-
if not len(mav_value) < 8:
126+
if not isinstance(mav_value,dict):
127+
return _valid_mav(mav_value)
128+
129+
else: #isinstance(mav_value,dict)
130+
if 'period' not in mav_value: return False
131+
132+
period = mav_value['period']
133+
if not _valid_mav(period): return False
134+
135+
if 'shift' not in mav_value: return True
136+
137+
shift = mav_value['shift']
138+
if not _valid_mav(shift, False): return False
139+
if isinstance(period,int) and isinstance(shift,int): return True
140+
if isinstance(period,(tuple,list)) and isinstance(shift,(tuple,list)):
141+
if len(period) != len(shift): return False
142+
return True
117143
return False
118-
for num in mav_value:
119-
if not isinstance(num,int) and num > 1:
120-
return False
121-
return True
144+
122145

123146
def _hlines_validator(value):
124147
if isinstance(value,dict):

src/mplfinance/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
version_info = (0, 12, 7, 'alpha', 17)
2+
version_info = (0, 12, 7, 'alpha', 18)
33

44
_specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''}
55

src/mplfinance/plotting.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -978,8 +978,12 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
978978
mavgs = config['mav']
979979
mavp_list = []
980980
if mavgs is not None:
981+
shift = None
982+
if isinstance(mavgs,dict):
983+
shift = mavgs['shift']
984+
mavgs = mavgs['period']
981985
if isinstance(mavgs,int):
982-
mavgs = mavgs, # convert to tuple
986+
mavgs = mavgs, # convert to tuple
983987
if len(mavgs) > 7:
984988
mavgs = mavgs[0:7] # take at most 7
985989

@@ -988,8 +992,11 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
988992
else:
989993
mavc = None
990994

991-
for mav in mavgs:
992-
mavprices = pd.Series(prices).rolling(mav).mean().values
995+
for idx,mav in enumerate(mavgs):
996+
mean = pd.Series(prices).rolling(mav).mean()
997+
if shift is not None:
998+
mean = mean.shift(periods=shift[idx])
999+
mavprices = mean.values
9931000
lw = config['_width_config']['line_width']
9941001
if mavc:
9951002
ax.plot(xdates, mavprices, linewidth=lw, color=next(mavc))

tests/reference_images/addplot12.png

60.2 KB
Loading

tests/test_addplot.py

+20
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,23 @@ def test_addplot11(bolldata):
354354
print('result=',result)
355355
assert result is None
356356

357+
def test_addplot12(bolldata):
358+
359+
df = bolldata
360+
361+
fname = base+'12.png'
362+
tname = os.path.join(tdir,fname)
363+
rname = os.path.join(refd,fname)
364+
365+
mpf.plot(df,type='candle',volume=True,savefig=tname,mav={'period':(20,40,60), 'shift': [5,10,20]})
366+
367+
tsize = os.path.getsize(tname)
368+
print(glob.glob(tname),'[',tsize,'bytes',']')
369+
370+
rsize = os.path.getsize(rname)
371+
print(glob.glob(rname),'[',rsize,'bytes',']')
372+
373+
result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
374+
if result is not None:
375+
print('result=',result)
376+
assert result is None

0 commit comments

Comments
 (0)