Skip to content

Commit 8cbd9a9

Browse files
committed
Feature: return errors for invalid message API queries (#335)
Problem: the messages.json endpoint was too permissive, allowing users to specify invalid chains, message types, make incoherent pagination / time filtering, etc. Solution: detect these cases and return a 422 error. Replaced the custom validation code of the endpoint by a Pydantic model. Added tests for start/end date parameters and pagination. Breaking changes: * The "endDate" field is now considered as exclusive. Only messages with a time field strictly lower than the value will be returned, instead of a value lower or equal before. * The endpoint now returns a 422 error instead of a 400 when incorrect parameters are passed. Moreover, a 422 error code will now be returned in the following situations, where the previous implementation would simply return a 200: * if an invalid/unknown chain appears in the "chains". * if an invalid/unknown message type is specified in "msgType". * if an invalid item hash (=not a hexadecimal sha256, CIDv0 or CIDv1) is specified in the "hashes" or "contentHashes" field. * if the "endDate" field is lower than the "startDate" field. * if "endDate" or "startDate" are negative. * if pagination parameters ("page" and "pagination") are negative.
1 parent d42fbca commit 8cbd9a9

File tree

4 files changed

+428
-134
lines changed

4 files changed

+428
-134
lines changed

src/aleph/web/controllers/messages.py

+182-112
Original file line numberDiff line numberDiff line change
@@ -1,129 +1,200 @@
1-
from typing import Any, Dict, List, Optional, Set
1+
import asyncio
2+
import logging
3+
from enum import IntEnum
4+
from typing import Any, Dict, List, Optional, Mapping
25

3-
from aleph.model.messages import CappedMessage, Message
46
from aiohttp import web
5-
from aiohttp.web_exceptions import HTTPBadRequest
6-
import asyncio
7-
from pymongo.cursor import CursorType
7+
from aleph_message.models import MessageType, ItemHash, Chain
88
from bson.objectid import ObjectId
9-
from aleph.web.controllers.utils import Pagination, cond_output, prepare_date_filters
10-
import logging
9+
from pydantic import BaseModel, Field, validator, ValidationError, root_validator
10+
from pymongo.cursor import CursorType
1111

12-
LOGGER = logging.getLogger("MESSAGES")
13-
14-
KNOWN_QUERY_FIELDS = {
15-
"sort_order",
16-
"msgType",
17-
"addresses",
18-
"refs",
19-
"contentHashes",
20-
"contentKeys",
21-
"contentTypes",
22-
"chains",
23-
"channels",
24-
"tags",
25-
"hashes",
26-
"history",
27-
"pagination",
28-
"page", # page is handled in Pagination.get_pagination_params
29-
"startDate",
30-
"endDate",
31-
}
32-
33-
34-
async def get_filters(request: web.Request):
35-
def get_query_list_field(field: str, separator=",") -> Optional[List[str]]:
36-
field_str = request.query.get(field, None)
37-
return field_str.split(separator) if field_str is not None else None
38-
39-
unknown_query_fields: Set[str] = set(request.query.keys()).difference(
40-
KNOWN_QUERY_FIELDS
41-
)
42-
if unknown_query_fields:
43-
raise ValueError(f"Unknown query fields: {unknown_query_fields}")
44-
45-
find_filters: Dict[str, Any] = {}
46-
47-
msg_type = request.query.get("msgType", None)
48-
49-
filters: List[Dict[str, Any]] = []
50-
addresses = get_query_list_field("addresses")
51-
refs = get_query_list_field("refs")
52-
content_keys = get_query_list_field("contentKeys")
53-
content_hashes = get_query_list_field("contentHashes")
54-
content_types = get_query_list_field("contentTypes")
55-
chains = get_query_list_field("chains")
56-
channels = get_query_list_field("channels")
57-
tags = get_query_list_field("tags")
58-
hashes = get_query_list_field("hashes")
59-
60-
date_filters = prepare_date_filters(request, "time")
61-
62-
if msg_type is not None:
63-
filters.append({"type": msg_type})
64-
65-
if addresses is not None:
66-
filters.append(
67-
{
68-
"$or": [
69-
{"content.address": {"$in": addresses}},
70-
{"sender": {"$in": addresses}},
71-
]
72-
}
73-
)
12+
from aleph.model.messages import CappedMessage, Message
13+
from aleph.web.controllers.utils import (
14+
LIST_FIELD_SEPARATOR,
15+
Pagination,
16+
cond_output,
17+
make_date_filters,
18+
)
7419

75-
if content_hashes is not None:
76-
filters.append({"content.item_hash": {"$in": content_hashes}})
20+
LOGGER = logging.getLogger(__name__)
7721

78-
if content_keys is not None:
79-
filters.append({"content.key": {"$in": content_keys}})
8022

81-
if content_types is not None:
82-
filters.append({"content.type": {"$in": content_types}})
23+
DEFAULT_MESSAGES_PER_PAGE = 20
24+
DEFAULT_PAGE = 1
25+
DEFAULT_WS_HISTORY = 10
8326

84-
if refs is not None:
85-
filters.append({"content.ref": {"$in": refs}})
8627

87-
if tags is not None:
88-
filters.append({"content.content.tags": {"$elemMatch": {"$in": tags}}})
28+
class SortOrder(IntEnum):
29+
ASCENDING = 1
30+
DESCENDING = -1
8931

90-
if chains is not None:
91-
filters.append({"chain": {"$in": chains}})
9232

93-
if channels is not None:
94-
filters.append({"channel": {"$in": channels}})
33+
class MessageQueryParams(BaseModel):
34+
sort_order: SortOrder = Field(
35+
default=SortOrder.DESCENDING,
36+
description="Order in which messages should be listed: "
37+
"-1 means most recent messages first, 1 means older messages first.",
38+
)
39+
message_type: Optional[MessageType] = Field(
40+
default=None, alias="msgType", description="Message type."
41+
)
42+
addresses: Optional[List[str]] = Field(
43+
default=None, description="Accepted values for the 'sender' field."
44+
)
45+
refs: Optional[List[str]] = Field(
46+
default=None, description="Accepted values for the 'content.ref' field."
47+
)
48+
content_hashes: Optional[List[ItemHash]] = Field(
49+
default=None,
50+
alias="contentHashes",
51+
description="Accepted values for the 'content.item_hash' field.",
52+
)
53+
content_keys: Optional[List[ItemHash]] = Field(
54+
default=None,
55+
alias="contentKeys",
56+
description="Accepted values for the 'content.keys' field.",
57+
)
58+
content_types: Optional[List[ItemHash]] = Field(
59+
default=None,
60+
alias="contentTypes",
61+
description="Accepted values for the 'content.type' field.",
62+
)
63+
chains: Optional[List[Chain]] = Field(
64+
default=None, description="Accepted values for the 'chain' field."
65+
)
66+
channels: Optional[List[str]] = Field(
67+
default=None, description="Accepted values for the 'channel' field."
68+
)
69+
tags: Optional[List[str]] = Field(
70+
default=None, description="Accepted values for the 'content.content.tag' field."
71+
)
72+
hashes: Optional[List[ItemHash]] = Field(
73+
default=None, description="Accepted values for the 'item_hash' field."
74+
)
75+
history: Optional[int] = Field(
76+
DEFAULT_WS_HISTORY,
77+
ge=10,
78+
lt=200,
79+
description="Accepted values for the 'item_hash' field.",
80+
)
81+
pagination: int = Field(
82+
default=DEFAULT_MESSAGES_PER_PAGE,
83+
ge=0,
84+
description="Maximum number of messages to return. Specifying 0 removes this limit.",
85+
)
86+
page: int = Field(
87+
default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1."
88+
)
89+
start_date: float = Field(
90+
default=0,
91+
ge=0,
92+
alias="startDate",
93+
description="Start date timestamp. If specified, only messages with "
94+
"a time field greater or equal to this value will be returned.",
95+
)
96+
end_date: float = Field(
97+
default=0,
98+
ge=0,
99+
alias="endDate",
100+
description="End date timestamp. If specified, only messages with "
101+
"a time field lower than this value will be returned.",
102+
)
95103

96-
if hashes is not None:
97-
filters.append(
98-
{"$or": [{"item_hash": {"$in": hashes}}, {"tx_hash": {"$in": hashes}}]}
99-
)
104+
@root_validator
105+
def validate_field_dependencies(cls, values):
106+
start_date = values.get("start_date")
107+
end_date = values.get("end_date")
108+
if start_date and end_date and (end_date < start_date):
109+
raise ValueError("end date cannot be lower than start date.")
110+
return values
111+
112+
@validator(
113+
"addresses",
114+
"content_hashes",
115+
"content_keys",
116+
"content_types",
117+
"chains",
118+
"channels",
119+
"tags",
120+
pre=True,
121+
)
122+
def split_str(cls, v):
123+
if isinstance(v, str):
124+
return v.split(LIST_FIELD_SEPARATOR)
125+
return v
126+
127+
def to_mongodb_filters(self) -> Mapping[str, Any]:
128+
filters: List[Dict[str, Any]] = []
129+
130+
if self.message_type is not None:
131+
filters.append({"type": self.message_type})
132+
133+
if self.addresses is not None:
134+
filters.append(
135+
{
136+
"$or": [
137+
{"content.address": {"$in": self.addresses}},
138+
{"sender": {"$in": self.addresses}},
139+
]
140+
}
141+
)
100142

101-
if date_filters is not None:
102-
filters.append(date_filters)
143+
if self.content_hashes is not None:
144+
filters.append({"content.item_hash": {"$in": self.content_hashes}})
145+
if self.content_keys is not None:
146+
filters.append({"content.key": {"$in": self.content_keys}})
147+
if self.content_types is not None:
148+
filters.append({"content.type": {"$in": self.content_types}})
149+
if self.refs is not None:
150+
filters.append({"content.ref": {"$in": self.refs}})
151+
if self.tags is not None:
152+
filters.append({"content.content.tags": {"$elemMatch": {"$in": self.tags}}})
153+
if self.chains is not None:
154+
filters.append({"chain": {"$in": self.chains}})
155+
if self.channels is not None:
156+
filters.append({"channel": {"$in": self.channels}})
157+
if self.hashes is not None:
158+
filters.append(
159+
{
160+
"$or": [
161+
{"item_hash": {"$in": self.hashes}},
162+
{"tx_hash": {"$in": self.hashes}},
163+
]
164+
}
165+
)
166+
167+
date_filters = make_date_filters(
168+
start=self.start_date, end=self.end_date, filter_key="time"
169+
)
170+
if date_filters:
171+
filters.append(date_filters)
103172

104-
if len(filters) > 0:
105-
find_filters = {"$and": filters} if len(filters) > 1 else filters[0]
173+
and_filter = {}
174+
if filters:
175+
and_filter = {"$and": filters} if len(filters) > 1 else filters[0]
106176

107-
return find_filters
177+
return and_filter
108178

109179

110180
async def view_messages_list(request):
111181
"""Messages list view with filters"""
112182

113183
try:
114-
find_filters = await get_filters(request)
115-
except ValueError as error:
116-
raise HTTPBadRequest(body=error.args[0])
117-
118-
(
119-
pagination_page,
120-
pagination_per_page,
121-
pagination_skip,
122-
) = Pagination.get_pagination_params(request)
123-
if pagination_per_page is None:
124-
pagination_per_page = 0
125-
if pagination_skip is None:
126-
pagination_skip = 0
184+
query_params = MessageQueryParams.parse_obj(request.query)
185+
except ValidationError as e:
186+
raise web.HTTPUnprocessableEntity(body=e.json(indent=4))
187+
188+
# If called from the messages/page/{page}.json endpoint, override the page
189+
# parameters with the URL one
190+
if url_page_param := request.match_info.get("page"):
191+
query_params.page = int(url_page_param)
192+
193+
find_filters = query_params.to_mongodb_filters()
194+
195+
pagination_page = query_params.page
196+
pagination_per_page = query_params.pagination
197+
pagination_skip = (query_params.page - 1) * query_params.pagination
127198

128199
messages = [
129200
msg
@@ -132,14 +203,14 @@ async def view_messages_list(request):
132203
projection={"_id": 0},
133204
limit=pagination_per_page,
134205
skip=pagination_skip,
135-
sort=[("time", int(request.query.get("sort_order", "-1")))],
206+
sort=[("time", query_params.sort_order.value)],
136207
)
137208
]
138209

139210
context = {"messages": messages}
140211

141212
if pagination_per_page is not None:
142-
if len(find_filters.keys()):
213+
if find_filters:
143214
total_msgs = await Message.collection.count_documents(find_filters)
144215
else:
145216
total_msgs = await Message.collection.estimated_document_count()
@@ -173,11 +244,10 @@ async def messages_ws(request: web.Request):
173244
collection = CappedMessage.collection
174245
last_id = None
175246

176-
find_filters = await get_filters(request)
177-
initial_count = int(request.query.get("history", 10))
178-
initial_count = max(initial_count, 10)
179-
# let's cap this to 200 historic messages max.
180-
initial_count = min(initial_count, 200)
247+
query_params = MessageQueryParams.parse_obj(request.query)
248+
find_filters = query_params.to_mongodb_filters()
249+
250+
initial_count = query_params.history
181251

182252
items = [
183253
item

src/aleph/web/controllers/utils.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def get_pagination_params(request):
1919
with_pagination = pagination_param != 0
2020

2121
if pagination_page < 1:
22-
raise web.HTTPBadRequest(text=f"Query field 'page' must be ≥ 1, not {pagination_page}")
22+
raise web.HTTPBadRequest(
23+
text=f"Query field 'page' must be ≥ 1, not {pagination_page}"
24+
)
2325

2426
if not with_pagination:
2527
pagination_per_page = None
@@ -66,6 +68,21 @@ def iter_pages(self, left_edge=2, left_current=2, right_current=5, right_edge=2)
6668
last = num
6769

6870

71+
def make_date_filters(start: float, end: float, filter_key: str):
72+
filters = []
73+
if start:
74+
filters.append({filter_key: {"$gte": start}})
75+
if end:
76+
filters.append({filter_key: {"$lt": end}})
77+
78+
if len(filters) > 1:
79+
return {"$and": filters}
80+
if filters:
81+
return filters[0]
82+
83+
return None
84+
85+
6986
def prepare_date_filters(request, filter_key):
7087
date_filters = None
7188

0 commit comments

Comments
 (0)