Skip to content

Commit 09897b6

Browse files
authored
PYTHON-5212 - Do not hold Topology lock while resetting pool (#2301)
1 parent e2e673e commit 09897b6

File tree

9 files changed

+230
-25
lines changed

9 files changed

+230
-25
lines changed

pymongo/asynchronous/pool.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import collections
1819
import contextlib
1920
import logging
@@ -860,8 +861,14 @@ async def _reset(
860861
# PoolClosedEvent but that reset() SHOULD close sockets *after*
861862
# publishing the PoolClearedEvent.
862863
if close:
863-
for conn in sockets:
864-
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
864+
if not _IS_SYNC:
865+
await asyncio.gather(
866+
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
867+
return_exceptions=True,
868+
)
869+
else:
870+
for conn in sockets:
871+
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
865872
if self.enabled_for_cmap:
866873
assert listeners is not None
867874
listeners.publish_pool_closed(self.address)
@@ -891,8 +898,14 @@ async def _reset(
891898
serverPort=self.address[1],
892899
serviceId=service_id,
893900
)
894-
for conn in sockets:
895-
await conn.close_conn(ConnectionClosedReason.STALE)
901+
if not _IS_SYNC:
902+
await asyncio.gather(
903+
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
904+
return_exceptions=True,
905+
)
906+
else:
907+
for conn in sockets:
908+
await conn.close_conn(ConnectionClosedReason.STALE)
896909

897910
async def update_is_writable(self, is_writable: Optional[bool]) -> None:
898911
"""Updates the is_writable attribute on all sockets currently in the
@@ -938,8 +951,14 @@ async def remove_stale_sockets(self, reference_generation: int) -> None:
938951
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
939952
):
940953
close_conns.append(self.conns.pop())
941-
for conn in close_conns:
942-
await conn.close_conn(ConnectionClosedReason.IDLE)
954+
if not _IS_SYNC:
955+
await asyncio.gather(
956+
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
957+
return_exceptions=True,
958+
)
959+
else:
960+
for conn in close_conns:
961+
await conn.close_conn(ConnectionClosedReason.IDLE)
943962

944963
while True:
945964
async with self.size_cond:

pymongo/asynchronous/topology.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -529,12 +529,6 @@ async def _process_change(
529529
if not _IS_SYNC:
530530
self._monitor_tasks.append(self._srv_monitor)
531531

532-
# Clear the pool from a failed heartbeat.
533-
if reset_pool:
534-
server = self._servers.get(server_description.address)
535-
if server:
536-
await server.pool.reset(interrupt_connections=interrupt_connections)
537-
538532
# Wake anything waiting in select_servers().
539533
self._condition.notify_all()
540534

@@ -557,6 +551,11 @@ async def on_change(
557551
# that didn't include this server.
558552
if self._opened and self._description.has_server(server_description.address):
559553
await self._process_change(server_description, reset_pool, interrupt_connections)
554+
# Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close.
555+
if reset_pool:
556+
server = self._servers.get(server_description.address)
557+
if server:
558+
await server.pool.reset(interrupt_connections=interrupt_connections)
560559

561560
async def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
562561
"""Process a new seedlist on an opened topology.

pymongo/synchronous/pool.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import collections
1819
import contextlib
1920
import logging
@@ -858,8 +859,14 @@ def _reset(
858859
# PoolClosedEvent but that reset() SHOULD close sockets *after*
859860
# publishing the PoolClearedEvent.
860861
if close:
861-
for conn in sockets:
862-
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
862+
if not _IS_SYNC:
863+
asyncio.gather(
864+
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
865+
return_exceptions=True,
866+
)
867+
else:
868+
for conn in sockets:
869+
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
863870
if self.enabled_for_cmap:
864871
assert listeners is not None
865872
listeners.publish_pool_closed(self.address)
@@ -889,8 +896,14 @@ def _reset(
889896
serverPort=self.address[1],
890897
serviceId=service_id,
891898
)
892-
for conn in sockets:
893-
conn.close_conn(ConnectionClosedReason.STALE)
899+
if not _IS_SYNC:
900+
asyncio.gather(
901+
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
902+
return_exceptions=True,
903+
)
904+
else:
905+
for conn in sockets:
906+
conn.close_conn(ConnectionClosedReason.STALE)
894907

895908
def update_is_writable(self, is_writable: Optional[bool]) -> None:
896909
"""Updates the is_writable attribute on all sockets currently in the
@@ -934,8 +947,14 @@ def remove_stale_sockets(self, reference_generation: int) -> None:
934947
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
935948
):
936949
close_conns.append(self.conns.pop())
937-
for conn in close_conns:
938-
conn.close_conn(ConnectionClosedReason.IDLE)
950+
if not _IS_SYNC:
951+
asyncio.gather(
952+
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
953+
return_exceptions=True,
954+
)
955+
else:
956+
for conn in close_conns:
957+
conn.close_conn(ConnectionClosedReason.IDLE)
939958

940959
while True:
941960
with self.size_cond:

pymongo/synchronous/topology.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -529,12 +529,6 @@ def _process_change(
529529
if not _IS_SYNC:
530530
self._monitor_tasks.append(self._srv_monitor)
531531

532-
# Clear the pool from a failed heartbeat.
533-
if reset_pool:
534-
server = self._servers.get(server_description.address)
535-
if server:
536-
server.pool.reset(interrupt_connections=interrupt_connections)
537-
538532
# Wake anything waiting in select_servers().
539533
self._condition.notify_all()
540534

@@ -557,6 +551,11 @@ def on_change(
557551
# that didn't include this server.
558552
if self._opened and self._description.has_server(server_description.address):
559553
self._process_change(server_description, reset_pool, interrupt_connections)
554+
# Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close.
555+
if reset_pool:
556+
server = self._servers.get(server_description.address)
557+
if server:
558+
server.pool.reset(interrupt_connections=interrupt_connections)
560559

561560
def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
562561
"""Process a new seedlist on an opened topology.

test/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,14 @@ def require_sync(self, func):
826826
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
827827
)
828828

829+
def require_async(self, func):
830+
"""Run a test only if using the asynchronous API.""" # unasync: off
831+
return self._require(
832+
lambda: not _IS_SYNC,
833+
"This test only works with the asynchronous API", # unasync: off
834+
func=func,
835+
)
836+
829837
def mongos_seeds(self):
830838
return ",".join("{}:{}".format(*address) for address in self.mongoses)
831839

test/asynchronous/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,14 @@ def require_sync(self, func):
828828
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
829829
)
830830

831+
def require_async(self, func):
832+
"""Run a test only if using the asynchronous API.""" # unasync: off
833+
return self._require(
834+
lambda: not _IS_SYNC,
835+
"This test only works with the asynchronous API", # unasync: off
836+
func=func,
837+
)
838+
831839
def mongos_seeds(self):
832840
return ",".join("{}:{}".format(*address) for address in self.mongoses)
833841

test/asynchronous/test_discovery_and_monitoring.py

+73
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,15 @@
2020
import socketserver
2121
import sys
2222
import threading
23+
import time
2324
from asyncio import StreamReader, StreamWriter
2425
from pathlib import Path
2526
from test.asynchronous.helpers import ConcurrentRunner
2627

28+
from pymongo.asynchronous.pool import AsyncConnection
29+
from pymongo.operations import _Op
30+
from pymongo.server_selectors import writable_server_selector
31+
2732
sys.path[0:0] = [""]
2833

2934
from test.asynchronous import (
@@ -370,6 +375,74 @@ async def test_pool_unpause(self):
370375
await listener.async_wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
371376
await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1)
372377

378+
@async_client_context.require_failCommand_appName
379+
@async_client_context.require_test_commands
380+
@async_client_context.require_async
381+
async def test_connection_close_does_not_block_other_operations(self):
382+
listener = CMAPHeartbeatListener()
383+
client = await self.async_single_client(
384+
appName="SDAMConnectionCloseTest",
385+
event_listeners=[listener],
386+
heartbeatFrequencyMS=500,
387+
minPoolSize=10,
388+
)
389+
server = await (await client._get_topology()).select_server(
390+
writable_server_selector, _Op.TEST
391+
)
392+
await async_wait_until(
393+
lambda: len(server._pool.conns) == 10,
394+
"pool initialized with 10 connections",
395+
)
396+
397+
await client.db.test.insert_one({"x": 1})
398+
close_delay = 0.1
399+
latencies = []
400+
should_exit = []
401+
402+
async def run_task():
403+
while True:
404+
start_time = time.monotonic()
405+
await client.db.test.find_one({})
406+
elapsed = time.monotonic() - start_time
407+
latencies.append(elapsed)
408+
if should_exit:
409+
break
410+
await asyncio.sleep(0.001)
411+
412+
task = ConcurrentRunner(target=run_task)
413+
await task.start()
414+
original_close = AsyncConnection.close_conn
415+
try:
416+
# Artificially delay the close operation to simulate a slow close
417+
async def mock_close(self, reason):
418+
await asyncio.sleep(close_delay)
419+
await original_close(self, reason)
420+
421+
AsyncConnection.close_conn = mock_close
422+
423+
fail_hello = {
424+
"mode": {"times": 4},
425+
"data": {
426+
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
427+
"errorCode": 91,
428+
"appName": "SDAMConnectionCloseTest",
429+
},
430+
}
431+
async with self.fail_point(fail_hello):
432+
# Wait for server heartbeat to fail
433+
await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
434+
# Wait until all idle connections are closed to simulate real-world conditions
435+
await listener.async_wait_for_event(monitoring.ConnectionClosedEvent, 10)
436+
# Wait for one more find to complete after the pool has been reset, then shutdown the task
437+
n = len(latencies)
438+
await async_wait_until(lambda: len(latencies) >= n + 1, "run one more find")
439+
should_exit.append(True)
440+
await task.join()
441+
# No operation latency should not significantly exceed close_delay
442+
self.assertLessEqual(max(latencies), close_delay * 5.0)
443+
finally:
444+
AsyncConnection.close_conn = original_close
445+
373446

374447
class TestServerMonitoringMode(AsyncIntegrationTest):
375448
@async_client_context.require_no_serverless

test/test_discovery_and_monitoring.py

+71
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,15 @@
2020
import socketserver
2121
import sys
2222
import threading
23+
import time
2324
from asyncio import StreamReader, StreamWriter
2425
from pathlib import Path
2526
from test.helpers import ConcurrentRunner
2627

28+
from pymongo.operations import _Op
29+
from pymongo.server_selectors import writable_server_selector
30+
from pymongo.synchronous.pool import Connection
31+
2732
sys.path[0:0] = [""]
2833

2934
from test import (
@@ -370,6 +375,72 @@ def test_pool_unpause(self):
370375
listener.wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
371376
listener.wait_for_event(monitoring.PoolReadyEvent, 1)
372377

378+
@client_context.require_failCommand_appName
379+
@client_context.require_test_commands
380+
@client_context.require_async
381+
def test_connection_close_does_not_block_other_operations(self):
382+
listener = CMAPHeartbeatListener()
383+
client = self.single_client(
384+
appName="SDAMConnectionCloseTest",
385+
event_listeners=[listener],
386+
heartbeatFrequencyMS=500,
387+
minPoolSize=10,
388+
)
389+
server = (client._get_topology()).select_server(writable_server_selector, _Op.TEST)
390+
wait_until(
391+
lambda: len(server._pool.conns) == 10,
392+
"pool initialized with 10 connections",
393+
)
394+
395+
client.db.test.insert_one({"x": 1})
396+
close_delay = 0.1
397+
latencies = []
398+
should_exit = []
399+
400+
def run_task():
401+
while True:
402+
start_time = time.monotonic()
403+
client.db.test.find_one({})
404+
elapsed = time.monotonic() - start_time
405+
latencies.append(elapsed)
406+
if should_exit:
407+
break
408+
time.sleep(0.001)
409+
410+
task = ConcurrentRunner(target=run_task)
411+
task.start()
412+
original_close = Connection.close_conn
413+
try:
414+
# Artificially delay the close operation to simulate a slow close
415+
def mock_close(self, reason):
416+
time.sleep(close_delay)
417+
original_close(self, reason)
418+
419+
Connection.close_conn = mock_close
420+
421+
fail_hello = {
422+
"mode": {"times": 4},
423+
"data": {
424+
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
425+
"errorCode": 91,
426+
"appName": "SDAMConnectionCloseTest",
427+
},
428+
}
429+
with self.fail_point(fail_hello):
430+
# Wait for server heartbeat to fail
431+
listener.wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
432+
# Wait until all idle connections are closed to simulate real-world conditions
433+
listener.wait_for_event(monitoring.ConnectionClosedEvent, 10)
434+
# Wait for one more find to complete after the pool has been reset, then shutdown the task
435+
n = len(latencies)
436+
wait_until(lambda: len(latencies) >= n + 1, "run one more find")
437+
should_exit.append(True)
438+
task.join()
439+
# No operation latency should not significantly exceed close_delay
440+
self.assertLessEqual(max(latencies), close_delay * 5.0)
441+
finally:
442+
Connection.close_conn = original_close
443+
373444

374445
class TestServerMonitoringMode(IntegrationTest):
375446
@client_context.require_no_serverless

0 commit comments

Comments
 (0)