Skip to content

Commit d9fca73

Browse files
committed
Feature: return errors for invalid post API queries (#344)
Problem: the posts.json endpoint is too permissive and allows users to specify invalid hashes, time filters, pagination, etc. Solution: detect these cases and return a 422 error. Replaced the validation code by a Pydantic model. Breaking changes: * The "endDate" field is now considered as exclusive. Moreover, a 422 error code will now be returned in the following situations, where the previous implementation would simply return a 200: * if an invalid item hash (=not a hexadecimal sha256, CIDv0 or CIDv1) is specified in the "hashes" or "contentHashes" field. * if the "endDate" field is lower than the "startDate" field. * if "endDate" or "startDate" are negative. * if pagination parameters ("page" and "pagination") are negative.
1 parent 0067b93 commit d9fca73

File tree

3 files changed

+137
-69
lines changed

3 files changed

+137
-69
lines changed

src/aleph/web/controllers/messages.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from aleph.model.messages import CappedMessage, Message
1313
from aleph.web.controllers.utils import (
14+
DEFAULT_MESSAGES_PER_PAGE,
15+
DEFAULT_PAGE,
1416
LIST_FIELD_SEPARATOR,
1517
Pagination,
1618
cond_output,
@@ -20,8 +22,6 @@
2022
LOGGER = logging.getLogger(__name__)
2123

2224

23-
DEFAULT_MESSAGES_PER_PAGE = 20
24-
DEFAULT_PAGE = 1
2525
DEFAULT_WS_HISTORY = 10
2626

2727

src/aleph/web/controllers/posts.py

+132-64
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,129 @@
1+
from typing import Optional, List, Mapping, Any
2+
3+
from aiohttp import web
4+
from aleph_message.models import ItemHash
5+
from pydantic import BaseModel, Field, root_validator, validator, ValidationError
6+
17
from aleph.model.messages import Message, get_merged_posts
2-
from aleph.web.controllers.utils import Pagination, cond_output, prepare_date_filters
8+
from aleph.web.controllers.utils import (
9+
DEFAULT_MESSAGES_PER_PAGE,
10+
DEFAULT_PAGE,
11+
LIST_FIELD_SEPARATOR,
12+
Pagination,
13+
cond_output,
14+
make_date_filters,
15+
)
16+
17+
18+
class PostQueryParams(BaseModel):
19+
addresses: Optional[List[str]] = Field(
20+
default=None, description="Accepted values for the 'sender' field."
21+
)
22+
hashes: Optional[List[ItemHash]] = Field(
23+
default=None, description="Accepted values for the 'item_hash' field."
24+
)
25+
refs: Optional[List[str]] = Field(
26+
default=None, description="Accepted values for the 'content.ref' field."
27+
)
28+
post_types: Optional[List[str]] = Field(
29+
default=None, description="Accepted values for the 'content.type' field."
30+
)
31+
tags: Optional[List[str]] = Field(
32+
default=None, description="Accepted values for the 'content.content.tag' field."
33+
)
34+
channels: Optional[List[str]] = Field(
35+
default=None, description="Accepted values for the 'channel' field."
36+
)
37+
start_date: float = Field(
38+
default=0,
39+
ge=0,
40+
alias="startDate",
41+
description="Start date timestamp. If specified, only messages with "
42+
"a time field greater or equal to this value will be returned.",
43+
)
44+
end_date: float = Field(
45+
default=0,
46+
ge=0,
47+
alias="endDate",
48+
description="End date timestamp. If specified, only messages with "
49+
"a time field lower than this value will be returned.",
50+
)
51+
pagination: int = Field(
52+
default=DEFAULT_MESSAGES_PER_PAGE,
53+
ge=0,
54+
description="Maximum number of messages to return. Specifying 0 removes this limit.",
55+
)
56+
page: int = Field(
57+
default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1."
58+
)
59+
60+
@root_validator
61+
def validate_field_dependencies(cls, values):
62+
start_date = values.get("start_date")
63+
end_date = values.get("end_date")
64+
if start_date and end_date and (end_date < start_date):
65+
raise ValueError("end date cannot be lower than start date.")
66+
return values
67+
68+
@validator(
69+
"addresses",
70+
"hashes",
71+
"refs",
72+
"post_types",
73+
"channels",
74+
"tags",
75+
pre=True,
76+
)
77+
def split_str(cls, v):
78+
if isinstance(v, str):
79+
return v.split(LIST_FIELD_SEPARATOR)
80+
return v
81+
82+
def to_filter_list(self) -> List[Mapping[str, Any]]:
83+
84+
filters: List[Mapping[str, Any]] = []
85+
86+
if self.addresses is not None:
87+
filters.append(
88+
{"content.address": {"$in": self.addresses}},
89+
)
90+
if self.post_types is not None:
91+
filters.append({"content.type": {"$in": self.post_types}})
92+
if self.refs is not None:
93+
filters.append({"content.ref": {"$in": self.refs}})
94+
if self.tags is not None:
95+
filters.append({"content.content.tags": {"$elemMatch": {"$in": self.tags}}})
96+
if self.hashes is not None:
97+
filters.append(
98+
{
99+
"$or": [
100+
{"item_hash": {"$in": self.hashes}},
101+
{"tx_hash": {"$in": self.hashes}},
102+
]
103+
}
104+
)
105+
if self.channels is not None:
106+
filters.append({"channel": {"$in": self.channels}})
107+
108+
date_filters = make_date_filters(
109+
start=self.start_date, end=self.end_date, filter_key="time"
110+
)
111+
if date_filters:
112+
filters.append(date_filters)
113+
114+
return filters
115+
116+
def to_mongodb_filters(self) -> Mapping[str, Any]:
117+
filters = self.to_filter_list()
118+
return self._make_and_filter(filters)
119+
120+
@staticmethod
121+
def _make_and_filter(filters: List[Mapping[str, Any]]) -> Mapping[str, Any]:
122+
and_filter: Mapping[str, Any] = {}
123+
if filters:
124+
and_filter = {"$and": filters} if len(filters) > 1 else filters[0]
125+
126+
return and_filter
3127

4128

5129
async def view_posts_list(request):
@@ -8,72 +132,16 @@ async def view_posts_list(request):
8132
"""
9133

10134
find_filters = {}
11-
filters = [
12-
# {'type': request.query.get('msgType', 'POST')}
13-
]
14-
15135
query_string = request.query_string
16-
addresses = request.query.get("addresses", None)
17-
if addresses is not None:
18-
addresses = addresses.split(",")
19-
20-
refs = request.query.get("refs", None)
21-
if refs is not None:
22-
refs = refs.split(",")
23-
24-
post_types = request.query.get("types", None)
25-
if post_types is not None:
26-
post_types = post_types.split(",")
27-
28-
tags = request.query.get("tags", None)
29-
if tags is not None:
30-
tags = tags.split(",")
31-
32-
hashes = request.query.get("hashes", None)
33-
if hashes is not None:
34-
hashes = hashes.split(",")
35-
36-
channels = request.query.get("channels", None)
37-
if channels is not None:
38-
channels = channels.split(",")
39-
40-
date_filters = prepare_date_filters(request, "time")
41-
42-
if addresses is not None:
43-
filters.append({"content.address": {"$in": addresses}})
44-
45-
if post_types is not None:
46-
filters.append({"content.type": {"$in": post_types}})
47-
48-
if refs is not None:
49-
filters.append({"content.ref": {"$in": refs}})
50-
51-
if tags is not None:
52-
filters.append({"content.content.tags": {"$elemMatch": {"$in": tags}}})
53-
54-
if hashes is not None:
55-
filters.append(
56-
{"$or": [{"item_hash": {"$in": hashes}}, {"tx_hash": {"$in": hashes}}]}
57-
)
58-
59-
if channels is not None:
60-
filters.append({"channel": {"$in": channels}})
61-
62-
if date_filters is not None:
63-
filters.append(date_filters)
64136

65-
if len(filters) > 0:
66-
find_filters = {"$and": filters} if len(filters) > 1 else filters[0]
137+
try:
138+
query_params = PostQueryParams.parse_obj(request.query)
139+
except ValidationError as e:
140+
raise web.HTTPUnprocessableEntity(body=e.json(indent=4))
67141

68-
(
69-
pagination_page,
70-
pagination_per_page,
71-
pagination_skip,
72-
) = Pagination.get_pagination_params(request)
73-
if pagination_per_page is None:
74-
pagination_per_page = 0
75-
if pagination_skip is None:
76-
pagination_skip = 0
142+
pagination_page = query_params.page
143+
pagination_per_page = query_params.pagination
144+
pagination_skip = (query_params.page - 1) * query_params.pagination
77145

78146
posts = [
79147
msg

src/aleph/web/controllers/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from aiohttp import web
66
from bson import json_util
77

8-
PER_PAGE = 20
9-
PER_PAGE_SUMMARY = 50
8+
DEFAULT_MESSAGES_PER_PAGE = 20
9+
DEFAULT_PAGE = 1
1010
LIST_FIELD_SEPARATOR = ","
1111

1212

@@ -15,7 +15,7 @@ class Pagination(object):
1515
def get_pagination_params(request):
1616
pagination_page = int(request.match_info.get("page", "1"))
1717
pagination_page = int(request.query.get("page", pagination_page))
18-
pagination_param = int(request.query.get("pagination", PER_PAGE))
18+
pagination_param = int(request.query.get("pagination", DEFAULT_MESSAGES_PER_PAGE))
1919
with_pagination = pagination_param != 0
2020

2121
if pagination_page < 1:

0 commit comments

Comments
 (0)