Skip to content

Commit c0f09ca

Browse files
authored
[Cosmos] fix response_hook not being passed through for complex queries (#40696)
* pass response_hook through aggregate query pipelines * Update CHANGELOG.md * add tests, remove comments, pylint * leftover print statements * add response hook to global statistics queries * fix tests
1 parent 610a33b commit c0f09ca

23 files changed

+265
-45
lines changed

sdk/cosmos/azure-cosmos/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
* Fixed bug where change feed requests would not respect the partition key filter. See [PR 40677](https://github.com/Azure/azure-sdk-for-python/pull/40677).
1313
* Fixed how the environment variables in the sdk are parsed. See [PR 40303](https://github.com/Azure/azure-sdk-for-python/pull/40303).
1414
* Fixed health check to check the first write region when it is not specified in the preferred regions. See [PR 40588](https://github.com/Azure/azure-sdk-for-python/pull/40588).
15+
* Fixed `response_hook` not getting called for aggregate queries. See [PR 40696](https://github.com/Azure/azure-sdk-for-python/pull/40696).
1516
* Fixed bug where writes were being retried for 5xx status codes for patch and replace. See [PR 40672](https://github.com/Azure/azure-sdk-for-python/pull/40672).
1617

1718
#### Other Changes

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,8 @@ def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInse
11121112
options,
11131113
fetch_function=fetch_fn,
11141114
collection_link=database_or_container_link,
1115-
page_iterator_class=query_iterable.QueryIterable
1115+
page_iterator_class=query_iterable.QueryIterable,
1116+
response_hook=response_hook
11161117
)
11171118

11181119
def QueryItemsChangeFeed(

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/document_producer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class _DocumentProducer(object):
4040
result of each.
4141
"""
4242

43-
def __init__(self, partition_key_target_range, client, collection_link, query, document_producer_comp, options):
43+
def __init__(self, partition_key_target_range, client, collection_link, query, document_producer_comp, options,
44+
response_hook):
4445
"""
4546
Constructor
4647
"""
@@ -60,7 +61,8 @@ def __init__(self, partition_key_target_range, client, collection_link, query, d
6061
collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link)
6162

6263
async def fetch_fn(options):
63-
return await self._client.QueryFeed(path, collection_id, query, options, partition_key_target_range["id"])
64+
return await self._client.QueryFeed(path, collection_id, query, options, partition_key_target_range["id"],
65+
response_hook=response_hook)
6466

6567
self._ex_context = _DefaultQueryExecutionContext(client, self._options, fetch_fn)
6668

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class _ProxyQueryExecutionContext(_QueryExecutionContextBase): # pylint: disabl
4747
to _MultiExecutionContextAggregator
4848
"""
4949

50-
def __init__(self, client, resource_link, query, options, fetch_function):
50+
def __init__(self, client, resource_link, query, options, fetch_function, response_hook):
5151
"""
5252
Constructor
5353
"""
@@ -57,6 +57,7 @@ def __init__(self, client, resource_link, query, options, fetch_function):
5757
self._resource_link = resource_link
5858
self._query = query
5959
self._fetch_function = fetch_function
60+
self._response_hook = response_hook
6061

6162
async def __anext__(self):
6263
"""Returns the next query result.
@@ -127,7 +128,8 @@ async def _create_pipelined_execution_context(self, query_execution_info):
127128
self._resource_link,
128129
self._query,
129130
self._options,
130-
query_execution_info)
131+
query_execution_info,
132+
self._response_hook)
131133
await execution_context_aggregator._configure_partition_ranges()
132134
elif query_execution_info.has_hybrid_search_query_info():
133135
hybrid_search_query_info = query_execution_info._query_execution_info['hybridSearchQueryInfo']
@@ -137,14 +139,16 @@ async def _create_pipelined_execution_context(self, query_execution_info):
137139
self._resource_link,
138140
self._options,
139141
query_execution_info,
140-
hybrid_search_query_info)
142+
hybrid_search_query_info,
143+
self._response_hook)
141144
await execution_context_aggregator._run_hybrid_search()
142145
else:
143146
execution_context_aggregator = multi_execution_aggregator._MultiExecutionContextAggregator(self._client,
144147
self._resource_link,
145148
self._query,
146149
self._options,
147-
query_execution_info)
150+
query_execution_info,
151+
self._response_hook)
148152
await execution_context_aggregator._configure_partition_ranges()
149153
return _PipelineExecutionContext(self._client, self._options, execution_context_aggregator,
150154
query_execution_info)

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class _HybridSearchContextAggregator(_QueryExecutionContextBase):
4545
"""
4646

4747
def __init__(self, client, resource_link, options, partitioned_query_execution_info,
48-
hybrid_search_query_info):
48+
hybrid_search_query_info, response_hook):
4949
super(_HybridSearchContextAggregator, self).__init__(client, options)
5050

5151
# use the routing provider in the client
@@ -57,6 +57,7 @@ def __init__(self, client, resource_link, options, partitioned_query_execution_i
5757
self._final_results = []
5858
self._aggregated_global_statistics = None
5959
self._document_producer_comparator = None
60+
self._response_hook = response_hook
6061

6162
async def _run_hybrid_search(self):
6263
# Check if we need to run global statistics queries, and if so do for every partition in the container
@@ -75,6 +76,7 @@ async def _run_hybrid_search(self):
7576
global_statistics_query,
7677
self._document_producer_comparator,
7778
self._options,
79+
self._response_hook
7880
)
7981
)
8082

@@ -117,6 +119,7 @@ async def _run_hybrid_search(self):
117119
rewritten_query['rewrittenQuery'],
118120
self._document_producer_comparator,
119121
self._options,
122+
self._response_hook
120123
)
121124
)
122125
# verify all document producers have items/ no splits
@@ -222,6 +225,7 @@ async def _repair_document_producer(self, query, target_all_ranges=False):
222225
query,
223226
self._document_producer_comparator,
224227
self._options,
228+
self._response_hook
225229
)
226230
)
227231

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/multi_execution_aggregator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def peek(self):
6262
def size(self):
6363
return len(self._heap)
6464

65-
def __init__(self, client, resource_link, query, options, partitioned_query_ex_info):
65+
def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, response_hook):
6666
super(_MultiExecutionContextAggregator, self).__init__(client, options)
6767

6868
# use the routing provider in the client
@@ -72,6 +72,7 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i
7272
self._query = query
7373
self._partitioned_query_ex_info = partitioned_query_ex_info
7474
self._sort_orders = partitioned_query_ex_info.get_order_by()
75+
self._response_hook = response_hook
7576

7677
if self._sort_orders:
7778
self._document_producer_comparator = document_producer._OrderByDocumentProducerComparator(self._sort_orders)
@@ -154,6 +155,7 @@ def _createTargetPartitionQueryExecutionContext(self, partition_key_target_range
154155
query,
155156
self._document_producer_comparator,
156157
self._options,
158+
self._response_hook
157159
)
158160

159161
async def _get_target_partition_key_range(self):

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class _NonStreamingOrderByContextAggregator(_QueryExecutionContextBase):
2222
by the user.
2323
"""
2424

25-
def __init__(self, client, resource_link, query, options, partitioned_query_ex_info):
25+
def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, response_hook):
2626
super(_NonStreamingOrderByContextAggregator, self).__init__(client, options)
2727

2828
# use the routing provider in the client
@@ -35,6 +35,7 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i
3535
self._orderByPQ = _MultiExecutionContextAggregator.PriorityQueue()
3636
self._doc_producers = []
3737
self._document_producer_comparator = document_producer._NonStreamingOrderByComparator(self._sort_orders)
38+
self._response_hook = response_hook
3839

3940

4041
async def __anext__(self):
@@ -99,6 +100,7 @@ def _createTargetPartitionQueryExecutionContext(self, partition_key_target_range
99100
query,
100101
self._document_producer_comparator,
101102
self._options,
103+
self._response_hook
102104
)
103105

104106
async def _get_target_partition_key_range(self):

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/document_producer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class _DocumentProducer(object):
3939
result of each.
4040
"""
4141

42-
def __init__(self, partition_key_target_range, client, collection_link, query, document_producer_comp, options):
42+
def __init__(self, partition_key_target_range, client, collection_link, query, document_producer_comp, options,
43+
response_hook):
4344
"""
4445
Constructor
4546
"""
@@ -59,7 +60,8 @@ def __init__(self, partition_key_target_range, client, collection_link, query, d
5960
collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link)
6061

6162
def fetch_fn(options):
62-
return self._client.QueryFeed(path, collection_id, query, options, partition_key_target_range["id"])
63+
return self._client.QueryFeed(path, collection_id, query, options, partition_key_target_range["id"],
64+
response_hook=response_hook)
6365

6466
self._ex_context = _DefaultQueryExecutionContext(client, self._options, fetch_fn)
6567

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class _ProxyQueryExecutionContext(_QueryExecutionContextBase): # pylint: disabl
7777
to _MultiExecutionContextAggregator
7878
"""
7979

80-
def __init__(self, client, resource_link, query, options, fetch_function):
80+
def __init__(self, client, resource_link, query, options, fetch_function, response_hook):
8181
"""
8282
Constructor
8383
"""
@@ -87,6 +87,7 @@ def __init__(self, client, resource_link, query, options, fetch_function):
8787
self._resource_link = resource_link
8888
self._query = query
8989
self._fetch_function = fetch_function
90+
self._response_hook = response_hook
9091

9192
def __next__(self):
9293
"""Returns the next query result.
@@ -160,7 +161,8 @@ def _create_pipelined_execution_context(self, query_execution_info):
160161
self._resource_link,
161162
self._query,
162163
self._options,
163-
query_execution_info)
164+
query_execution_info,
165+
self._response_hook)
164166
elif query_execution_info.has_hybrid_search_query_info():
165167
hybrid_search_query_info = query_execution_info._query_execution_info['hybridSearchQueryInfo']
166168
_verify_valid_hybrid_search_query(hybrid_search_query_info)
@@ -169,15 +171,17 @@ def _create_pipelined_execution_context(self, query_execution_info):
169171
self._resource_link,
170172
self._options,
171173
query_execution_info,
172-
hybrid_search_query_info)
174+
hybrid_search_query_info,
175+
self._response_hook)
173176
execution_context_aggregator._run_hybrid_search()
174177
else:
175178
execution_context_aggregator = \
176179
multi_execution_aggregator._MultiExecutionContextAggregator(self._client,
177180
self._resource_link,
178181
self._query,
179182
self._options,
180-
query_execution_info)
183+
query_execution_info,
184+
self._response_hook)
181185
return _PipelineExecutionContext(self._client, self._options, execution_context_aggregator,
182186
query_execution_info)
183187

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class _HybridSearchContextAggregator(_QueryExecutionContextBase):
150150
"""
151151

152152
def __init__(self, client, resource_link, options,
153-
partitioned_query_execution_info, hybrid_search_query_info):
153+
partitioned_query_execution_info, hybrid_search_query_info, response_hook):
154154
super(_HybridSearchContextAggregator, self).__init__(client, options)
155155

156156
# use the routing provider in the client
@@ -162,6 +162,7 @@ def __init__(self, client, resource_link, options,
162162
self._final_results = []
163163
self._aggregated_global_statistics = None
164164
self._document_producer_comparator = None
165+
self._response_hook = response_hook
165166

166167
def _run_hybrid_search(self):
167168
# Check if we need to run global statistics queries, and if so do for every partition in the container
@@ -180,6 +181,7 @@ def _run_hybrid_search(self):
180181
global_statistics_query,
181182
self._document_producer_comparator,
182183
self._options,
184+
self._response_hook
183185
)
184186
)
185187

@@ -221,6 +223,7 @@ def _run_hybrid_search(self):
221223
rewritten_query['rewrittenQuery'],
222224
self._document_producer_comparator,
223225
self._options,
226+
self._response_hook
224227
)
225228
)
226229
# verify all document producers have items/ no splits
@@ -347,6 +350,7 @@ def _repair_document_producer(self, query, target_all_ranges=False):
347350
query,
348351
self._document_producer_comparator,
349352
self._options,
353+
self._response_hook
350354
)
351355
)
352356

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def peek(self):
6363
def size(self):
6464
return len(self._heap)
6565

66-
def __init__(self, client, resource_link, query, options, partitioned_query_ex_info):
66+
def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, response_hook):
6767
super(_MultiExecutionContextAggregator, self).__init__(client, options)
6868

6969
# use the routing provider in the client
@@ -73,6 +73,7 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i
7373
self._query = query
7474
self._partitioned_query_ex_info = partitioned_query_ex_info
7575
self._sort_orders = partitioned_query_ex_info.get_order_by()
76+
self._response_hook = response_hook
7677

7778
if self._sort_orders:
7879
self._document_producer_comparator = document_producer._OrderByDocumentProducerComparator(self._sort_orders)
@@ -186,6 +187,7 @@ def _createTargetPartitionQueryExecutionContext(self, partition_key_target_range
186187
query,
187188
self._document_producer_comparator,
188189
self._options,
190+
self._response_hook
189191
)
190192

191193
def _get_target_partition_key_range(self):

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class _NonStreamingOrderByContextAggregator(_QueryExecutionContextBase):
2222
by the user.
2323
"""
2424

25-
def __init__(self, client, resource_link, query, options, partitioned_query_ex_info):
25+
def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, response_hook):
2626
super(_NonStreamingOrderByContextAggregator, self).__init__(client, options)
2727

2828
# use the routing provider in the client
@@ -33,6 +33,7 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i
3333
self._partitioned_query_ex_info = partitioned_query_ex_info
3434
self._sort_orders = partitioned_query_ex_info.get_order_by()
3535
self._orderByPQ = _MultiExecutionContextAggregator.PriorityQueue()
36+
self._response_hook = response_hook
3637

3738
# will be a list of (partition_min, partition_max) tuples
3839
targetPartitionRanges = self._get_target_partition_key_range()
@@ -142,6 +143,7 @@ def _createTargetPartitionQueryExecutionContext(self, partition_key_target_range
142143
query,
143144
self._document_producer_comparator,
144145
self._options,
146+
self._response_hook
145147
)
146148

147149
def _get_target_partition_key_range(self):

sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
database_link=None,
4444
partition_key=None,
4545
continuation_token=None,
46+
response_hook=None,
4647
):
4748
"""Instantiates a QueryIterable for non-client side partitioning queries.
4849
@@ -73,8 +74,7 @@ def __init__(
7374
self._database_link = database_link
7475
self._partition_key = partition_key
7576
self._ex_context = execution_dispatcher._ProxyQueryExecutionContext(
76-
self._client, self._collection_link, self._query, self._options, self._fetch_function
77-
)
77+
self._client, self._collection_link, self._query, self._options, self._fetch_function, response_hook)
7878
super(QueryIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token)
7979

8080
def _unpack(self, block):

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2278,7 +2278,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca
22782278
options,
22792279
fetch_function=fetch_fn,
22802280
collection_link=database_or_container_link,
2281-
page_iterator_class=query_iterable.QueryIterable
2281+
page_iterator_class=query_iterable.QueryIterable,
2282+
response_hook=response_hook
22822283
)
22832284

22842285
def QueryItemsChangeFeed(

0 commit comments

Comments
 (0)