Skip to content

Commit a488847

Browse files
committed
Allow Optional lists!
1 parent ede221d commit a488847

File tree

2 files changed

+151
-106
lines changed

2 files changed

+151
-106
lines changed

flask_parameter_validation/parameter_validation.py

Lines changed: 150 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66
from flask import request
77

88

9+
# Mock flask request class, for creating custom requests
10+
class MockRequest:
11+
def __init__(self, default, name, annotation):
12+
self.default = default
13+
self.name = name
14+
self.annotation = annotation
15+
16+
17+
# Main validation class
918
class ValidateParameters:
1019

1120
def default_error_function(self, error_message):
@@ -20,10 +29,8 @@ def __init__(self, error_function=None):
2029
self.error_function = self.default_error_function
2130

2231
def __call__(self, f):
23-
decorator_self = self
2432

2533
def nested_func(**kwargs):
26-
parsed_inputs = {}
2734
# Get all request inputs as dicts
2835
request_inputs = {
2936
Route: kwargs.copy(),
@@ -34,113 +41,151 @@ def nested_func(**kwargs):
3441
}
3542
# Get function arguments
3643
function_args = signature(f).parameters
37-
# Iterate through all
38-
for arg in function_args.values():
39-
param_type = arg.default # ie. Route(), Json()
40-
param_name = arg.name # ie. id, username
41-
param_annotation = arg.annotation or typing.Any # ie. str, int
42-
is_list_or_union = hasattr(param_annotation, "__args__")
43-
# Ensure param type is valid
44-
if param_type.__class__ not in request_inputs.keys():
45-
return self.error_function("Invalid parameter type.")
46-
47-
# Get user input for given param type and name
48-
user_input = request_inputs[param_type.__class__].get(
49-
param_name
50-
)
5144

52-
# If default is given, set it
53-
if user_input is None and param_type.default is not None:
54-
user_input = param_type.default
45+
parsed_inputs = self.validate_parameters(
46+
request_inputs, function_args
47+
)
48+
if isinstance(parsed_inputs, dict):
49+
return f(**parsed_inputs)
50+
else:
51+
return parsed_inputs
5552

56-
# If no default and no input, error
57-
elif user_input is None:
58-
ptype = param_type.name
59-
error_response = self.error_function(
60-
f"Required {ptype} parameter '{param_name}' not given."
61-
)
62-
# If "None" is allowed in Union, then continue
63-
if is_list_or_union:
64-
if type(None) not in param_annotation.__args__:
65-
return error_response
66-
else:
67-
return error_response
68-
69-
# If typing's Any or ClassVar, don't validate type
70-
if isinstance(param_annotation, typing._SpecialForm):
71-
valid = True
72-
allowed_types = ["all"]
53+
nested_func.__name__ = f.__name__
54+
return nested_func
7355

74-
# Otherwise, validate type
75-
else:
76-
allowed_types = []
77-
# If List or Union, get all "inner" types
78-
if is_list_or_union:
79-
allowed_types = param_annotation.__args__
80-
else:
81-
allowed_types = (param_annotation,)
82-
83-
# If query parameter, try converting to match
84-
if param_type.__class__ == Query and type(user_input) == str: # noqa: E501
85-
# int conversion
86-
if int in allowed_types:
87-
try:
88-
user_input = int(user_input)
89-
except ValueError:
90-
pass
91-
# float conversion
92-
if float in allowed_types:
93-
try:
94-
user_input = float(user_input)
95-
except ValueError:
96-
pass
97-
# bool conversion
98-
elif bool in allowed_types:
99-
if user_input.lower() == "true":
100-
user_input = True
101-
elif user_input.lower() == "false":
102-
user_input = False
103-
104-
# Check if type matches annotation
105-
annotation_is_list = False
106-
if hasattr(param_annotation, "_name"):
107-
annotation_is_list = param_annotation._name == "List"
108-
if type(user_input) == list and annotation_is_list:
109-
# If input is a list, validate all items
110-
valid = all(
111-
isinstance(i, allowed_types) for i in user_input
112-
)
113-
elif type(user_input) != list and annotation_is_list:
114-
allowed_types = [list]
115-
valid = False
56+
def validate_parameters(self, request_inputs, function_args, repeat=False):
57+
# Iterate through all
58+
parsed_inputs = {}
59+
for arg in function_args.values():
60+
param_type = arg.default # ie. Route(), Json()
61+
param_name = arg.name # ie. id, username
62+
param_annotation = arg.annotation or typing.Any # ie. str, int
63+
is_list_or_union = hasattr(param_annotation, "__args__")
64+
# Ensure param type is valid
65+
66+
if param_type.__class__ not in request_inputs.keys():
67+
return self.error_function("Invalid parameter type.")
68+
69+
# Get user input for given param type and name
70+
user_input = request_inputs[param_type.__class__].get(
71+
param_name
72+
)
73+
74+
# If default is given, set it
75+
if user_input is None and param_type.default is not None:
76+
user_input = param_type.default
77+
78+
# If no default and no input, error
79+
elif user_input is None:
80+
ptype = param_type.name
81+
error_response = self.error_function(
82+
f"Required {ptype} parameter '{param_name}' not given."
83+
)
84+
# If "None" is allowed in Union, then continue
85+
if is_list_or_union:
86+
if type(None) not in param_annotation.__args__:
87+
return error_response
11688
else:
117-
# If not list, just validate singular data type
118-
valid = isinstance(user_input, allowed_types)
119-
120-
# Continue or error depending on validity
121-
if valid:
122-
try:
123-
param_type.validate(user_input)
124-
parsed_inputs[param_name] = user_input
125-
except Exception as e:
126-
return self.error_function(
127-
f"Parameter '{param_name}' {e}"
128-
)
89+
continue
12990
else:
130-
if type(None) in allowed_types:
131-
allowed_types = list(allowed_types)
132-
allowed_types.remove(type(None))
133-
startphrase = "Optional parameter"
134-
else:
135-
startphrase = "Parameter"
136-
types = "/".join(t.__name__ for t in allowed_types)
137-
if annotation_is_list:
138-
types = "List[" + types + "]"
139-
return decorator_self.error_function(
140-
f"{startphrase} '{param_name}' should be type {types}."
91+
return error_response
92+
93+
# If typing's Any or ClassVar, don't validate type
94+
if isinstance(param_annotation, typing._SpecialForm):
95+
valid = True
96+
allowed_types = ["all"]
97+
98+
# Otherwise, validate type
99+
else:
100+
allowed_types = []
101+
# If List or Union, get all "inner" types
102+
if is_list_or_union:
103+
allowed_types = param_annotation.__args__
104+
# Validate any embedded lists
105+
can_skip = False
106+
for allowed_type in allowed_types:
107+
# If is list or union...
108+
if hasattr(allowed_type, "__args__"):
109+
# Run function recursively
110+
parsed = self.validate_parameters({
111+
param_type.__class__: {param_name: user_input},
112+
}, {
113+
param_name: MockRequest(
114+
param_type,
115+
param_name,
116+
allowed_type
117+
)
118+
}, repeat=True)
119+
# Return error if present
120+
if not isinstance(parsed, dict):
121+
return parsed
122+
# Update and continue to next
123+
parsed_inputs.update(parsed)
124+
can_skip = True
125+
break
126+
if can_skip:
127+
continue
128+
else:
129+
allowed_types = (param_annotation,)
130+
131+
# If query parameter, try converting to match
132+
if param_type.__class__ == Query and type(user_input) == str: # noqa: E501
133+
# int conversion
134+
if int in allowed_types:
135+
try:
136+
user_input = int(user_input)
137+
except ValueError:
138+
pass
139+
# float conversion
140+
if float in allowed_types:
141+
try:
142+
user_input = float(user_input)
143+
except ValueError:
144+
pass
145+
# bool conversion
146+
elif bool in allowed_types:
147+
if user_input.lower() == "true":
148+
user_input = True
149+
elif user_input.lower() == "false":
150+
user_input = False
151+
152+
# Check if type matches annotation
153+
annotation_is_list = False
154+
if hasattr(param_annotation, "_name"):
155+
annotation_is_list = param_annotation._name == "List"
156+
if type(user_input) == list and annotation_is_list:
157+
# If input is a list, validate all items
158+
valid = all(
159+
isinstance(i, allowed_types) for i in user_input
141160
)
161+
elif type(user_input) != list and annotation_is_list:
162+
allowed_types = [list]
163+
valid = False
164+
else:
165+
# If not list, just validate singular data type
166+
valid = isinstance(user_input, allowed_types)
167+
168+
# Continue or error depending on validity
169+
if valid:
170+
try:
171+
param_type.validate(user_input)
172+
parsed_inputs[param_name] = user_input
173+
except Exception as e:
174+
return self.error_function(
175+
f"Parameter '{param_name}' {e}"
176+
)
177+
else:
178+
if type(None) in allowed_types:
179+
allowed_types = list(allowed_types)
180+
allowed_types.remove(type(None))
181+
startphrase = "Optional parameter"
182+
else:
183+
startphrase = "Parameter"
184+
types = "/".join(t.__name__ for t in allowed_types)
185+
if annotation_is_list and allowed_types[0] is not list:
186+
types = "List[" + types + "]"
187+
return self.error_function(
188+
f"{startphrase} '{param_name}' should be type {types}."
189+
)
142190

143-
return f(**parsed_inputs)
144-
145-
nested_func.__name__ = f.__name__
146-
return nested_func
191+
return parsed_inputs

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
setup(
1313
name='Flask-Parameter-Validation',
14-
version='1.0.21',
14+
version='1.0.22',
1515
url='https://github.com/Ge0rg3/Flask-Parameter-Validation',
1616
license='MIT',
1717
author='George Omnet',

0 commit comments

Comments
 (0)