1313
1414from newsroom import Service
1515from newsroom .auth .utils import user_has_section_allowed
16+ from newsroom .search import BoolQuery , BoolQueryParams , QueryStringQuery
1617from newsroom .search .config import (
1718 SearchGroupNestedConfig ,
1819 get_nested_config ,
@@ -48,7 +49,7 @@ def query_string(
4849 fields : List [str ] = ["*" ],
4950 multimatch_type : Literal ["cross_fields" , "best_fields" ] = "cross_fields" ,
5051 analyze_wildcard = False ,
51- ):
52+ ) -> QueryStringQuery :
5253 query_string_settings = app .config ["ELASTICSEARCH_SETTINGS" ]["settings" ]["query_string" ]
5354 return {
5455 "query_string" : {
@@ -106,6 +107,16 @@ class AdvancedSearchParams(TypedDict):
106107 fields : List [str ]
107108
108109
110+ class SearchArgs (TypedDict , total = False ):
111+ q : str
112+ id : str
113+ ids : List [str ]
114+ size : int
115+ bookmarks : str
116+ ignore_latest : bool
117+ filter : Union [Dict [str , str ], str ]
118+
119+
109120class SearchQuery (object ):
110121 """Class for storing the search parameters for validation and query generation"""
111122
@@ -120,14 +131,14 @@ def __init__(self):
120131 self .requested_products = []
121132 self .advanced : Optional [AdvancedSearchParams ] = None
122133
123- self .args = {}
134+ self .args : SearchArgs = {}
124135 self .lookup = {}
125136 self .projections = {}
126137 self .req = None
127138
128139 self .aggs = None
129140 self .source = {}
130- self .query = {"bool" : {"filter" : [], "must" : [], "must_not" : [], "should" : []}}
141+ self .query : BoolQuery = {"bool" : {"filter" : [], "must" : [], "must_not" : [], "should" : []}}
131142 self .highlight = None
132143 self .item_type = None
133144 self .planning_items_should = []
@@ -199,7 +210,7 @@ def _search_all_versions(self, search: SearchQuery, req, lookup):
199210
200211 # Now run a query only using the IDs from the above search
201212 # This final search makes sure pagination still works
202- search .query ["bool" ] = {"filter" : { "terms " : {"_id " : next_item_ids }}}
213+ search .query ["bool" ] = {"filter" : [{ "ids " : {"values " : next_item_ids }}] }
203214 self .gen_source_from_search (search )
204215 internal_req = self .get_internal_request (search )
205216 res = self .internal_get (internal_req , search .lookup )
@@ -330,10 +341,11 @@ def get_internal_request(self, search):
330341
331342 return internal_req
332343
333- def set_bool_query_from_filters (self , bool_query : Dict [ str , Any ], filters : Dict [str , Any ]):
344+ def set_bool_query_from_filters (self , bool_query : BoolQueryParams , filters : Dict [str , Any ]) -> None :
334345 for key , val in filters .items ():
335346 if not val :
336347 continue
348+ bool_query .setdefault ("must" , [])
337349 bool_query ["must" ].append (
338350 get_filter_query (key , val , self .get_aggregation_field (key ), get_nested_config ("items" , key ))
339351 )
@@ -576,7 +588,7 @@ def prefill_search_highlights(self, search, req):
576588 highlight_search .advanced = deepcopy (search .advanced )
577589
578590 # Set up the search query for filtering
579- self .apply_request_filter (highlight_search )
591+ self .apply_request_filter (highlight_search , highlights = True )
580592
581593 # Set up highlighting settings
582594 highlight_search .source .setdefault ("highlight" , {})
@@ -742,31 +754,34 @@ def get_product_filter(self, search, product):
742754 if product .get ("query" ):
743755 return self .query_string (product ["query" ])
744756
745- def apply_request_filter (self , search ):
757+ def parse_filters (self , search : SearchQuery ) -> Optional [Dict [str , Any ]]:
758+ if search .args .get ("filter" ):
759+ if isinstance (search .args ["filter" ], dict ):
760+ return search .args ["filter" ]
761+ else :
762+ try :
763+ return json .loads (search .args ["filter" ])
764+ except TypeError :
765+ raise BadParameterValueError ("Incorrect type supplied for filter parameter" )
766+ return None
767+
768+ def apply_request_filter (self , search : SearchQuery , highlights = False ) -> None :
746769 if search .args .get ("q" ):
747770 search .query ["bool" ].setdefault ("must" , []).append (
748771 self .query_string (search .args ["q" ], search .args .get ("default_operator" ) or "AND" )
749772 )
750773
751774 if search .args .get ("ids" ):
752- search .query ["bool" ]["must" ].append ({"terms " : {"_id " : search .args ["ids" ]}})
775+ search .query ["bool" ]["must" ].append ({"ids " : {"values " : search .args ["ids" ]}})
753776
754- filters = None
755- if search .args .get ("filter" ):
756- if isinstance (search .args ["filter" ], dict ):
757- filters = search .args ["filter" ]
758- else :
759- try :
760- filters = json .loads (search .args ["filter" ])
761- except TypeError :
762- raise BadParameterValueError ("Incorrect type supplied for filter parameter" )
777+ filters = self .parse_filters (search )
763778
764779 if not app .config .get ("FILTER_BY_POST_FILTER" , False ):
765780 if filters :
766781 if app .config .get ("FILTER_AGGREGATIONS" , True ):
767782 self .set_bool_query_from_filters (search .query ["bool" ], filters )
768- else :
769- search .query ["bool" ]["must" ].append (filters )
783+ elif isinstance ( filters , dict ) :
784+ search .query ["bool" ]["must" ].append (filters ) # type: ignore
770785
771786 if search .args .get ("created_from" ) or search .args .get ("created_to" ):
772787 search .query ["bool" ]["must" ].append (self .versioncreated_range (search .args ))
@@ -906,7 +921,7 @@ def get_matching_topics_for_item(self, topics, users, companies, query):
906921
907922 return topic_matches
908923
909- def apply_topic_args (self , topic , args = None ):
924+ def apply_topic_args (self , topic , args = None ) -> SearchArgs :
910925 if args is None :
911926 args = {}
912927
@@ -974,7 +989,7 @@ def get_items_by_query(self, search, size=10, aggs=None):
974989 internal_req = self .get_internal_request (search )
975990 return self .internal_get (internal_req , search .lookup )
976991
977- def query_string (self , query , default_operator = "AND" ):
992+ def query_string (self , query , default_operator = "AND" ) -> QueryStringQuery :
978993 fields_config_key = "WIRE_SEARCH_FIELDS" if self .section == "wire" else "AGENDA_SEARCH_FIELDS"
979994 fields = app .config .get (fields_config_key , ["*" ])
980995 return query_string (query , default_operator = default_operator , fields = fields )
0 commit comments