Skip to content

Commit b3cfb01

Browse files
committed
Fix awaiting of async view function.
1 parent 42769af commit b3cfb01

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

flask_parameter_validation/parameter_validation.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import re
55
from inspect import signature
6-
from flask import request
6+
from flask import request, Response
77
from werkzeug.datastructures import ImmutableMultiDict
88
from werkzeug.exceptions import BadRequest
99
from .exceptions import (InvalidParameterTypeError, MissingInputError,
@@ -40,8 +40,13 @@ def __call__(self, f):
4040
}
4141
fn_list[fsig] = fdocs
4242

43-
@functools.wraps(f)
44-
async def nested_func(**kwargs):
43+
def nested_func_helper(**kwargs):
44+
"""
45+
Validates the inputs of a Flask route or returns an error. Returns
46+
are wrapped in a dictionary with a flag to let nested_func() know
47+
if it should unpack the resulting dictionary of inputs as kwargs,
48+
or just return the error message.
49+
"""
4550
# Step 1 - Get expected input details as dict
4651
expected_inputs = signature(f).parameters
4752

@@ -54,7 +59,7 @@ async def nested_func(**kwargs):
5459
try:
5560
json_input = request.json
5661
except BadRequest:
57-
return {"error": "Could not parse JSON."}, 400
62+
return {"error": ({"error": "Could not parse JSON."}, 400), "validated": False}
5863

5964
# Step 3 - Extract list of parameters expected to be lists (otherwise all values are converted to lists)
6065
expected_list_params = []
@@ -79,18 +84,32 @@ async def nested_func(**kwargs):
7984
try:
8085
new_input = self.validate(expected, request_inputs)
8186
except (MissingInputError, ValidationError) as e:
82-
return {"error": str(e)}, 400
87+
return {"error": ({"error": str(e)}, 400), "validated": False}
8388
else:
8489
try:
8590
new_input = self.validate(expected, request_inputs)
8691
except Exception as e:
87-
return self.custom_error_handler(e)
92+
return {"error": self.custom_error_handler(e), "validated": False}
8893
validated_inputs[expected.name] = new_input
8994

90-
if asyncio.iscoroutinefunction(f):
91-
return await f(**validated_inputs)
92-
else:
93-
return f(**validated_inputs)
95+
return {"inputs": validated_inputs, "validated": True}
96+
97+
if asyncio.iscoroutinefunction(f):
98+
# If the view function is async, return and await a coroutine
99+
@functools.wraps(f)
100+
async def nested_func(**kwargs):
101+
validated_inputs = nested_func_helper(**kwargs)
102+
if validated_inputs["validated"]:
103+
return await f(**validated_inputs["inputs"])
104+
return validated_inputs["error"]
105+
else:
106+
# If the view function is not async, return a function
107+
@functools.wraps(f)
108+
def nested_func(**kwargs):
109+
validated_inputs = nested_func_helper(**kwargs)
110+
if validated_inputs["validated"]:
111+
return f(**validated_inputs["inputs"])
112+
return validated_inputs["error"]
94113

95114
nested_func.__name__ = f.__name__
96115
return nested_func

0 commit comments

Comments
 (0)