Skip to content

Commit 2aae88b

Browse files
author
Ted
committed
simpler code, and don't check for minimum size in shifts
1 parent fb8bccb commit 2aae88b

File tree

1 file changed

+34
-30
lines changed

1 file changed

+34
-30
lines changed

src/mplfinance/_arg_validators.py

+34-30
Original file line numberDiff line numberDiff line change
@@ -103,42 +103,46 @@ def _get_valid_plot_types(plottype=None):
103103

104104

105105
def _mav_validator(mav_value):
106-
'''
106+
'''
107107
Value for mav (moving average) keyword may be:
108-
scalar int greater than 1, or tuple of ints, or list of ints (greater than 1).
109-
tuple or list limited to length of 7 moving averages (to keep the plot clean).
108+
scalar int greater than 1, or tuple of ints, or list of ints (each greater than 1)
109+
or a dict of `period` and `shift` each of which may be:
110+
scalar int, or tuple of ints, or list of ints. `period` must be greater than 1
111+
while `shift` must be greater than 0.
110112
'''
111-
if isinstance(mav_value,int) and mav_value > 1:
113+
def _valid_mav(value, is_period=True):
114+
if not isinstance(value,(tuple,list,int)):
115+
return False
116+
if isinstance(value,int):
117+
return isinstance(value,int) and (value >= 2 or not is_period)
118+
# Must be a tuple or list here:
119+
for num in value:
120+
if not isinstance(num,int) or (is_period and num < 2):
121+
return False
112122
return True
113-
elif not isinstance(mav_value,(tuple,list,dict)):
123+
124+
if not isinstance(mav_value,(tuple,list,int,dict)):
114125
return False
115126

116-
if isinstance(mav_value,dict):
117-
if 'period' not in mav_value or not isinstance(mav_value['period'],(tuple,list,int)):
118-
return False
119-
if 'shift' in mav_value:
120-
if not isinstance(mav_value['shift'],(tuple,list,int)):
121-
return False
122-
if isinstance(mav_value['period'], int) and isinstance(mav_value['shift'], int):
123-
return True
124-
elif isinstance(mav_value['period'], int) or isinstance(mav_value['shift'], int):
125-
return False
126-
if len(mav_value['period']) == len(mav_value['shift']):
127-
for num in mav_value['period']:
128-
if not isinstance(num, int) and num > 1:
129-
return False
130-
for num in mav_value['shift']:
131-
if not isinstance(num, int) and num > 1:
132-
return False
133-
return True
134-
return False
135-
elif not len(mav_value) < 8:
127+
if not isinstance(mav_value,dict):
128+
return _valid_mav(mav_value)
129+
130+
else: #isinstance(mav_value,dict)
131+
if 'period' not in mav_value: return False
132+
133+
period = mav_value['period']
134+
if not _valid_mav(period): return False
135+
136+
if 'shift' not in mav_value: return True
137+
138+
shift = mav_value['shift']
139+
if not _valid_mav(shift, False): return False
140+
if isinstance(period,int) and isinstance(shift,int): return True
141+
if isinstance(period,(tuple,list)) and isinstance(shift,(tuple,list)):
142+
if len(period) != len(shift): return False
143+
return True
136144
return False
137-
else:
138-
for num in mav_value:
139-
if not isinstance(num,int) and num > 1:
140-
return False
141-
return True
145+
142146

143147
def _hlines_validator(value):
144148
if isinstance(value,dict):

0 commit comments

Comments
 (0)