Skip to content

Commit 489b743

Browse files
gridnevvvitVitalii Gridnev
authored and
Vitalii Gridnev
committed
add compression support to ydb sdk
1 parent bd7a4e3 commit 489b743

File tree

6 files changed

+64
-5
lines changed

6 files changed

+64
-5
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 2.1.0 ##
2+
3+
* add compression support to ydb sdk
4+
15
## 1.1.16 ##
26

37
* alias `kikimr.public.sdk.python.client` is deprecated. use `import ydb` instead.

test-requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ docker==5.0.0
1010
docker-compose==1.29.2
1111
dockerpty==0.4.1
1212
docopt==0.6.2
13-
grpcio==1.38.0
13+
grpcio>=1.38.0
1414
idna==3.2
1515
importlib-metadata==4.6.1
1616
iniconfig==1.1.1

tests/aio/test_connection_pool.py

+19
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,25 @@ async def test_async_call(endpoint, database):
2121
await driver.stop()
2222

2323

24+
@pytest.mark.asyncio
25+
async def test_gzip_compression(endpoint, database):
26+
driver_config = ydb.DriverConfig(
27+
endpoint,
28+
database,
29+
credentials=ydb.construct_credentials_from_environ(),
30+
root_certificates=ydb.load_ydb_root_certificate(),
31+
compression=ydb.RPCCompression.Gzip,
32+
)
33+
34+
driver = Driver(driver_config=driver_config)
35+
36+
await driver.scheme_client.make_directory(
37+
"/local/lol",
38+
settings=ydb.BaseRequestSettings().with_compression(ydb.RPCCompression.Deflate),
39+
)
40+
await driver.stop()
41+
42+
2443
@pytest.mark.asyncio
2544
async def test_other_credentials(endpoint, database):
2645
driver = Driver(endpoint=endpoint, database=database)

ydb/connection.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,22 @@ def channel_factory(
300300
logger.debug("Channel options: {}".format(options))
301301

302302
if driver_config.root_certificates is None and not driver_config.secure_channel:
303-
return channel_provider.insecure_channel(endpoint, options)
303+
return channel_provider.insecure_channel(
304+
endpoint, options, compression=getattr(driver_config, "compression", None)
305+
)
306+
304307
root_certificates = driver_config.root_certificates
305308
if root_certificates is None:
306309
root_certificates = default_pem.load_default_pem()
307310
credentials = grpc.ssl_channel_credentials(
308311
root_certificates, driver_config.private_key, driver_config.certificate_chain
309312
)
310-
return channel_provider.secure_channel(endpoint, credentials, options)
313+
return channel_provider.secure_channel(
314+
endpoint,
315+
credentials,
316+
options,
317+
compression=getattr(driver_config, "compression", None),
318+
)
311319

312320

313321
class Connection(object):
@@ -405,7 +413,12 @@ def future(
405413
rpc_state, timeout, metadata = self._prepare_call(
406414
stub, rpc_name, request, settings
407415
)
408-
rendezvous, result_future = rpc_state.future(request, timeout, metadata)
416+
rendezvous, result_future = rpc_state.future(
417+
request,
418+
timeout,
419+
metadata,
420+
compression=getattr(settings, "compression", None),
421+
)
409422
rendezvous.add_done_callback(
410423
lambda resp_future: _on_response_callback(
411424
rpc_state,
@@ -443,7 +456,12 @@ def __call__(
443456
stub, rpc_name, request, settings
444457
)
445458
try:
446-
response = rpc_state(request, timeout, metadata)
459+
response = rpc_state(
460+
request,
461+
timeout,
462+
metadata,
463+
compression=getattr(settings, "compression", None),
464+
)
447465
_log_response(rpc_state, response)
448466
return (
449467
response

ydb/driver.py

+12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from . import tracing
44
import six
55
import os
6+
import grpc
67

78
if six.PY2:
89
Any = None
@@ -39,6 +40,14 @@ def parse_connection_string(connection_string):
3940
return p.scheme + "://" + p.netloc, database[0]
4041

4142

43+
class RPCCompression:
44+
"""Indicates the compression method to be used for an RPC."""
45+
46+
NoCompression = grpc.Compression.NoCompression
47+
Deflate = grpc.Compression.Deflate
48+
Gzip = grpc.Compression.Gzip
49+
50+
4251
def default_credentials(credentials=None, tracer=None):
4352
tracer = tracer if tracer is not None else tracing.Tracer(None)
4453
with tracer.trace("Driver.default_credentials") as ctx:
@@ -94,6 +103,7 @@ class DriverConfig(object):
94103
"tracer",
95104
"grpc_lb_policy_name",
96105
"discovery_request_timeout",
106+
"compression",
97107
)
98108

99109
def __init__(
@@ -115,6 +125,7 @@ def __init__(
115125
tracer=None,
116126
grpc_lb_policy_name="round_robin",
117127
discovery_request_timeout=10,
128+
compression=None,
118129
):
119130
"""
120131
A driver config to initialize a driver instance
@@ -159,6 +170,7 @@ def __init__(
159170
self.tracer = tracer if tracer is not None else tracing.Tracer(None)
160171
self.grpc_lb_policy_name = grpc_lb_policy_name
161172
self.discovery_request_timeout = discovery_request_timeout
173+
self.compression = compression
162174

163175
def set_database(self, database):
164176
self.database = database

ydb/settings.py

+6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class BaseRequestSettings(object):
99
"cancel_after",
1010
"operation_timeout",
1111
"tracer",
12+
"compression",
1213
)
1314

1415
def __init__(self):
@@ -20,6 +21,11 @@ def __init__(self):
2021
self.timeout = None
2122
self.cancel_after = None
2223
self.operation_timeout = None
24+
self.compression = None
25+
26+
def with_compression(self, compression):
27+
self.compression = compression
28+
return self
2329

2430
def with_trace_id(self, trace_id):
2531
"""

0 commit comments

Comments
 (0)