Skip to content

Commit 7b99414

Browse files
authored
Merge pull request #41 from maxwrimlinger/upstream_decorators
Fix Async View Handling
2 parents 1fd539b + 1e6e6b5 commit 7b99414

19 files changed

+1163
-18
lines changed

flask_parameter_validation/parameter_validation.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import re
55
from inspect import signature
6-
from flask import request
6+
from flask import request, Response
77
from werkzeug.datastructures import ImmutableMultiDict
88
from werkzeug.exceptions import BadRequest
99
from .exceptions import (InvalidParameterTypeError, MissingInputError,
@@ -40,8 +40,13 @@ def __call__(self, f):
4040
}
4141
fn_list[fsig] = fdocs
4242

43-
@functools.wraps(f)
44-
async def nested_func(**kwargs):
43+
def nested_func_helper(**kwargs):
44+
"""
45+
Validates the inputs of a Flask route or returns an error. Returns
46+
are wrapped in a dictionary with a flag to let nested_func() know
47+
if it should unpack the resulting dictionary of inputs as kwargs,
48+
or just return the error message.
49+
"""
4550
# Step 1 - Get expected input details as dict
4651
expected_inputs = signature(f).parameters
4752

@@ -54,7 +59,7 @@ async def nested_func(**kwargs):
5459
try:
5560
json_input = request.json
5661
except BadRequest:
57-
return {"error": "Could not parse JSON."}, 400
62+
return {"error": ({"error": "Could not parse JSON."}, 400), "validated": False}
5863

5964
# Step 3 - Extract list of parameters expected to be lists (otherwise all values are converted to lists)
6065
expected_list_params = []
@@ -79,18 +84,32 @@ async def nested_func(**kwargs):
7984
try:
8085
new_input = self.validate(expected, request_inputs)
8186
except (MissingInputError, ValidationError) as e:
82-
return {"error": str(e)}, 400
87+
return {"error": ({"error": str(e)}, 400), "validated": False}
8388
else:
8489
try:
8590
new_input = self.validate(expected, request_inputs)
8691
except Exception as e:
87-
return self.custom_error_handler(e)
92+
return {"error": self.custom_error_handler(e), "validated": False}
8893
validated_inputs[expected.name] = new_input
8994

90-
if asyncio.iscoroutinefunction(f):
91-
return await f(**validated_inputs)
92-
else:
93-
return f(**validated_inputs)
95+
return {"inputs": validated_inputs, "validated": True}
96+
97+
if asyncio.iscoroutinefunction(f):
98+
# If the view function is async, return and await a coroutine
99+
@functools.wraps(f)
100+
async def nested_func(**kwargs):
101+
validated_inputs = nested_func_helper(**kwargs)
102+
if validated_inputs["validated"]:
103+
return await f(**validated_inputs["inputs"])
104+
return validated_inputs["error"]
105+
else:
106+
# If the view function is not async, return a function
107+
@functools.wraps(f)
108+
def nested_func(**kwargs):
109+
validated_inputs = nested_func_helper(**kwargs)
110+
if validated_inputs["validated"]:
111+
return f(**validated_inputs["inputs"])
112+
return validated_inputs["error"]
94113

95114
nested_func.__name__ = f.__name__
96115
return nested_func

flask_parameter_validation/test/test_file_params.py

+22
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,28 @@ def test_required_file(client):
1717
assert "error" in r.json
1818

1919

20+
def test_required_file_decorator(client):
21+
url = "/file/decorator/required"
22+
# Test that we receive a success response if a file is provided
23+
r = client.post(url, data={"v": (resources / "test.json").open("rb")})
24+
assert "success" in r.json
25+
assert r.json["success"]
26+
# Test that we receive an error if a file is not provided
27+
r = client.post(url)
28+
assert "error" in r.json
29+
30+
31+
def test_required_file_async_decorator(client):
32+
url = "/file/async_decorator/required"
33+
# Test that we receive a success response if a file is provided
34+
r = client.post(url, data={"v": (resources / "test.json").open("rb")})
35+
assert "success" in r.json
36+
assert r.json["success"]
37+
# Test that we receive an error if a file is not provided
38+
r = client.post(url)
39+
assert "error" in r.json
40+
41+
2042
def test_optional_file(client):
2143
url = "/file/optional"
2244
# Test that we receive a success response if a file is provided

flask_parameter_validation/test/test_form_params.py

+22
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,28 @@ def test_required_str(client):
2525
assert "error" in r.json
2626

2727

28+
def test_required_str_decorator(client):
29+
url = "/form/str/decorator/required"
30+
# Test that present input yields input value
31+
r = client.post(url, data={"v": "v"})
32+
assert "v" in r.json
33+
assert r.json["v"] == "v"
34+
# Test that missing input yields error
35+
r = client.post(url)
36+
assert "error" in r.json
37+
38+
39+
def test_required_str_async_decorator(client):
40+
url = "/form/str/async_decorator/required"
41+
# Test that present input yields input value
42+
r = client.post(url, data={"v": "v"})
43+
assert "v" in r.json
44+
assert r.json["v"] == "v"
45+
# Test that missing input yields error
46+
r = client.post(url)
47+
assert "error" in r.json
48+
49+
2850
def test_optional_str(client):
2951
url = "/form/str/optional"
3052
# Test that missing input yields None

flask_parameter_validation/test/test_json_params.py

+48
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,54 @@ def test_dict_default(client):
952952
assert opt == r.json["opt"]
953953

954954

955+
def test_dict_default_decorator(client):
956+
url = "/json/dict/decorator/default"
957+
# Test that present dict yields input values
958+
n_opt = {"e": "f"}
959+
opt = {"g": "h"}
960+
r = client.post(url, json={"n_opt": n_opt, "opt": opt})
961+
assert "n_opt" in r.json
962+
assert "opt" in r.json
963+
assert type(r.json["n_opt"]) is dict
964+
assert type(r.json["opt"]) is dict
965+
assert n_opt == r.json["n_opt"]
966+
assert opt == r.json["opt"]
967+
# Test that missing dict yields default values
968+
n_opt = {"a": "b"}
969+
opt = {"c": "d"}
970+
r = client.post(url)
971+
assert "n_opt" in r.json
972+
assert "opt" in r.json
973+
assert type(r.json["n_opt"]) is dict
974+
assert type(r.json["opt"]) is dict
975+
assert n_opt == r.json["n_opt"]
976+
assert opt == r.json["opt"]
977+
978+
979+
def test_dict_default_async_decorator(client):
980+
url = "/json/dict/async_decorator/default"
981+
# Test that present dict yields input values
982+
n_opt = {"e": "f"}
983+
opt = {"g": "h"}
984+
r = client.post(url, json={"n_opt": n_opt, "opt": opt})
985+
assert "n_opt" in r.json
986+
assert "opt" in r.json
987+
assert type(r.json["n_opt"]) is dict
988+
assert type(r.json["opt"]) is dict
989+
assert n_opt == r.json["n_opt"]
990+
assert opt == r.json["opt"]
991+
# Test that missing dict yields default values
992+
n_opt = {"a": "b"}
993+
opt = {"c": "d"}
994+
r = client.post(url)
995+
assert "n_opt" in r.json
996+
assert "opt" in r.json
997+
assert type(r.json["n_opt"]) is dict
998+
assert type(r.json["opt"]) is dict
999+
assert n_opt == r.json["n_opt"]
1000+
assert opt == r.json["opt"]
1001+
1002+
9551003
def test_dict_func(client):
9561004
url = "/json/dict/func"
9571005
# Test that dict passing func yields input value

0 commit comments

Comments
 (0)