Skip to content

Commit 139dde2

Browse files
committed
feat: add nearest DC detection with TCP race
1 parent 1272b08 commit 139dde2

File tree

7 files changed

+659
-4
lines changed

7 files changed

+659
-4
lines changed

tests/aio/test_nearest_dc.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import asyncio
2+
import pytest
3+
from ydb.aio import nearest_dc
4+
5+
6+
class MockEndpoint:
7+
def __init__(self, address, port, location):
8+
self.address = address
9+
self.port = port
10+
self.endpoint = f"{address}:{port}"
11+
self.location = location
12+
13+
14+
class MockWriter:
15+
def __init__(self):
16+
self.closed = False
17+
18+
def close(self):
19+
self.closed = True
20+
21+
async def wait_closed(self):
22+
await asyncio.sleep(0)
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_check_fastest_endpoint_empty():
27+
assert await nearest_dc._check_fastest_endpoint([]) is None
28+
29+
30+
@pytest.mark.asyncio
31+
async def test_check_fastest_endpoint_all_fail(monkeypatch):
32+
async def fake_open_connection(host, port):
33+
raise OSError("connect failed")
34+
35+
monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)
36+
37+
endpoints = [
38+
MockEndpoint("a", 1, "dc1"),
39+
MockEndpoint("b", 1, "dc2"),
40+
]
41+
assert await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) is None
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_check_fastest_endpoint_fastest_wins(monkeypatch):
46+
async def fake_open_connection(host, port):
47+
if host == "slow":
48+
await asyncio.sleep(0.05)
49+
return None, MockWriter()
50+
51+
monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)
52+
53+
endpoints = [
54+
MockEndpoint("slow", 1, "dc_slow"),
55+
MockEndpoint("fast", 1, "dc_fast"),
56+
]
57+
winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.2)
58+
assert winner is not None
59+
assert winner.location == "dc_fast"
60+
61+
62+
@pytest.mark.asyncio
63+
async def test_check_fastest_endpoint_respects_main_timeout(monkeypatch):
64+
async def fake_open_connection(host, port):
65+
await asyncio.sleep(0.2)
66+
return None, MockWriter()
67+
68+
monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)
69+
70+
endpoints = [
71+
MockEndpoint("hang1", 1, "dc1"),
72+
MockEndpoint("hang2", 1, "dc2"),
73+
]
74+
75+
winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05)
76+
77+
assert winner is None
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_detect_local_dc_empty_endpoints():
82+
with pytest.raises(ValueError, match="Empty endpoints"):
83+
await nearest_dc.detect_local_dc([])
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_detect_local_dc_single_location_returns_immediately(monkeypatch):
88+
async def fail_if_called(*args, **kwargs):
89+
raise AssertionError("open_connection should not be called for single location")
90+
91+
monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fail_if_called)
92+
93+
endpoints = [
94+
MockEndpoint("h1", 1, "dc1"),
95+
MockEndpoint("h2", 1, "dc1"),
96+
]
97+
assert await nearest_dc.detect_local_dc(endpoints) == "dc1"
98+
99+
100+
@pytest.mark.asyncio
101+
async def test_detect_local_dc_fallback_to_first_location_when_all_fail(monkeypatch):
102+
async def fake_open_connection(host, port):
103+
raise OSError("connect failed")
104+
105+
monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)
106+
107+
endpoints = [
108+
MockEndpoint("bad1", 9999, "dc1"),
109+
MockEndpoint("bad2", 9999, "dc2"),
110+
]
111+
assert await nearest_dc.detect_local_dc(endpoints, timeout=0.05) == "dc1"
112+
113+
114+
@pytest.mark.asyncio
115+
async def test_detect_local_dc_returns_location_of_fastest(monkeypatch):
116+
async def fake_open_connection(host, port):
117+
if host == "dc1_host":
118+
await asyncio.sleep(0.05)
119+
return None, MockWriter()
120+
121+
monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)
122+
123+
endpoints = [
124+
MockEndpoint("dc1_host", 1, "dc1"),
125+
MockEndpoint("dc2_host", 1, "dc2"),
126+
]
127+
assert await nearest_dc.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2"
128+
129+
130+
@pytest.mark.asyncio
131+
async def test_detect_local_dc_respects_max_per_location(monkeypatch):
132+
calls = []
133+
134+
async def fake_open_connection(host, port):
135+
calls.append((host, port))
136+
raise OSError("connect failed")
137+
138+
monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)
139+
140+
endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [
141+
MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5)
142+
]
143+
await nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2)
144+
145+
assert len(calls) == 4

tests/test_nearest_dc.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import time
2+
import pytest
3+
from ydb import nearest_dc
4+
5+
6+
class MockEndpoint:
7+
def __init__(self, address, port, location):
8+
self.address = address
9+
self.port = port
10+
self.endpoint = f"{address}:{port}"
11+
self.location = location
12+
13+
14+
class DummySock:
15+
def close(self):
16+
pass
17+
18+
19+
def test_check_fastest_endpoint_empty():
20+
assert nearest_dc._check_fastest_endpoint([]) is None
21+
22+
23+
def test_check_fastest_endpoint_all_fail(monkeypatch):
24+
def fake_create_connection(addr_port, timeout=None):
25+
raise OSError("connect failed")
26+
27+
monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection)
28+
29+
endpoints = [
30+
MockEndpoint("a", 1, "dc1"),
31+
MockEndpoint("b", 1, "dc2"),
32+
]
33+
assert nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) is None
34+
35+
36+
def test_check_fastest_endpoint_fastest_wins(monkeypatch):
37+
def fake_create_connection(addr_port, timeout=None):
38+
host, _ = addr_port
39+
if host == "slow":
40+
time.sleep(0.05)
41+
return DummySock()
42+
43+
monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection)
44+
45+
endpoints = [
46+
MockEndpoint("slow", 1, "dc_slow"),
47+
MockEndpoint("fast", 1, "dc_fast"),
48+
]
49+
winner = nearest_dc._check_fastest_endpoint(endpoints, timeout=0.2)
50+
assert winner is not None
51+
assert winner.location == "dc_fast"
52+
53+
54+
def test_check_fastest_endpoint_respects_main_timeout(monkeypatch):
55+
def fake_create_connection(addr_port, timeout=None):
56+
time.sleep(0.2)
57+
return DummySock()
58+
59+
monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection)
60+
61+
endpoints = [
62+
MockEndpoint("hang1", 1, "dc1"),
63+
MockEndpoint("hang2", 1, "dc2"),
64+
]
65+
66+
winner = nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05)
67+
68+
assert winner is None
69+
70+
71+
def test_detect_local_dc_empty_endpoints():
72+
with pytest.raises(ValueError, match="Empty endpoints"):
73+
nearest_dc.detect_local_dc([])
74+
75+
76+
def test_detect_local_dc_single_location_returns_immediately(monkeypatch):
77+
def fail_if_called(*args, **kwargs):
78+
raise AssertionError("create_connection should not be called for single location")
79+
80+
monkeypatch.setattr(nearest_dc.socket, "create_connection", fail_if_called)
81+
82+
endpoints = [
83+
MockEndpoint("h1", 1, "dc1"),
84+
MockEndpoint("h2", 1, "dc1"),
85+
]
86+
assert nearest_dc.detect_local_dc(endpoints) == "dc1"
87+
88+
89+
def test_detect_local_dc_fallback_to_first_location_when_all_fail(monkeypatch):
90+
def fake_create_connection(addr_port, timeout=None):
91+
raise OSError("connect failed")
92+
93+
monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection)
94+
95+
endpoints = [
96+
MockEndpoint("bad1", 9999, "dc1"),
97+
MockEndpoint("bad2", 9999, "dc2"),
98+
]
99+
assert nearest_dc.detect_local_dc(endpoints, timeout=0.05) == "dc1"
100+
101+
102+
def test_detect_local_dc_returns_location_of_fastest(monkeypatch):
103+
def fake_create_connection(addr_port, timeout=None):
104+
host, _ = addr_port
105+
if host == "dc1_host":
106+
time.sleep(0.05)
107+
return DummySock()
108+
109+
monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection)
110+
111+
endpoints = [
112+
MockEndpoint("dc1_host", 1, "dc1"),
113+
MockEndpoint("dc2_host", 1, "dc2"),
114+
]
115+
assert nearest_dc.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2"
116+
117+
118+
def test_detect_local_dc_respects_max_per_location(monkeypatch):
119+
calls = []
120+
121+
def fake_create_connection(addr_port, timeout=None):
122+
calls.append(addr_port)
123+
raise OSError("connect failed")
124+
125+
monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection)
126+
127+
endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [
128+
MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5)
129+
]
130+
nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2)
131+
132+
assert len(calls) == 4

0 commit comments

Comments
 (0)