Skip to content

Commit 774741c

Browse files
authored
Merge pull request #757 from ydb-platform/use_session_node_id
Use session's node_id as preferred endpoint
2 parents 7b29e34 + 5cec34c commit 774741c

File tree

4 files changed

+170
-1
lines changed

4 files changed

+170
-1
lines changed

tests/aio/query/test_query_session.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import ydb
66
from ydb import QueryExplainResultFormat
77
from ydb.aio.query.session import QuerySession
8+
from ydb.connection import EndpointKey
89

910

1011
def _check_session_not_ready(session: QuerySession):
@@ -161,3 +162,82 @@ async def callee(session: QuerySession):
161162
assert "Lookup" in plan_lookup_string
162163
finally:
163164
await pool.execute_with_retries("DROP TABLE test_explain")
165+
166+
167+
class TestAsyncQuerySessionPreferredEndpoint:
168+
def test_endpoint_key_is_none_before_create(self, session: QuerySession):
169+
assert session._endpoint_key is None
170+
171+
@pytest.mark.asyncio
172+
async def test_endpoint_key_is_set_after_create(self, session: QuerySession):
173+
await session.create()
174+
assert session.node_id is not None
175+
assert session._endpoint_key is not None
176+
assert isinstance(session._endpoint_key, EndpointKey)
177+
assert session._endpoint_key.node_id == session.node_id
178+
179+
@pytest.mark.asyncio
180+
async def test_session_uses_preferred_endpoint_on_execute(self, session: QuerySession):
181+
await session.create()
182+
original_driver_call = session._driver
183+
184+
calls = []
185+
186+
async def mock_driver_call(*args, **kwargs):
187+
calls.append(kwargs)
188+
return await original_driver_call(*args, **kwargs)
189+
190+
session._driver = mock_driver_call
191+
192+
async with await session.execute("select 1;") as results:
193+
async for _ in results:
194+
pass
195+
196+
assert len(calls) > 0
197+
assert "preferred_endpoint" in calls[0]
198+
assert calls[0]["preferred_endpoint"] is not None
199+
assert calls[0]["preferred_endpoint"].node_id == session.node_id
200+
201+
@pytest.mark.asyncio
202+
async def test_session_uses_preferred_endpoint_on_delete(self, session: QuerySession):
203+
await session.create()
204+
original_driver_call = session._driver
205+
206+
calls = []
207+
208+
async def mock_driver_call(*args, **kwargs):
209+
calls.append(kwargs)
210+
return await original_driver_call(*args, **kwargs)
211+
212+
session._driver = mock_driver_call
213+
214+
await session.delete()
215+
216+
assert len(calls) > 0
217+
assert "preferred_endpoint" in calls[0]
218+
assert calls[0]["preferred_endpoint"] is not None
219+
assert calls[0]["preferred_endpoint"].node_id == session.node_id
220+
221+
@pytest.mark.asyncio
222+
async def test_transaction_uses_preferred_endpoint(self, session: QuerySession):
223+
await session.create()
224+
original_driver_call = session._driver
225+
226+
calls = []
227+
228+
async def mock_driver_call(*args, **kwargs):
229+
calls.append(kwargs)
230+
return await original_driver_call(*args, **kwargs)
231+
232+
session._driver = mock_driver_call
233+
234+
async with session.transaction() as tx:
235+
async with await tx.execute("select 1;") as results:
236+
async for _ in results:
237+
pass
238+
239+
execute_calls = [c for c in calls if "preferred_endpoint" in c]
240+
assert len(execute_calls) > 0
241+
for call in execute_calls:
242+
assert call["preferred_endpoint"] is not None
243+
assert call["preferred_endpoint"].node_id == session.node_id

tests/query/test_query_session.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ydb import QuerySessionPool
1010
from ydb.query.base import QueryStatsMode, QueryExplainResultFormat
1111
from ydb.query.session import QuerySession
12+
from ydb.connection import EndpointKey
1213

1314

1415
def _check_session_not_ready(session: QuerySession):
@@ -226,3 +227,78 @@ def callee(session: QuerySession):
226227
assert "Lookup" in plan_lookup_string
227228
finally:
228229
pool.execute_with_retries("DROP TABLE test_explain")
230+
231+
232+
class TestQuerySessionPreferredEndpoint:
233+
def test_endpoint_key_is_none_before_create(self, session: QuerySession):
234+
assert session._endpoint_key is None
235+
236+
def test_endpoint_key_is_set_after_create(self, session: QuerySession):
237+
session.create()
238+
assert session.node_id is not None
239+
assert session._endpoint_key is not None
240+
assert isinstance(session._endpoint_key, EndpointKey)
241+
assert session._endpoint_key.node_id == session.node_id
242+
243+
def test_session_uses_preferred_endpoint_on_execute(self, session: QuerySession):
244+
session.create()
245+
original_driver_call = session._driver
246+
247+
calls = []
248+
249+
def mock_driver_call(*args, **kwargs):
250+
calls.append(kwargs)
251+
return original_driver_call(*args, **kwargs)
252+
253+
session._driver = mock_driver_call
254+
255+
with session.execute("select 1;") as results:
256+
for _ in results:
257+
pass
258+
259+
assert len(calls) > 0
260+
assert "preferred_endpoint" in calls[0]
261+
assert calls[0]["preferred_endpoint"] is not None
262+
assert calls[0]["preferred_endpoint"].node_id == session.node_id
263+
264+
def test_session_uses_preferred_endpoint_on_delete(self, session: QuerySession):
265+
session.create()
266+
original_driver_call = session._driver
267+
268+
calls = []
269+
270+
def mock_driver_call(*args, **kwargs):
271+
calls.append(kwargs)
272+
return original_driver_call(*args, **kwargs)
273+
274+
session._driver = mock_driver_call
275+
276+
session.delete()
277+
278+
assert len(calls) > 0
279+
assert "preferred_endpoint" in calls[0]
280+
assert calls[0]["preferred_endpoint"] is not None
281+
assert calls[0]["preferred_endpoint"].node_id == session.node_id
282+
283+
def test_transaction_uses_preferred_endpoint(self, session: QuerySession):
284+
session.create()
285+
original_driver_call = session._driver
286+
287+
calls = []
288+
289+
def mock_driver_call(*args, **kwargs):
290+
calls.append(kwargs)
291+
return original_driver_call(*args, **kwargs)
292+
293+
session._driver = mock_driver_call
294+
295+
with session.transaction() as tx:
296+
with tx.execute("select 1;") as results:
297+
for _ in results:
298+
pass
299+
300+
execute_calls = [c for c in calls if "preferred_endpoint" in c]
301+
assert len(execute_calls) > 0
302+
for call in execute_calls:
303+
assert call["preferred_endpoint"] is not None
304+
assert call["preferred_endpoint"].node_id == session.node_id

ydb/query/session.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from .. import _apis, issues, _utilities
1717
from ..settings import BaseRequestSettings
18-
from ..connection import _RpcState as RpcState
18+
from ..connection import _RpcState as RpcState, EndpointKey
1919
from .._grpc.grpcwrapper import common_utils
2020
from .._grpc.grpcwrapper import ydb_query as _ydb_query
2121
from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public
@@ -85,6 +85,12 @@ def node_id(self) -> Optional[int]:
8585
def is_active(self) -> bool:
8686
return self._session_id is not None and not self._closed
8787

88+
@property
89+
def _endpoint_key(self) -> Optional[EndpointKey]:
90+
if self._node_id is None:
91+
return None
92+
return EndpointKey(endpoint=None, node_id=self._node_id)
93+
8894
@property
8995
def is_closed(self) -> bool:
9096
return self._closed
@@ -141,6 +147,7 @@ def _delete_call(self, settings: Optional[BaseRequestSettings] = None) -> "BaseQ
141147
wrap_result=wrapper_delete_session,
142148
wrap_args=(self,),
143149
settings=settings,
150+
preferred_endpoint=self._endpoint_key,
144151
)
145152

146153
def _attach_call(self) -> Iterable[_apis.ydb_query.SessionState]:
@@ -149,6 +156,7 @@ def _attach_call(self) -> Iterable[_apis.ydb_query.SessionState]:
149156
_apis.QueryService.Stub,
150157
_apis.QueryService.AttachSession,
151158
settings=self._attach_settings,
159+
preferred_endpoint=self._endpoint_key,
152160
)
153161

154162
def _execute_call(
@@ -188,6 +196,7 @@ def _execute_call(
188196
_apis.QueryService.Stub,
189197
_apis.QueryService.ExecuteQuery,
190198
settings=settings,
199+
preferred_endpoint=self._endpoint_key,
191200
)
192201

193202

ydb/query/transaction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def _begin_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxCo
259259
wrap_tx_begin_response,
260260
settings,
261261
(self.session, self._tx_state, self),
262+
preferred_endpoint=self.session._endpoint_key,
262263
)
263264

264265
def _commit_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext":
@@ -272,6 +273,7 @@ def _commit_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxC
272273
wrap_tx_commit_response,
273274
settings,
274275
(self.session, self._tx_state, self),
276+
preferred_endpoint=self.session._endpoint_key,
275277
)
276278

277279
def _rollback_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext":
@@ -285,6 +287,7 @@ def _rollback_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryT
285287
wrap_tx_rollback_response,
286288
settings,
287289
(self.session, self._tx_state, self),
290+
preferred_endpoint=self.session._endpoint_key,
288291
)
289292

290293
def _execute_call(
@@ -327,6 +330,7 @@ def _execute_call(
327330
_apis.QueryService.Stub,
328331
_apis.QueryService.ExecuteQuery,
329332
settings=settings,
333+
preferred_endpoint=self.session._endpoint_key,
330334
)
331335

332336
def _move_to_beginned(self, tx_id: str) -> None:

0 commit comments

Comments
 (0)