1
- from typing import Any , Dict , List , Optional , Set
1
+ import asyncio
2
+ import logging
3
+ from enum import IntEnum
4
+ from typing import Any , Dict , List , Optional , Mapping
2
5
3
- from aleph .model .messages import CappedMessage , Message
4
6
from aiohttp import web
5
- from aiohttp .web_exceptions import HTTPBadRequest
6
- import asyncio
7
- from pymongo .cursor import CursorType
7
+ from aleph_message .models import MessageType , ItemHash , Chain
8
8
from bson .objectid import ObjectId
9
- from aleph . web . controllers . utils import Pagination , cond_output , prepare_date_filters
10
- import logging
9
+ from pydantic import BaseModel , Field , validator , ValidationError , root_validator
10
+ from pymongo . cursor import CursorType
11
11
12
- LOGGER = logging .getLogger ("MESSAGES" )
13
-
14
- KNOWN_QUERY_FIELDS = {
15
- "sort_order" ,
16
- "msgType" ,
17
- "addresses" ,
18
- "refs" ,
19
- "contentHashes" ,
20
- "contentKeys" ,
21
- "contentTypes" ,
22
- "chains" ,
23
- "channels" ,
24
- "tags" ,
25
- "hashes" ,
26
- "history" ,
27
- "pagination" ,
28
- "page" , # page is handled in Pagination.get_pagination_params
29
- "startDate" ,
30
- "endDate" ,
31
- }
32
-
33
-
34
- async def get_filters (request : web .Request ):
35
- def get_query_list_field (field : str , separator = "," ) -> Optional [List [str ]]:
36
- field_str = request .query .get (field , None )
37
- return field_str .split (separator ) if field_str is not None else None
38
-
39
- unknown_query_fields : Set [str ] = set (request .query .keys ()).difference (
40
- KNOWN_QUERY_FIELDS
41
- )
42
- if unknown_query_fields :
43
- raise ValueError (f"Unknown query fields: { unknown_query_fields } " )
44
-
45
- find_filters : Dict [str , Any ] = {}
46
-
47
- msg_type = request .query .get ("msgType" , None )
48
-
49
- filters : List [Dict [str , Any ]] = []
50
- addresses = get_query_list_field ("addresses" )
51
- refs = get_query_list_field ("refs" )
52
- content_keys = get_query_list_field ("contentKeys" )
53
- content_hashes = get_query_list_field ("contentHashes" )
54
- content_types = get_query_list_field ("contentTypes" )
55
- chains = get_query_list_field ("chains" )
56
- channels = get_query_list_field ("channels" )
57
- tags = get_query_list_field ("tags" )
58
- hashes = get_query_list_field ("hashes" )
59
-
60
- date_filters = prepare_date_filters (request , "time" )
61
-
62
- if msg_type is not None :
63
- filters .append ({"type" : msg_type })
64
-
65
- if addresses is not None :
66
- filters .append (
67
- {
68
- "$or" : [
69
- {"content.address" : {"$in" : addresses }},
70
- {"sender" : {"$in" : addresses }},
71
- ]
72
- }
73
- )
12
+ from aleph .model .messages import CappedMessage , Message
13
+ from aleph .web .controllers .utils import (
14
+ LIST_FIELD_SEPARATOR ,
15
+ Pagination ,
16
+ cond_output ,
17
+ make_date_filters ,
18
+ )
74
19
75
- if content_hashes is not None :
76
- filters .append ({"content.item_hash" : {"$in" : content_hashes }})
20
+ LOGGER = logging .getLogger (__name__ )
77
21
78
- if content_keys is not None :
79
- filters .append ({"content.key" : {"$in" : content_keys }})
80
22
81
- if content_types is not None :
82
- filters .append ({"content.type" : {"$in" : content_types }})
23
+ DEFAULT_MESSAGES_PER_PAGE = 20
24
+ DEFAULT_PAGE = 1
25
+ DEFAULT_WS_HISTORY = 10
83
26
84
- if refs is not None :
85
- filters .append ({"content.ref" : {"$in" : refs }})
86
27
87
- if tags is not None :
88
- filters .append ({"content.content.tags" : {"$elemMatch" : {"$in" : tags }}})
28
+ class SortOrder (IntEnum ):
29
+ ASCENDING = 1
30
+ DESCENDING = - 1
89
31
90
- if chains is not None :
91
- filters .append ({"chain" : {"$in" : chains }})
92
32
93
- if channels is not None :
94
- filters .append ({"channel" : {"$in" : channels }})
33
+ class MessageQueryParams (BaseModel ):
34
+ sort_order : SortOrder = Field (
35
+ default = SortOrder .DESCENDING ,
36
+ description = "Order in which messages should be listed: "
37
+ "-1 means most recent messages first, 1 means older messages first." ,
38
+ )
39
+ message_type : Optional [MessageType ] = Field (
40
+ default = None , alias = "msgType" , description = "Message type."
41
+ )
42
+ addresses : Optional [List [str ]] = Field (
43
+ default = None , description = "Accepted values for the 'sender' field."
44
+ )
45
+ refs : Optional [List [str ]] = Field (
46
+ default = None , description = "Accepted values for the 'content.ref' field."
47
+ )
48
+ content_hashes : Optional [List [ItemHash ]] = Field (
49
+ default = None ,
50
+ alias = "contentHashes" ,
51
+ description = "Accepted values for the 'content.item_hash' field." ,
52
+ )
53
+ content_keys : Optional [List [ItemHash ]] = Field (
54
+ default = None ,
55
+ alias = "contentKeys" ,
56
+ description = "Accepted values for the 'content.keys' field." ,
57
+ )
58
+ content_types : Optional [List [ItemHash ]] = Field (
59
+ default = None ,
60
+ alias = "contentTypes" ,
61
+ description = "Accepted values for the 'content.type' field." ,
62
+ )
63
+ chains : Optional [List [Chain ]] = Field (
64
+ default = None , description = "Accepted values for the 'chain' field."
65
+ )
66
+ channels : Optional [List [str ]] = Field (
67
+ default = None , description = "Accepted values for the 'channel' field."
68
+ )
69
+ tags : Optional [List [str ]] = Field (
70
+ default = None , description = "Accepted values for the 'content.content.tag' field."
71
+ )
72
+ hashes : Optional [List [ItemHash ]] = Field (
73
+ default = None , description = "Accepted values for the 'item_hash' field."
74
+ )
75
+ history : Optional [int ] = Field (
76
+ DEFAULT_WS_HISTORY ,
77
+ ge = 10 ,
78
+ lt = 200 ,
79
+ description = "Accepted values for the 'item_hash' field." ,
80
+ )
81
+ pagination : int = Field (
82
+ default = DEFAULT_MESSAGES_PER_PAGE ,
83
+ ge = 0 ,
84
+ description = "Maximum number of messages to return. Specifying 0 removes this limit." ,
85
+ )
86
+ page : int = Field (
87
+ default = DEFAULT_PAGE , ge = 1 , description = "Offset in pages. Starts at 1."
88
+ )
89
+ start_date : float = Field (
90
+ default = 0 ,
91
+ ge = 0 ,
92
+ alias = "startDate" ,
93
+ description = "Start date timestamp. If specified, only messages with "
94
+ "a time field greater or equal to this value will be returned." ,
95
+ )
96
+ end_date : float = Field (
97
+ default = 0 ,
98
+ ge = 0 ,
99
+ alias = "endDate" ,
100
+ description = "End date timestamp. If specified, only messages with "
101
+ "a time field lower than this value will be returned." ,
102
+ )
95
103
96
- if hashes is not None :
97
- filters .append (
98
- {"$or" : [{"item_hash" : {"$in" : hashes }}, {"tx_hash" : {"$in" : hashes }}]}
99
- )
104
+ @root_validator
105
+ def validate_field_dependencies (cls , values ):
106
+ start_date = values .get ("start_date" )
107
+ end_date = values .get ("end_date" )
108
+ if start_date and end_date and (end_date < start_date ):
109
+ raise ValueError ("end date cannot be lower than start date." )
110
+ return values
111
+
112
+ @validator (
113
+ "addresses" ,
114
+ "content_hashes" ,
115
+ "content_keys" ,
116
+ "content_types" ,
117
+ "chains" ,
118
+ "channels" ,
119
+ "tags" ,
120
+ pre = True ,
121
+ )
122
+ def split_str (cls , v ):
123
+ if isinstance (v , str ):
124
+ return v .split (LIST_FIELD_SEPARATOR )
125
+ return v
126
+
127
+ def to_mongodb_filters (self ) -> Mapping [str , Any ]:
128
+ filters : List [Dict [str , Any ]] = []
129
+
130
+ if self .message_type is not None :
131
+ filters .append ({"type" : self .message_type })
132
+
133
+ if self .addresses is not None :
134
+ filters .append (
135
+ {
136
+ "$or" : [
137
+ {"content.address" : {"$in" : self .addresses }},
138
+ {"sender" : {"$in" : self .addresses }},
139
+ ]
140
+ }
141
+ )
100
142
101
- if date_filters is not None :
102
- filters .append (date_filters )
143
+ if self .content_hashes is not None :
144
+ filters .append ({"content.item_hash" : {"$in" : self .content_hashes }})
145
+ if self .content_keys is not None :
146
+ filters .append ({"content.key" : {"$in" : self .content_keys }})
147
+ if self .content_types is not None :
148
+ filters .append ({"content.type" : {"$in" : self .content_types }})
149
+ if self .refs is not None :
150
+ filters .append ({"content.ref" : {"$in" : self .refs }})
151
+ if self .tags is not None :
152
+ filters .append ({"content.content.tags" : {"$elemMatch" : {"$in" : self .tags }}})
153
+ if self .chains is not None :
154
+ filters .append ({"chain" : {"$in" : self .chains }})
155
+ if self .channels is not None :
156
+ filters .append ({"channel" : {"$in" : self .channels }})
157
+ if self .hashes is not None :
158
+ filters .append (
159
+ {
160
+ "$or" : [
161
+ {"item_hash" : {"$in" : self .hashes }},
162
+ {"tx_hash" : {"$in" : self .hashes }},
163
+ ]
164
+ }
165
+ )
166
+
167
+ date_filters = make_date_filters (
168
+ start = self .start_date , end = self .end_date , filter_key = "time"
169
+ )
170
+ if date_filters :
171
+ filters .append (date_filters )
103
172
104
- if len (filters ) > 0 :
105
- find_filters = {"$and" : filters } if len (filters ) > 1 else filters [0 ]
173
+ and_filter = {}
174
+ if filters :
175
+ and_filter = {"$and" : filters } if len (filters ) > 1 else filters [0 ]
106
176
107
- return find_filters
177
+ return and_filter
108
178
109
179
110
180
async def view_messages_list (request ):
111
181
"""Messages list view with filters"""
112
182
113
183
try :
114
- find_filters = await get_filters (request )
115
- except ValueError as error :
116
- raise HTTPBadRequest (body = error .args [0 ])
117
-
118
- (
119
- pagination_page ,
120
- pagination_per_page ,
121
- pagination_skip ,
122
- ) = Pagination .get_pagination_params (request )
123
- if pagination_per_page is None :
124
- pagination_per_page = 0
125
- if pagination_skip is None :
126
- pagination_skip = 0
184
+ query_params = MessageQueryParams .parse_obj (request .query )
185
+ except ValidationError as e :
186
+ raise web .HTTPUnprocessableEntity (body = e .json (indent = 4 ))
187
+
188
+ # If called from the messages/page/{page}.json endpoint, override the page
189
+ # parameters with the URL one
190
+ if url_page_param := request .match_info .get ("page" ):
191
+ query_params .page = int (url_page_param )
192
+
193
+ find_filters = query_params .to_mongodb_filters ()
194
+
195
+ pagination_page = query_params .page
196
+ pagination_per_page = query_params .pagination
197
+ pagination_skip = (query_params .page - 1 ) * query_params .pagination
127
198
128
199
messages = [
129
200
msg
@@ -132,14 +203,14 @@ async def view_messages_list(request):
132
203
projection = {"_id" : 0 },
133
204
limit = pagination_per_page ,
134
205
skip = pagination_skip ,
135
- sort = [("time" , int ( request . query . get ( " sort_order" , "-1" )) )],
206
+ sort = [("time" , query_params . sort_order . value )],
136
207
)
137
208
]
138
209
139
210
context = {"messages" : messages }
140
211
141
212
if pagination_per_page is not None :
142
- if len ( find_filters . keys ()) :
213
+ if find_filters :
143
214
total_msgs = await Message .collection .count_documents (find_filters )
144
215
else :
145
216
total_msgs = await Message .collection .estimated_document_count ()
@@ -173,11 +244,10 @@ async def messages_ws(request: web.Request):
173
244
collection = CappedMessage .collection
174
245
last_id = None
175
246
176
- find_filters = await get_filters (request )
177
- initial_count = int (request .query .get ("history" , 10 ))
178
- initial_count = max (initial_count , 10 )
179
- # let's cap this to 200 historic messages max.
180
- initial_count = min (initial_count , 200 )
247
+ query_params = MessageQueryParams .parse_obj (request .query )
248
+ find_filters = query_params .to_mongodb_filters ()
249
+
250
+ initial_count = query_params .history
181
251
182
252
items = [
183
253
item
0 commit comments