Skip to content

#325: add disable_discovery option to DriverConfig #666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 313 additions & 0 deletions tests/test_disable_discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
import pytest
import unittest.mock
import ydb
import asyncio
from ydb import _apis


TEST_ERROR = "Test error"
TEST_QUERY = "SELECT 1 + 2 AS sum"


@pytest.fixture
def mock_connection():
"""Mock a YDB connection to avoid actual connections."""
with unittest.mock.patch("ydb.connection.Connection.ready_factory") as mock_factory:
# Setup the mock to return a connection-like object
mock_connection = unittest.mock.MagicMock()
# Use the endpoint fixture value via the function parameter
mock_connection.endpoint = "localhost:2136" # Will be overridden in tests
mock_connection.node_id = "mock_node_id"
mock_factory.return_value = mock_connection
yield mock_factory


@pytest.fixture
def mock_aio_connection():
"""Mock a YDB async connection to avoid actual connections."""
with unittest.mock.patch("ydb.aio.connection.Connection.__init__") as mock_init:
# Setup the mock to return None (as __init__ does)
mock_init.return_value = None

# Mock connection_ready method
with unittest.mock.patch("ydb.aio.connection.Connection.connection_ready") as mock_ready:
# Create event loop if there isn't one currently
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

future = asyncio.Future()
future.set_result(None)
mock_ready.return_value = future
yield mock_init


def create_mock_discovery_resolver(path):
"""Create a mock discovery resolver that raises exception if called."""

def _mock_fixture():
with unittest.mock.patch(path) as mock_resolve:
# Configure mock to throw an exception if called
mock_resolve.side_effect = Exception("Discovery should not be executed when discovery is disabled")
yield mock_resolve

return _mock_fixture


# Mock discovery resolvers to verify no discovery requests are made
mock_discovery_resolver = pytest.fixture(
create_mock_discovery_resolver("ydb.resolver.DiscoveryEndpointsResolver.context_resolve")
)
mock_aio_discovery_resolver = pytest.fixture(
create_mock_discovery_resolver("ydb.aio.resolver.DiscoveryEndpointsResolver.resolve")
)


# Basic unit tests for DriverConfig
def test_driver_config_has_disable_discovery_option(endpoint, database):
"""Test that DriverConfig has the disable_discovery option."""
config = ydb.DriverConfig(endpoint=endpoint, database=database, disable_discovery=True)
assert hasattr(config, "disable_discovery")
assert config.disable_discovery is True


# Driver config fixtures
def create_driver_config(endpoint, database, disable_discovery):
"""Create a driver config with the given discovery setting."""
return ydb.DriverConfig(
endpoint=endpoint,
database=database,
disable_discovery=disable_discovery,
)


@pytest.fixture
def driver_config_disabled_discovery(endpoint, database):
"""A driver config with discovery disabled"""
return create_driver_config(endpoint, database, True)


@pytest.fixture
def driver_config_enabled_discovery(endpoint, database):
"""A driver config with discovery enabled (default)"""
return create_driver_config(endpoint, database, False)


# Common test assertions
def assert_discovery_disabled(driver):
"""Assert that discovery is disabled in the driver."""
assert "Discovery is disabled" in driver.discovery_debug_details()


def create_future_with_error():
"""Create a future with a test error."""
future = asyncio.Future()
future.set_exception(ydb.issues.Error(TEST_ERROR))
return future


def create_completed_future():
"""Create a completed future."""
future = asyncio.Future()
future.set_result(None)
return future


# Mock tests for synchronous driver
def test_sync_driver_discovery_disabled_mock(
driver_config_disabled_discovery, mock_connection, mock_discovery_resolver
):
"""Test that when disable_discovery=True, the discovery thread is not started and resolver is not called (mock)."""
with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class:
driver = ydb.Driver(driver_config=driver_config_disabled_discovery)

try:
# Check that the discovery thread was not created
mock_discovery_class.assert_not_called()

# Check that discovery is disabled in debug details
assert_discovery_disabled(driver)

# Execute a dummy call to verify no discovery requests are made
try:
driver(ydb.issues.Error(TEST_ERROR), _apis.OperationService.Stub, "GetOperation")
except ydb.issues.Error:
pass # Expected exception, we just want to ensure no discovery occurs

# Verify the mock wasn't called
assert (
not mock_discovery_resolver.called
), "Discovery resolver should not be called when discovery is disabled"
finally:
# Clean up
driver.stop()


def test_sync_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_connection):
"""Test that when disable_discovery=False, the discovery thread is started (mock)."""
with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class:
mock_discovery_instance = unittest.mock.MagicMock()
mock_discovery_class.return_value = mock_discovery_instance

driver = ydb.Driver(driver_config=driver_config_enabled_discovery)

try:
# Check that the discovery thread was created and started
mock_discovery_class.assert_called_once()
assert mock_discovery_instance.start.called
finally:
# Clean up
driver.stop()


# Helper for setting up async driver test mocks
def setup_async_driver_mocks():
"""Set up common mocks for async driver tests."""
mocks = {}

# Create mock for Discovery class
discovery_patcher = unittest.mock.patch("ydb.aio.pool.Discovery")
mocks["mock_discovery_class"] = discovery_patcher.start()

# Mock the event loop
loop_patcher = unittest.mock.patch("asyncio.get_event_loop")
mock_loop = loop_patcher.start()
mock_loop_instance = unittest.mock.MagicMock()
mock_loop.return_value = mock_loop_instance
mock_loop_instance.create_task.return_value = unittest.mock.MagicMock()
mocks["mock_loop"] = mock_loop

# Mock the connection pool stop method
stop_patcher = unittest.mock.patch("ydb.aio.pool.ConnectionPool.stop")
mock_stop = stop_patcher.start()
mock_stop.return_value = create_completed_future()
mocks["mock_stop"] = mock_stop

# Add cleanup for all patchers
mocks["patchers"] = [discovery_patcher, loop_patcher, stop_patcher]

return mocks


def teardown_async_mocks(mocks):
"""Clean up all mock patchers."""
for patcher in mocks["patchers"]:
patcher.stop()


# Mock tests for asynchronous driver
@pytest.mark.asyncio
async def test_aio_driver_discovery_disabled_mock(
driver_config_disabled_discovery, mock_aio_connection, mock_aio_discovery_resolver
):
"""Test that when disable_discovery=True, the discovery is not created and resolver is not called (mock)."""
mocks = setup_async_driver_mocks()

try:
# Mock the pool's call method to prevent unhandled exceptions
with unittest.mock.patch("ydb.aio.pool.ConnectionPool.__call__") as mock_call:
mock_call.return_value = create_future_with_error()

driver = ydb.aio.Driver(driver_config=driver_config_disabled_discovery)

try:
# Check that the discovery class was not instantiated
mocks["mock_discovery_class"].assert_not_called()

# Check that discovery is disabled in debug details
assert_discovery_disabled(driver)

# Execute a dummy call to verify no discovery requests are made
try:
try:
await driver(ydb.issues.Error(TEST_ERROR), _apis.OperationService.Stub, "GetOperation")
except ydb.issues.Error:
pass # Expected exception, we just want to ensure no discovery occurs
except Exception as e:
if "discovery is disabled" in str(e).lower():
raise # If the error is related to discovery being disabled, re-raise it
pass # Other exceptions are expected as we're using mocks

# Verify the mock wasn't called
assert (
not mock_aio_discovery_resolver.called
), "Discovery resolver should not be called when discovery is disabled"
finally:
# The stop method is already mocked, so we don't need to await it
pass
finally:
teardown_async_mocks(mocks)


@pytest.mark.asyncio
async def test_aio_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_aio_connection):
"""Test that when disable_discovery=False, the discovery is created (mock)."""
mocks = setup_async_driver_mocks()

try:
mock_discovery_instance = unittest.mock.MagicMock()
mocks["mock_discovery_class"].return_value = mock_discovery_instance

driver = ydb.aio.Driver(driver_config=driver_config_enabled_discovery)

try:
# Check that the discovery class was instantiated
mocks["mock_discovery_class"].assert_called_once()
assert driver is not None # Use the driver variable to avoid F841
finally:
# The stop method is already mocked, so we don't need to await it
pass
finally:
teardown_async_mocks(mocks)


# Common integration test logic
def perform_integration_test_checks(driver, is_async=False):
"""Common assertions for integration tests."""
assert_discovery_disabled(driver)


# Integration tests with real YDB
def test_integration_disable_discovery(driver_config_disabled_discovery):
"""Integration test that tests the disable_discovery feature with a real YDB container."""
# Create driver with discovery disabled
driver = ydb.Driver(driver_config=driver_config_disabled_discovery)
try:
driver.wait(timeout=15)
perform_integration_test_checks(driver)

# Try to execute a simple query to ensure it works with discovery disabled
with ydb.SessionPool(driver) as pool:

def query_callback(session):
result_sets = session.transaction().execute(TEST_QUERY, commit_tx=True)
assert len(result_sets) == 1
assert result_sets[0].rows[0].sum == 3

pool.retry_operation_sync(query_callback)
finally:
driver.stop(timeout=10)


@pytest.mark.asyncio
async def test_integration_aio_disable_discovery(driver_config_disabled_discovery):
"""Integration test that tests the disable_discovery feature with a real YDB container (async)."""
# Create driver with discovery disabled
driver = ydb.aio.Driver(driver_config=driver_config_disabled_discovery)
try:
await driver.wait(timeout=15)
perform_integration_test_checks(driver, is_async=True)

# Try to execute a simple query to ensure it works with discovery disabled
session_pool = ydb.aio.SessionPool(driver, size=10)

async def query_callback(session):
result_sets = await session.transaction().execute(TEST_QUERY, commit_tx=True)
assert len(result_sets) == 1
assert result_sets[0].rows[0].sum == 3

try:
await session_pool.retry_operation(query_callback)
finally:
await session_pool.stop()
finally:
await driver.stop(timeout=10)
31 changes: 25 additions & 6 deletions ydb/aio/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,27 @@ def __init__(self, driver_config):
self._store = ConnectionsCache(driver_config.use_all_nodes)
self._grpc_init = Connection(self._driver_config.endpoint, self._driver_config)
self._stopped = False
self._discovery = Discovery(self._store, self._driver_config)

self._discovery_task = asyncio.get_event_loop().create_task(self._discovery.run())
if driver_config.disable_discovery:
# If discovery is disabled, just add the initial endpoint to the store
async def init_connection():
ready_connection = Connection(self._driver_config.endpoint, self._driver_config)
await ready_connection.connection_ready(
ready_timeout=getattr(self._driver_config, "discovery_request_timeout", 10)
)
self._store.add(ready_connection)

# Create and schedule the task to initialize the connection
self._discovery = None
self._discovery_task = asyncio.get_event_loop().create_task(init_connection())
else:
# Start discovery as usual
self._discovery = Discovery(self._store, self._driver_config)
self._discovery_task = asyncio.get_event_loop().create_task(self._discovery.run())

async def stop(self, timeout=10):
self._discovery.stop()
if self._discovery:
self._discovery.stop()
await self._grpc_init.close()
try:
await asyncio.wait_for(self._discovery_task, timeout=timeout)
Expand All @@ -215,15 +230,18 @@ async def stop(self, timeout=10):
def _on_disconnected(self, connection):
async def __wrapper__():
await connection.close()
self._discovery.notify_disconnected()
if self._discovery:
self._discovery.notify_disconnected()

return __wrapper__

async def wait(self, timeout=7, fail_fast=False):
await self._store.get(fast_fail=fail_fast, wait_timeout=timeout)

def discovery_debug_details(self):
return self._discovery.discovery_debug_details()
if self._discovery:
return self._discovery.discovery_debug_details()
return "Discovery is disabled, using only the initial endpoint"

async def __aenter__(self):
return self
Expand All @@ -248,7 +266,8 @@ async def __call__(
try:
connection = await self._store.get(preferred_endpoint, fast_fail=fast_fail, wait_timeout=wait_timeout)
except BaseException:
self._discovery.notify_disconnected()
if self._discovery:
self._discovery.notify_disconnected()
raise

return await connection(
Expand Down
Loading
Loading