Skip to content

Commit 51db95b

Browse files
api: custom packer and unpacker factories
After this patch, user may pass `packer_factory` and `unpacker_factory` options to a connection. They will be used instead of the default ones. `packer_factory` is expected to be a function with the only one parameter: connection object, which returns a new msgpack.Packer object. `unpacker_factory` is expected to be a function with the only one parameter: connection object, which returns a new msgpack.Unpacker object. `packer_factory` supersedes `encoding` option. `unpacker_factory` supersedes `encoding` and `use_list` options. User may implement `encoding` and `use_list` support in its custom packer or unpacker if they wish so. User may refer to request submodule `packer_factory` and response submodule `unpacker_factory` as an example (these factories are used by default.) Closes #145, #190, #191
1 parent 70b0a3d commit 51db95b

File tree

8 files changed

+224
-15
lines changed

8 files changed

+224
-15
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
## Unreleased
88

99
### Added
10+
- Support custom packer and unpacker factories (#191).
1011

1112
### Changed
1213

tarantool/connection.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121

2222
import msgpack
2323

24-
from tarantool.response import Response
24+
from tarantool.response import (
25+
unpacker_factory as default_unpacker_factory,
26+
Response,
27+
)
2528
from tarantool.request import (
29+
packer_factory as default_packer_factory,
2630
Request,
2731
# RequestOK,
2832
RequestCall,
@@ -357,7 +361,9 @@ def __init__(self, host, port,
357361
ssl_key_file=DEFAULT_SSL_KEY_FILE,
358362
ssl_cert_file=DEFAULT_SSL_CERT_FILE,
359363
ssl_ca_file=DEFAULT_SSL_CA_FILE,
360-
ssl_ciphers=DEFAULT_SSL_CIPHERS):
364+
ssl_ciphers=DEFAULT_SSL_CIPHERS,
365+
packer_factory=default_packer_factory,
366+
unpacker_factory=default_unpacker_factory):
361367
"""
362368
:param host: Server hostname or IP address. Use ``None`` for
363369
Unix sockets.
@@ -395,6 +401,16 @@ def __init__(self, host, port,
395401
:param encoding: ``'utf-8'`` or ``None``. Use ``None`` to work
396402
with non-UTF8 strings.
397403
404+
If non-default
405+
:paramref:`~tarantool.Connection.packer_factory` option is
406+
used, :paramref:`~tarantool.Connection.encoding` option
407+
value is ignored on encode until the factory explicitly uses
408+
its value. If non-default
409+
:paramref:`~tarantool.Connection.unpacker_factory` option is
410+
used, :paramref:`~tarantool.Connection.encoding` option
411+
value is ignored on decode until the factory explicitly uses
412+
its value.
413+
398414
If ``'utf-8'``, pack Unicode string (:obj:`str`) to
399415
MessagePack string (`mp_str`_) and unpack MessagePack string
400416
(`mp_str`_) Unicode string (:obj:`str`), pack :obj:`bytes`
@@ -429,6 +445,13 @@ def __init__(self, host, port,
429445
:param use_list:
430446
If ``True``, unpack MessagePack array (`mp_array`_) to
431447
:obj:`list`. Otherwise, unpack to :obj:`tuple`.
448+
449+
If non-default
450+
:paramref:`~tarantool.Connection.unpacker_factory` option is
451+
used,
452+
:paramref:`~tarantool.Connection.use_list` option value is
453+
ignored on decode until the factory explicitly uses its
454+
value.
432455
:type use_list: :obj:`bool`, optional
433456
434457
:param call_16:
@@ -463,6 +486,23 @@ def __init__(self, host, port,
463486
suites the connection can use.
464487
:type ssl_ciphers: :obj:`str` or :obj:`None`, optional
465488
489+
:param packer_factory: Request MessagePack packer factory.
490+
Supersedes :paramref:`~tarantool.Connection.encoding`. See
491+
:func:`~tarantool.request.packer_factory` for example of
492+
a packer factory.
493+
:type packer_factory:
494+
callable[[:obj:`~tarantool.Connection`], :obj:`~msgpack.Packer`],
495+
optional
496+
497+
:param unpacker_factory: Response MessagePack unpacker factory.
498+
Supersedes :paramref:`~tarantool.Connection.encoding` and
499+
:paramref:`~tarantool.Connection.use_list`. See
500+
:func:`~tarantool.response.unpacker_factory` for example of
501+
an unpacker factory.
502+
:type unpacker_factory:
503+
callable[[:obj:`~tarantool.Connection`], :obj:`~msgpack.Unpacker`],
504+
optional
505+
466506
:raise: :exc:`~tarantool.error.ConfigurationError`,
467507
:meth:`~tarantool.Connection.connect` exceptions
468508
@@ -514,6 +554,8 @@ def __init__(self, host, port,
514554
IPROTO_FEATURE_ERROR_EXTENSION: False,
515555
IPROTO_FEATURE_WATCHERS: False,
516556
}
557+
self._packer_factory_impl = packer_factory
558+
self._unpacker_factory_impl = unpacker_factory
517559

518560
if connect_now:
519561
self.connect()
@@ -1749,3 +1791,9 @@ def _check_features(self):
17491791
features_list = [val for val in CONNECTOR_FEATURES if val in server_features]
17501792
for val in features_list:
17511793
self._features[val] = True
1794+
1795+
def _packer_factory(self):
1796+
return self._packer_factory_impl(self)
1797+
1798+
def _unpacker_factory(self):
1799+
return self._unpacker_factory_impl(self)

tarantool/request.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363

6464
from tarantool.msgpack_ext.packer import default as packer_default
6565

66-
def build_packer(conn):
66+
def packer_factory(conn):
6767
"""
6868
Build packer to pack request.
6969
@@ -148,7 +148,7 @@ def __init__(self, conn):
148148
self._body = ''
149149
self.response_class = Response
150150

151-
self.packer = build_packer(conn)
151+
self.packer = conn._packer_factory()
152152

153153
def _dumps(self, src):
154154
"""

tarantool/response.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from tarantool.msgpack_ext.unpacker import ext_hook as unpacker_ext_hook
3434

35-
def build_unpacker(conn):
35+
def unpacker_factory(conn):
3636
"""
3737
Build unpacker to unpack request response.
3838
@@ -108,7 +108,7 @@ def __init__(self, conn, response):
108108
# created in the __new__().
109109
# super(Response, self).__init__()
110110

111-
unpacker = build_unpacker(conn)
111+
unpacker = conn._unpacker_factory()
112112

113113
unpacker.feed(response)
114114
header = unpacker.unpack()

test/suites/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@
2222
from .test_package import TestSuite_Package
2323
from .test_error_ext import TestSuite_ErrorExt
2424
from .test_push import TestSuite_Push
25+
from .test_connection import TestSuite_Connection
2526

2627
test_cases = (TestSuite_Schema_UnicodeConnection,
2728
TestSuite_Schema_BinaryConnection,
2829
TestSuite_Request, TestSuite_Protocol, TestSuite_Reconnect,
2930
TestSuite_Mesh, TestSuite_Execute, TestSuite_DBAPI,
3031
TestSuite_Encoding, TestSuite_Pool, TestSuite_Ssl,
3132
TestSuite_Decimal, TestSuite_UUID, TestSuite_Datetime,
32-
TestSuite_Interval, TestSuite_ErrorExt, TestSuite_Push,)
33+
TestSuite_Interval, TestSuite_ErrorExt, TestSuite_Push,
34+
TestSuite_Connection,)
3335

3436
def load_tests(loader, tests, pattern):
3537
suite = unittest.TestSuite()

test/suites/test_connection.py

+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import sys
2+
import unittest
3+
4+
import decimal
5+
import msgpack
6+
7+
import tarantool
8+
import tarantool.msgpack_ext.decimal as ext_decimal
9+
10+
from .lib.skip import skip_or_run_decimal_test, skip_or_run_varbinary_test
11+
from .lib.tarantool_server import TarantoolServer
12+
13+
class TestSuite_Connection(unittest.TestCase):
14+
@classmethod
15+
def setUpClass(self):
16+
print(' CONNECTION '.center(70, '='), file=sys.stderr)
17+
print('-' * 70, file=sys.stderr)
18+
self.srv = TarantoolServer()
19+
self.srv.script = 'test/suites/box.lua'
20+
self.srv.start()
21+
22+
self.adm = self.srv.admin
23+
self.adm(r"""
24+
box.schema.user.create('test', {password = 'test', if_not_exists = true})
25+
box.schema.user.grant('test', 'read,write,execute', 'universe')
26+
27+
box.schema.create_space('space_varbin')
28+
29+
box.space['space_varbin']:format({
30+
{
31+
'id',
32+
type = 'number',
33+
is_nullable = false
34+
},
35+
{
36+
'varbin',
37+
type = 'varbinary',
38+
is_nullable = false,
39+
}
40+
})
41+
42+
box.space['space_varbin']:create_index('id', {
43+
type = 'tree',
44+
parts = {1, 'number'},
45+
unique = true})
46+
47+
box.space['space_varbin']:create_index('varbin', {
48+
type = 'tree',
49+
parts = {2, 'varbinary'},
50+
unique = true})
51+
""")
52+
53+
def setUp(self):
54+
# prevent a remote tarantool from clean our session
55+
if self.srv.is_started():
56+
self.srv.touch_lock()
57+
58+
@skip_or_run_decimal_test
59+
def test_custom_packer(self):
60+
def my_ext_type_encoder(obj):
61+
if isinstance(obj, decimal.Decimal):
62+
obj = obj + 1
63+
return msgpack.ExtType(ext_decimal.EXT_ID, ext_decimal.encode(obj, None))
64+
raise TypeError("Unknown type: %r" % (obj,))
65+
66+
def my_packer_factory(_):
67+
return msgpack.Packer(default=my_ext_type_encoder)
68+
69+
self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
70+
user='test', password='test',
71+
packer_factory=my_packer_factory)
72+
73+
resp = self.con.eval("return ...", (decimal.Decimal('27756'),))
74+
self.assertSequenceEqual(resp, [decimal.Decimal('27757')])
75+
76+
def test_custom_packer_supersedes_encoding(self):
77+
def my_packer_factory(_):
78+
return msgpack.Packer(use_bin_type=False)
79+
80+
self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
81+
user='test', password='test',
82+
encoding='utf-8',
83+
packer_factory=my_packer_factory)
84+
85+
# bytes -> mp_str (string) for encoding=None
86+
# bytes -> mp_bin (varbinary) for encoding='utf-8'
87+
resp = self.con.eval("return type(...)", (bytes(bytearray.fromhex('DEADBEAF0103')),))
88+
self.assertSequenceEqual(resp, ['string'])
89+
90+
@skip_or_run_decimal_test
91+
def test_custom_unpacker(self):
92+
def my_ext_type_decoder(code, data):
93+
if code == ext_decimal.EXT_ID:
94+
return ext_decimal.decode(data, None) - 1
95+
raise NotImplementedError("Unknown msgpack extension type code %d" % (code,))
96+
97+
def my_unpacker_factory(_):
98+
if msgpack.version >= (1, 0, 0):
99+
return msgpack.Unpacker(ext_hook=my_ext_type_decoder, strict_map_key=False)
100+
return msgpack.Unpacker(ext_hook=my_ext_type_decoder)
101+
102+
103+
self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
104+
user='test', password='test',
105+
unpacker_factory=my_unpacker_factory)
106+
107+
resp = self.con.eval("return require('decimal').new('27756')")
108+
self.assertSequenceEqual(resp, [decimal.Decimal('27755')])
109+
110+
@skip_or_run_varbinary_test
111+
def test_custom_unpacker_supersedes_encoding(self):
112+
def my_unpacker_factory(_):
113+
if msgpack.version >= (0, 5, 2):
114+
if msgpack.version >= (1, 0, 0):
115+
return msgpack.Unpacker(raw=True, strict_map_key=False)
116+
117+
return msgpack.Unpacker(raw=True)
118+
return msgpack.Unpacker(encoding=None)
119+
120+
self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
121+
user='test', password='test',
122+
encoding='utf-8',
123+
unpacker_factory=my_unpacker_factory)
124+
125+
data_id = 1
126+
data_hex = 'DEADBEAF'
127+
data = bytes(bytearray.fromhex(data_hex))
128+
space = 'space_varbin'
129+
130+
self.con.execute("""
131+
INSERT INTO "%s" VALUES (%d, x'%s');
132+
""" % (space, data_id, data_hex))
133+
134+
resp = self.con.execute("""
135+
SELECT * FROM "%s" WHERE "varbin" == x'%s';
136+
""" % (space, data_hex))
137+
self.assertSequenceEqual(resp, [[data_id, data]])
138+
139+
def test_custom_unpacker_supersedes_use_list(self):
140+
def my_unpacker_factory(_):
141+
if msgpack.version >= (1, 0, 0):
142+
return msgpack.Unpacker(use_list=False, strict_map_key=False)
143+
return msgpack.Unpacker(use_list=False)
144+
145+
self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
146+
user='test', password='test',
147+
use_list=True,
148+
unpacker_factory=my_unpacker_factory)
149+
150+
resp = self.con.eval("return {1, 2, 3}")
151+
self.assertIsInstance(resp[0], tuple)
152+
153+
@classmethod
154+
def tearDown(self):
155+
if hasattr(self, 'con'):
156+
self.con.close()
157+
158+
@classmethod
159+
def tearDownClass(self):
160+
self.srv.stop()
161+
self.srv.clean()

test/suites/test_error_ext.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
from tarantool.msgpack_ext.packer import default as packer_default
1010
from tarantool.msgpack_ext.unpacker import ext_hook as unpacker_ext_hook
11-
from tarantool.request import build_packer
12-
from tarantool.response import build_unpacker
1311

1412
from .lib.tarantool_server import TarantoolServer
1513
from .lib.skip import skip_or_run_error_ext_type_test
@@ -273,7 +271,7 @@ def test_msgpack_decode(self):
273271
unpacker_ext_hook(
274272
3,
275273
case['msgpack'],
276-
build_unpacker(conn)
274+
conn._unpacker_factory(),
277275
),
278276
case['python'])
279277

@@ -330,7 +328,7 @@ def test_msgpack_encode(self):
330328
case = self.cases[name]
331329
conn = getattr(self, case['conn'])
332330

333-
self.assertEqual(packer_default(case['python'], build_packer(conn)),
331+
self.assertEqual(packer_default(case['python'], conn._packer_factory()),
334332
msgpack.ExtType(code=3, data=case['msgpack']))
335333

336334
@skip_or_run_error_ext_type_test

test/suites/test_interval.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from tarantool.msgpack_ext.packer import default as packer_default
1111
from tarantool.msgpack_ext.unpacker import ext_hook as unpacker_ext_hook
12-
from tarantool.response import build_unpacker
1312

1413
from .lib.tarantool_server import TarantoolServer
1514
from .lib.skip import skip_or_run_datetime_test
@@ -154,7 +153,7 @@ def test_msgpack_decode(self):
154153
self.assertEqual(unpacker_ext_hook(
155154
6,
156155
case['msgpack'],
157-
build_unpacker(self.con),
156+
self.con._unpacker_factory(),
158157
),
159158
case['python'])
160159

@@ -206,13 +205,13 @@ def test_unknown_field_decode(self):
206205
case = b'\x01\x09\xce\x00\x98\x96\x80'
207206
self.assertRaisesRegex(
208207
MsgpackError, 'Unknown interval field id 9',
209-
lambda: unpacker_ext_hook(6, case, build_unpacker(self.con)))
208+
lambda: unpacker_ext_hook(6, case, self.con._unpacker_factory()))
210209

211210
def test_unknown_adjust_decode(self):
212211
case = b'\x02\x07\xce\x00\x98\x96\x80\x08\x03'
213212
self.assertRaisesRegex(
214213
MsgpackError, '3 is not a valid Adjust',
215-
lambda: unpacker_ext_hook(6, case, build_unpacker(self.con)))
214+
lambda: unpacker_ext_hook(6, case, self.con._unpacker_factory()))
216215

217216

218217
arithmetic_cases = {

0 commit comments

Comments
 (0)