Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions tests/aio/test_discovery_detect_local_dc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from ydb import driver, connection
from ydb.aio import pool, nearest_dc


class MockEndpointInfo:
def __init__(self, address, port, location):
self.address = address
self.port = port
self.endpoint = f"{address}:{port}"
self.location = location
self.ssl = False
self.node_id = 1

def endpoints_with_options(self):
yield (self.endpoint, connection.EndpointOptions(ssl_target_name_override=None, node_id=self.node_id))


class MockDiscoveryResult:
def __init__(self, self_location, endpoints):
self.self_location = self_location
self.endpoints = endpoints


@pytest.mark.asyncio
async def test_detect_local_dc_overrides_server_location():
"""Test that detected location overrides server's self_location for preferred endpoints."""
# Server reports dc1, but we detect dc2 as nearest
endpoints = [
MockEndpointInfo("dc1-host", 2135, "dc1"),
MockEndpointInfo("dc2-host", 2135, "dc2"),
]
mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints)

mock_resolver = MagicMock()
mock_resolver.resolve = AsyncMock(return_value=mock_result)

preferred = []

def mock_init(self, endpoint, driver_config, endpoint_options=None):
self.endpoint = endpoint
self.node_id = 1

with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value="dc2")):
with patch("ydb.aio.connection.Connection.__init__", mock_init):
with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()):
with patch("ydb.aio.connection.Connection.close", AsyncMock()):
with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None):
config = driver.DriverConfig(
endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False
)
discovery = pool.Discovery(
store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config
)
discovery._resolver = mock_resolver

original_add = discovery._cache.add
discovery._cache.add = lambda conn, pref=False: (
preferred.append(conn.endpoint) if pref else None,
original_add(conn, pref),
)[1]

await discovery.execute_discovery()

assert any("dc2" in ep for ep in preferred), "dc2 should be preferred (detected)"
assert not any("dc1" in ep for ep in preferred), "dc1 should not be preferred"


@pytest.mark.asyncio
async def test_detect_local_dc_failure_fallback():
"""Test that detection failure falls back to server's self_location."""
endpoints = [
MockEndpointInfo("dc1-host", 2135, "dc1"),
MockEndpointInfo("dc2-host", 2135, "dc2"),
]
mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints)

mock_resolver = MagicMock()
mock_resolver.resolve = AsyncMock(return_value=mock_result)

preferred = []

def mock_init(self, endpoint, driver_config, endpoint_options=None):
self.endpoint = endpoint
self.node_id = 1

with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value=None)):
with patch("ydb.aio.connection.Connection.__init__", mock_init):
with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()):
with patch("ydb.aio.connection.Connection.close", AsyncMock()):
with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None):
config = driver.DriverConfig(
endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False
)
discovery = pool.Discovery(
store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config
)
discovery._resolver = mock_resolver

original_add = discovery._cache.add
discovery._cache.add = lambda conn, pref=False: (
preferred.append(conn.endpoint) if pref else None,
original_add(conn, pref),
)[1]

await discovery.execute_discovery()

assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)"
159 changes: 159 additions & 0 deletions tests/aio/test_nearest_dc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import asyncio
import pytest
from ydb.aio import nearest_dc


class MockEndpoint:
def __init__(self, address, port, location):
self.address = address
self.port = port
self.endpoint = f"{address}:{port}"
self.location = location


class MockWriter:
def __init__(self):
self.closed = False

def close(self):
self.closed = True

async def wait_closed(self):
await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_check_fastest_endpoint_empty():
assert await nearest_dc._check_fastest_endpoint([]) is None


@pytest.mark.asyncio
async def test_check_fastest_endpoint_all_fail(monkeypatch):
async def fake_open_connection(host, port):
raise OSError("connect failed")

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)

endpoints = [
MockEndpoint("a", 1, "dc1"),
MockEndpoint("b", 1, "dc2"),
]
assert await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) is None


@pytest.mark.asyncio
async def test_check_fastest_endpoint_fastest_wins(monkeypatch):
async def fake_open_connection(host, port):
if host == "slow":
await asyncio.sleep(0.05)
return None, MockWriter()

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)

endpoints = [
MockEndpoint("slow", 1, "dc_slow"),
MockEndpoint("fast", 1, "dc_fast"),
]
winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.2)
assert winner is not None
assert winner.location == "dc_fast"


@pytest.mark.asyncio
async def test_check_fastest_endpoint_respects_main_timeout(monkeypatch):
async def fake_open_connection(host, port):
await asyncio.sleep(0.2)
return None, MockWriter()

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)

endpoints = [
MockEndpoint("hang1", 1, "dc1"),
MockEndpoint("hang2", 1, "dc2"),
]

winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05)

assert winner is None


@pytest.mark.asyncio
async def test_detect_local_dc_empty_endpoints():
with pytest.raises(ValueError, match="Empty endpoints"):
await nearest_dc.detect_local_dc([])


@pytest.mark.asyncio
async def test_detect_local_dc_single_location_returns_immediately(monkeypatch):
async def fail_if_called(*args, **kwargs):
raise AssertionError("open_connection should not be called for single location")

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fail_if_called)

endpoints = [
MockEndpoint("h1", 1, "dc1"),
MockEndpoint("h2", 1, "dc1"),
]
assert await nearest_dc.detect_local_dc(endpoints) == "dc1"


@pytest.mark.asyncio
async def test_detect_local_dc_returns_none_when_all_fail(monkeypatch):
async def fake_open_connection(host, port):
raise OSError("connect failed")

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)

endpoints = [
MockEndpoint("bad1", 9999, "dc1"),
MockEndpoint("bad2", 9999, "dc2"),
]
assert await nearest_dc.detect_local_dc(endpoints, timeout=0.05) is None


@pytest.mark.asyncio
async def test_detect_local_dc_returns_location_of_fastest(monkeypatch):
async def fake_open_connection(host, port):
if host == "dc1_host":
await asyncio.sleep(0.05)
return None, MockWriter()

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)

endpoints = [
MockEndpoint("dc1_host", 1, "dc1"),
MockEndpoint("dc2_host", 1, "dc2"),
]
assert await nearest_dc.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2"


@pytest.mark.asyncio
async def test_detect_local_dc_respects_max_per_location(monkeypatch):
calls = []

async def fake_open_connection(host, port):
calls.append((host, port))
raise OSError("connect failed")

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)

endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [
MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5)
]
await nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2)

assert len(calls) == 4


@pytest.mark.asyncio
async def test_detect_local_dc_validates_max_per_location():
endpoints = [MockEndpoint("h1", 1, "dc1")]
with pytest.raises(ValueError, match="max_per_location must be >= 1"):
await nearest_dc.detect_local_dc(endpoints, max_per_location=0)


@pytest.mark.asyncio
async def test_detect_local_dc_validates_timeout():
endpoints = [MockEndpoint("h1", 1, "dc1")]
with pytest.raises(ValueError, match="timeout must be > 0"):
await nearest_dc.detect_local_dc(endpoints, timeout=0)
94 changes: 94 additions & 0 deletions tests/test_discovery_detect_local_dc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
from unittest.mock import Mock, MagicMock, patch
from ydb import driver, pool, nearest_dc, connection


class MockEndpointInfo:
def __init__(self, address, port, location):
self.address = address
self.port = port
self.endpoint = f"{address}:{port}"
self.location = location
self.ssl = False
self.node_id = 1

def endpoints_with_options(self):
yield (self.endpoint, connection.EndpointOptions(ssl_target_name_override=None, node_id=self.node_id))


class MockDiscoveryResult:
def __init__(self, self_location, endpoints):
self.self_location = self_location
self.endpoints = endpoints


def test_detect_local_dc_overrides_server_location():
"""Test that detected location overrides server's self_location for preferred endpoints."""
# Server reports dc1, but we detect dc2 as nearest
endpoints = [
MockEndpointInfo("dc1-host", 2135, "dc1"),
MockEndpointInfo("dc2-host", 2135, "dc2"),
]
mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints)

mock_resolver = MagicMock()
mock_resolver.context_resolve.return_value.__enter__.return_value = mock_result
mock_resolver.context_resolve.return_value.__exit__.return_value = None

preferred = []

with patch.object(nearest_dc, "detect_local_dc", Mock(return_value="dc2")):
with patch(
"ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1)
):
config = driver.DriverConfig(
endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False
)
discovery = pool.Discovery(store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config)
discovery._resolver = mock_resolver

original_add = discovery._cache.add
discovery._cache.add = lambda conn, pref=False: (
preferred.append(conn.endpoint) if pref else None,
original_add(conn, pref),
)[1]

discovery.execute_discovery()

assert any("dc2" in ep for ep in preferred), "dc2 should be preferred (detected)"
assert not any("dc1" in ep for ep in preferred), "dc1 should not be preferred"


def test_detect_local_dc_failure_fallback():
"""Test that detection failure falls back to server's self_location."""
endpoints = [
MockEndpointInfo("dc1-host", 2135, "dc1"),
MockEndpointInfo("dc2-host", 2135, "dc2"),
]
mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints)

mock_resolver = MagicMock()
mock_resolver.context_resolve.return_value.__enter__.return_value = mock_result
mock_resolver.context_resolve.return_value.__exit__.return_value = None

preferred = []

with patch.object(nearest_dc, "detect_local_dc", Mock(return_value=None)):
with patch(
"ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1)
):
config = driver.DriverConfig(
endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False
)
discovery = pool.Discovery(store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config)
discovery._resolver = mock_resolver

original_add = discovery._cache.add
discovery._cache.add = lambda conn, pref=False: (
preferred.append(conn.endpoint) if pref else None,
original_add(conn, pref),
)[1]

discovery.execute_discovery()

assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)"
Loading
Loading