Skip to content

Commit

Permalink
Merge pull request #548 from ydb-platform/tx_retryer
Browse files Browse the repository at this point in the history
Transactional retryer
  • Loading branch information
vgvoleg authored Jan 23, 2025
2 parents 228bb52 + ba5c216 commit 24badc7
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tests/aio/query/test_query_session_pool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import asyncio
import pytest
import ydb

from typing import Optional

from ydb.aio.query.pool import QuerySessionPool
from ydb.aio.query.session import QuerySession, QuerySessionStateEnum
from ydb.aio.query.transaction import QueryTxContext


class TestQuerySessionPool:
Expand Down Expand Up @@ -55,6 +59,43 @@ async def callee(session: QuerySession):
with pytest.raises(CustomException):
await pool.retry_operation_async(callee)

@pytest.mark.parametrize(
"tx_mode",
[
(None),
(ydb.QuerySerializableReadWrite()),
(ydb.QuerySnapshotReadOnly()),
(ydb.QueryOnlineReadOnly()),
(ydb.QueryStaleReadOnly()),
],
)
@pytest.mark.asyncio
async def test_retry_tx_normal(self, pool: QuerySessionPool, tx_mode: Optional[ydb.BaseQueryTxMode]):
retry_no = 0

async def callee(tx: QueryTxContext):
nonlocal retry_no
if retry_no < 2:
retry_no += 1
raise ydb.Unavailable("Fake fast backoff error")
result_stream = await tx.execute("SELECT 1")
return [result_set async for result_set in result_stream]

result = await pool.retry_tx_async(callee=callee, tx_mode=tx_mode)
assert len(result) == 1
assert retry_no == 2

@pytest.mark.asyncio
async def test_retry_tx_raises(self, pool: QuerySessionPool):
class CustomException(Exception):
pass

async def callee(tx: QueryTxContext):
raise CustomException()

with pytest.raises(CustomException):
await pool.retry_tx_async(callee)

@pytest.mark.asyncio
async def test_pool_size_limit_logic(self, pool: QuerySessionPool):
target_size = 5
Expand Down
39 changes: 39 additions & 0 deletions tests/query/test_query_session_pool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import pytest
import ydb

from typing import Optional

from ydb.query.pool import QuerySessionPool
from ydb.query.session import QuerySession, QuerySessionStateEnum
from ydb.query.transaction import QueryTxContext


class TestQuerySessionPool:
Expand Down Expand Up @@ -46,6 +50,41 @@ def callee(session: QuerySession):
with pytest.raises(CustomException):
pool.retry_operation_sync(callee)

@pytest.mark.parametrize(
"tx_mode",
[
(None),
(ydb.QuerySerializableReadWrite()),
(ydb.QuerySnapshotReadOnly()),
(ydb.QueryOnlineReadOnly()),
(ydb.QueryStaleReadOnly()),
],
)
def test_retry_tx_normal(self, pool: QuerySessionPool, tx_mode: Optional[ydb.BaseQueryTxMode]):
retry_no = 0

def callee(tx: QueryTxContext):
nonlocal retry_no
if retry_no < 2:
retry_no += 1
raise ydb.Unavailable("Fake fast backoff error")
result_stream = tx.execute("SELECT 1")
return [result_set for result_set in result_stream]

result = pool.retry_tx_sync(callee=callee, tx_mode=tx_mode)
assert len(result) == 1
assert retry_no == 2

def test_retry_tx_raises(self, pool: QuerySessionPool):
class CustomException(Exception):
pass

def callee(tx: QueryTxContext):
raise CustomException()

with pytest.raises(CustomException):
pool.retry_tx_sync(callee)

def test_pool_size_limit_logic(self, pool: QuerySessionPool):
target_size = 5
pool._size = target_size
Expand Down
35 changes: 35 additions & 0 deletions ydb/aio/query/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
RetrySettings,
retry_operation_async,
)
from ...query.base import BaseQueryTxMode
from ...query.base import QueryClientSettings
from ... import convert
from ..._grpc.grpcwrapper import common_utils
from ..._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -122,6 +124,39 @@ async def wrapped_callee():

return await retry_operation_async(wrapped_callee, retry_settings)

async def retry_tx_async(
self,
callee: Callable,
tx_mode: Optional[BaseQueryTxMode] = None,
retry_settings: Optional[RetrySettings] = None,
*args,
**kwargs,
):
"""Special interface to execute a bunch of commands with transaction in a safe, retriable way.
:param callee: A function, that works with session.
:param tx_mode: Transaction mode, which is a one from the following choises:
1) QuerySerializableReadWrite() which is default mode;
2) QueryOnlineReadOnly(allow_inconsistent_reads=False);
3) QuerySnapshotReadOnly();
4) QueryStaleReadOnly().
:param retry_settings: RetrySettings object.
:return: Result sets or exception in case of execution errors.
"""

tx_mode = tx_mode if tx_mode else _ydb_query_public.QuerySerializableReadWrite()
retry_settings = RetrySettings() if retry_settings is None else retry_settings

async def wrapped_callee():
async with self.checkout() as session:
async with session.transaction(tx_mode=tx_mode) as tx:
result = await callee(tx, *args, **kwargs)
await tx.commit()
return result

return await retry_operation_async(wrapped_callee, retry_settings)

async def execute_with_retries(
self,
query: str,
Expand Down
35 changes: 35 additions & 0 deletions ydb/query/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
import queue

from .base import BaseQueryTxMode
from .base import QueryClientSettings
from .session import (
QuerySession,
Expand All @@ -20,6 +21,7 @@
from .. import convert
from ..settings import BaseRequestSettings
from .._grpc.grpcwrapper import common_utils
from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -138,6 +140,39 @@ def wrapped_callee():

return retry_operation_sync(wrapped_callee, retry_settings)

def retry_tx_sync(
self,
callee: Callable,
tx_mode: Optional[BaseQueryTxMode] = None,
retry_settings: Optional[RetrySettings] = None,
*args,
**kwargs,
):
"""Special interface to execute a bunch of commands with transaction in a safe, retriable way.
:param callee: A function, that works with session.
:param tx_mode: Transaction mode, which is a one from the following choises:
1) QuerySerializableReadWrite() which is default mode;
2) QueryOnlineReadOnly(allow_inconsistent_reads=False);
3) QuerySnapshotReadOnly();
4) QueryStaleReadOnly().
:param retry_settings: RetrySettings object.
:return: Result sets or exception in case of execution errors.
"""

tx_mode = tx_mode if tx_mode else _ydb_query_public.QuerySerializableReadWrite()
retry_settings = RetrySettings() if retry_settings is None else retry_settings

def wrapped_callee():
with self.checkout(timeout=retry_settings.max_session_acquire_timeout) as session:
with session.transaction(tx_mode=tx_mode) as tx:
result = callee(tx, *args, **kwargs)
tx.commit()
return result

return retry_operation_sync(wrapped_callee, retry_settings)

def execute_with_retries(
self,
query: str,
Expand Down

0 comments on commit 24badc7

Please sign in to comment.