Skip to content

Commit f9eb9ff

Browse files
committed
Fix: refuse time and pagination parameters for message WS
Problem: the message websocket could receive "startDate"/"endDate" and pagination parameters. Solution: use a slightly different Pydantic model for the websocket to reject these parameters.
1 parent 8cbd9a9 commit f9eb9ff

File tree

1 file changed

+48
-40
lines changed

1 file changed

+48
-40
lines changed

src/aleph/web/controllers/messages.py

+48-40
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33
from enum import IntEnum
4-
from typing import Any, Dict, List, Optional, Mapping
4+
from typing import Any, List, Optional, Mapping
55

66
from aiohttp import web
77
from aleph_message.models import MessageType, ItemHash, Chain
@@ -14,7 +14,6 @@
1414
LIST_FIELD_SEPARATOR,
1515
Pagination,
1616
cond_output,
17-
make_date_filters,
1817
)
1918

2019
LOGGER = logging.getLogger(__name__)
@@ -30,7 +29,7 @@ class SortOrder(IntEnum):
3029
DESCENDING = -1
3130

3231

33-
class MessageQueryParams(BaseModel):
32+
class BaseMessageQueryParams(BaseModel):
3433
sort_order: SortOrder = Field(
3534
default=SortOrder.DESCENDING,
3635
description="Order in which messages should be listed: "
@@ -72,34 +71,6 @@ class MessageQueryParams(BaseModel):
7271
hashes: Optional[List[ItemHash]] = Field(
7372
default=None, description="Accepted values for the 'item_hash' field."
7473
)
75-
history: Optional[int] = Field(
76-
DEFAULT_WS_HISTORY,
77-
ge=10,
78-
lt=200,
79-
description="Accepted values for the 'item_hash' field.",
80-
)
81-
pagination: int = Field(
82-
default=DEFAULT_MESSAGES_PER_PAGE,
83-
ge=0,
84-
description="Maximum number of messages to return. Specifying 0 removes this limit.",
85-
)
86-
page: int = Field(
87-
default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1."
88-
)
89-
start_date: float = Field(
90-
default=0,
91-
ge=0,
92-
alias="startDate",
93-
description="Start date timestamp. If specified, only messages with "
94-
"a time field greater or equal to this value will be returned.",
95-
)
96-
end_date: float = Field(
97-
default=0,
98-
ge=0,
99-
alias="endDate",
100-
description="End date timestamp. If specified, only messages with "
101-
"a time field lower than this value will be returned.",
102-
)
10374

10475
@root_validator
10576
def validate_field_dependencies(cls, values):
@@ -124,8 +95,8 @@ def split_str(cls, v):
12495
return v.split(LIST_FIELD_SEPARATOR)
12596
return v
12697

127-
def to_mongodb_filters(self) -> Mapping[str, Any]:
128-
filters: List[Dict[str, Any]] = []
98+
def to_filter_list(self) -> List[Mapping[str, Any]]:
99+
filters: List[Mapping[str, Any]] = []
129100

130101
if self.message_type is not None:
131102
filters.append({"type": self.message_type})
@@ -164,19 +135,56 @@ def to_mongodb_filters(self) -> Mapping[str, Any]:
164135
}
165136
)
166137

167-
date_filters = make_date_filters(
168-
start=self.start_date, end=self.end_date, filter_key="time"
169-
)
170-
if date_filters:
171-
filters.append(date_filters)
138+
return filters
139+
140+
def to_mongodb_filters(self) -> Mapping[str, Any]:
141+
filters = self.to_filter_list()
142+
return self._make_and_filter(filters)
172143

173-
and_filter = {}
144+
@staticmethod
145+
def _make_and_filter(filters: List[Mapping[str, Any]]) -> Mapping[str, Any]:
146+
and_filter: Mapping[str, Any] = {}
174147
if filters:
175148
and_filter = {"$and": filters} if len(filters) > 1 else filters[0]
176149

177150
return and_filter
178151

179152

153+
class MessageQueryParams(BaseMessageQueryParams):
154+
pagination: int = Field(
155+
default=DEFAULT_MESSAGES_PER_PAGE,
156+
ge=0,
157+
description="Maximum number of messages to return. Specifying 0 removes this limit.",
158+
)
159+
page: int = Field(
160+
default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1."
161+
)
162+
163+
start_date: float = Field(
164+
default=0,
165+
ge=0,
166+
alias="startDate",
167+
description="Start date timestamp. If specified, only messages with "
168+
"a time field greater or equal to this value will be returned.",
169+
)
170+
end_date: float = Field(
171+
default=0,
172+
ge=0,
173+
alias="endDate",
174+
description="End date timestamp. If specified, only messages with "
175+
"a time field lower than this value will be returned.",
176+
)
177+
178+
179+
class WsMessageQueryParams(BaseMessageQueryParams):
180+
history: Optional[int] = Field(
181+
DEFAULT_WS_HISTORY,
182+
ge=10,
183+
lt=200,
184+
description="Accepted values for the 'item_hash' field.",
185+
)
186+
187+
180188
async def view_messages_list(request):
181189
"""Messages list view with filters"""
182190

@@ -244,7 +252,7 @@ async def messages_ws(request: web.Request):
244252
collection = CappedMessage.collection
245253
last_id = None
246254

247-
query_params = MessageQueryParams.parse_obj(request.query)
255+
query_params = WsMessageQueryParams.parse_obj(request.query)
248256
find_filters = query_params.to_mongodb_filters()
249257

250258
initial_count = query_params.history

0 commit comments

Comments
 (0)