@@ -56,8 +56,6 @@ class TestHealthCheckAsync:
56
56
connectionPolicy = test_config .TestConfig .connectionPolicy
57
57
TEST_DATABASE_ID = test_config .TestConfig .TEST_DATABASE_ID
58
58
TEST_CONTAINER_SINGLE_PARTITION_ID = test_config .TestConfig .TEST_SINGLE_PARTITION_CONTAINER_ID
59
- # health check in all these tests should check the endpoints for the first two write regions and the first two read regions
60
- # without checking the same endpoint twice
61
59
62
60
@pytest .mark .parametrize ("preferred_location, use_write_global_endpoint, use_read_global_endpoint" , health_check ())
63
61
async def test_health_check_success_startup_async (self , setup , preferred_location , use_write_global_endpoint , use_read_global_endpoint ):
@@ -79,10 +77,7 @@ async def test_health_check_success_startup_async(self, setup, preferred_locatio
79
77
expected_regional_routing_context = []
80
78
81
79
locational_endpoint = _location_cache .LocationCache .GetLocationalEndpoint (self .host , REGION_1 )
82
- if use_read_global_endpoint :
83
- assert mock_get_database_account_check .counter == 1
84
- else :
85
- assert mock_get_database_account_check .counter == 2
80
+ assert mock_get_database_account_check .counter == 2
86
81
endpoint = self .host if use_read_global_endpoint else locational_endpoint
87
82
expected_regional_routing_context .append (RegionalRoutingContext (endpoint , endpoint ))
88
83
locational_endpoint = _location_cache .LocationCache .GetLocationalEndpoint (self .host , REGION_2 )
@@ -107,12 +102,8 @@ async def test_health_check_failure_startup_async(self, setup, preferred_locatio
107
102
_global_endpoint_manager_async ._GlobalEndpointManager ._GetDatabaseAccountStub = self .original_getDatabaseAccountStub
108
103
expected_endpoints = []
109
104
110
- if not use_read_global_endpoint :
111
- for region in REGIONS :
112
- locational_endpoint = _location_cache .LocationCache .GetLocationalEndpoint (self .host , region )
113
- expected_endpoints .append (locational_endpoint )
114
- else :
115
- locational_endpoint = _location_cache .LocationCache .GetLocationalEndpoint (self .host , REGION_2 )
105
+ for region in REGIONS :
106
+ locational_endpoint = _location_cache .LocationCache .GetLocationalEndpoint (self .host , region )
116
107
expected_endpoints .append (locational_endpoint )
117
108
118
109
unavailable_endpoint_info = client .client_connection ._global_endpoint_manager .location_cache .location_unavailability_info_by_endpoint
@@ -149,7 +140,7 @@ async def test_health_check_background_fail(self, setup):
149
140
_global_endpoint_manager_async ._GlobalEndpointManager ._endpoints_health_check = self .original_health_check
150
141
151
142
@pytest .mark .parametrize ("preferred_location, use_write_global_endpoint, use_read_global_endpoint" , health_check ())
152
- async def test_health_check_success (self , setup , preferred_location , use_write_global_endpoint , use_read_global_endpoint ):
143
+ async def test_health_check_success_async (self , setup , preferred_location , use_write_global_endpoint , use_read_global_endpoint ):
153
144
# checks the background health check works as expected when all endpoints healthy
154
145
self .original_getDatabaseAccountStub = _global_endpoint_manager_async ._GlobalEndpointManager ._GetDatabaseAccountStub
155
146
self .original_getDatabaseAccountCheck = _cosmos_client_connection_async .CosmosClientConnection ._GetDatabaseAccountCheck
@@ -183,7 +174,7 @@ async def test_health_check_success(self, setup, preferred_location, use_write_g
183
174
184
175
185
176
@pytest .mark .parametrize ("preferred_location, use_write_global_endpoint, use_read_global_endpoint" , health_check ())
186
- async def test_health_check_failure (self , setup , preferred_location , use_write_global_endpoint , use_read_global_endpoint ):
177
+ async def test_health_check_failure_async (self , setup , preferred_location , use_write_global_endpoint , use_read_global_endpoint ):
187
178
# checks the background health check works as expected when all endpoints unhealthy - it should mark the endpoints unavailable
188
179
setup [COLLECTION ].client_connection ._global_endpoint_manager .location_cache .location_unavailability_info_by_endpoint .clear ()
189
180
self .original_getDatabaseAccountStub = _global_endpoint_manager_async ._GlobalEndpointManager ._GetDatabaseAccountStub
@@ -198,16 +189,12 @@ async def test_health_check_failure(self, setup, preferred_location, use_write_g
198
189
for i in range (2 ):
199
190
await setup [COLLECTION ].create_item (body = {'id' : 'item' + str (uuid .uuid4 ()), 'pk' : 'pk' })
200
191
# wait for background task to finish
201
- await asyncio .sleep (1 )
192
+ await asyncio .sleep (2 )
202
193
finally :
203
194
_global_endpoint_manager_async ._GlobalEndpointManager ._GetDatabaseAccountStub = self .original_getDatabaseAccountStub
204
195
setup [COLLECTION ].client_connection .connection_policy .PreferredLocations = self .original_preferred_locations
205
196
206
- if not use_write_global_endpoint :
207
- num_unavailable_endpoints = len (REGIONS )
208
- else :
209
- num_unavailable_endpoints = 1
210
-
197
+ num_unavailable_endpoints = len (REGIONS )
211
198
unavailable_endpoint_info = setup [COLLECTION ].client_connection ._global_endpoint_manager .location_cache .location_unavailability_info_by_endpoint
212
199
assert len (unavailable_endpoint_info ) == num_unavailable_endpoints
213
200
0 commit comments