Skip to content

Commit c601321

Browse files
committed
Internal: improve validation of the aggregates endpoint (#341)
Use a Pydantic model instead of manual validation.
1 parent 6ea6663 commit c601321

File tree

3 files changed

+67
-15
lines changed

3 files changed

+67
-15
lines changed

src/aleph/web/controllers/aggregates.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,27 @@
1+
from typing import List, Optional
2+
13
from aiohttp import web
4+
from pydantic import BaseModel, validator, ValidationError
25

36
from aleph.model.messages import get_computed_address_aggregates
7+
from .utils import LIST_FIELD_SEPARATOR
8+
9+
10+
DEFAULT_LIMIT = 1000
11+
12+
13+
class AggregatesQueryParams(BaseModel):
14+
keys: Optional[List[str]] = None
15+
limit: int = DEFAULT_LIMIT
16+
17+
@validator(
18+
"keys",
19+
pre=True,
20+
)
21+
def split_str(cls, v):
22+
if isinstance(v, str):
23+
return v.split(LIST_FIELD_SEPARATOR)
24+
return v
425

526

627
async def address_aggregate(request):
@@ -10,15 +31,15 @@ async def address_aggregate(request):
1031

1132
address = request.match_info["address"]
1233

13-
keys = request.query.get("keys", None)
14-
if keys is not None:
15-
keys = keys.split(",")
16-
17-
limit = request.query.get("limit", "1000")
18-
limit = int(limit)
34+
try:
35+
query_params = AggregatesQueryParams.parse_obj(request.query)
36+
except ValidationError as e:
37+
raise web.HTTPUnprocessableEntity(
38+
text=e.json(), content_type="application/json"
39+
)
1940

2041
aggregates = await get_computed_address_aggregates(
21-
address_list=[address], key_list=keys, limit=limit
42+
address_list=[address], key_list=query_params.keys, limit=query_params.limit
2243
)
2344

2445
if not aggregates.get(address):

src/aleph/web/controllers/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
PER_PAGE = 20
99
PER_PAGE_SUMMARY = 50
10+
LIST_FIELD_SEPARATOR = ","
1011

1112

1213
class Pagination(object):

tests/api/test_aggregates.py

+38-8
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def merge_content(_messages: List[Dict]) -> Dict:
4343
aggregates = []
4444

4545
for key, group in itertools.groupby(
46-
sorted(messages, key=lambda msg: msg["content"]["key"]),
47-
lambda msg: msg["content"]["key"],
46+
sorted(messages, key=lambda msg: msg["content"]["key"]),
47+
lambda msg: msg["content"]["key"],
4848
):
4949
sorted_messages = sorted(group, key=lambda msg: msg["time"])
5050
aggregates.append(merge_content(sorted_messages))
@@ -99,22 +99,30 @@ async def test_get_aggregates(ccn_api_client, fixture_aggregates: List[Dict]):
9999

100100

101101
@pytest.mark.asyncio
102-
async def test_get_aggregates_filter_by_key(ccn_api_client, fixture_aggregates: List[Dict]):
102+
async def test_get_aggregates_filter_by_key(
103+
ccn_api_client, fixture_aggregates: List[Dict]
104+
):
103105
"""
104106
Tests the 'keys' query parameter.
105107
"""
106108

107109
address, key = ADDRESS_1, "test_target"
108-
aggregates = await get_aggregates_expect_success(ccn_api_client, address=address, keys=key)
110+
aggregates = await get_aggregates_expect_success(
111+
ccn_api_client, address=address, keys=key
112+
)
109113
assert aggregates["address"] == address
110114
assert aggregates["data"][key] == EXPECTED_AGGREGATES[address][key]
111115

112116
# Multiple keys
113117
address, keys = ADDRESS_1, ["test_target", "test_reference"]
114-
aggregates = await get_aggregates_expect_success(ccn_api_client, address=address, keys=",".join(keys))
118+
aggregates = await get_aggregates_expect_success(
119+
ccn_api_client, address=address, keys=",".join(keys)
120+
)
115121
assert aggregates["address"] == address
116122
for key in keys:
117-
assert aggregates["data"][key] == EXPECTED_AGGREGATES[address][key], f"Key {key} does not match"
123+
assert (
124+
aggregates["data"][key] == EXPECTED_AGGREGATES[address][key]
125+
), f"Key {key} does not match"
118126

119127

120128
@pytest.mark.asyncio
@@ -124,13 +132,17 @@ async def test_get_aggregates_limit(ccn_api_client, fixture_aggregates: List[Dic
124132
"""
125133

126134
address, key = ADDRESS_1, "test_reference"
127-
aggregates = await get_aggregates_expect_success(ccn_api_client, address=address, keys=key, limit=1)
135+
aggregates = await get_aggregates_expect_success(
136+
ccn_api_client, address=address, keys=key, limit=1
137+
)
128138
assert aggregates["address"] == address
129139
assert aggregates["data"][key] == {"c": 3, "d": 4}
130140

131141

132142
@pytest.mark.asyncio
133-
async def test_get_aggregates_invalid_address(ccn_api_client, fixture_aggregates: List[Dict]):
143+
async def test_get_aggregates_invalid_address(
144+
ccn_api_client, fixture_aggregates: List[Dict]
145+
):
134146
"""
135147
Pass an unknown address.
136148
"""
@@ -139,3 +151,21 @@ async def test_get_aggregates_invalid_address(ccn_api_client, fixture_aggregates
139151

140152
response = await get_aggregates(ccn_api_client, invalid_address)
141153
assert response.status == 404
154+
155+
156+
@pytest.mark.asyncio
157+
async def test_get_aggregates_invalid_params(
158+
ccn_api_client, fixture_aggregates: List[Dict]
159+
):
160+
"""
161+
Tests that passing invalid parameters returns a 422 error.
162+
"""
163+
164+
# A string as limit
165+
response = await get_aggregates(ccn_api_client, ADDRESS_1, limit="abc")
166+
assert response.status == 422
167+
assert response.content_type == "application/json"
168+
169+
errors = await response.json()
170+
assert len(errors) == 1
171+
assert errors[0]["loc"] == ["limit"], errors

0 commit comments

Comments
 (0)