Skip to content

Commit d542f7b

Browse files
committed
Add sql execute to Connection
Closes #159 Co-authored-by Denis Ignatenko <[email protected]>
1 parent e99662d commit d542f7b

File tree

6 files changed

+209
-6
lines changed

6 files changed

+209
-6
lines changed

Diff for: tarantool/connection.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
RequestSubscribe,
3535
RequestUpdate,
3636
RequestUpsert,
37-
RequestAuthenticate
37+
RequestAuthenticate,
38+
RequestExecute
3839
)
3940
from tarantool.space import Space
4041
from tarantool.const import (
@@ -257,17 +258,18 @@ def _read_response(self):
257258

258259
def _send_request_wo_reconnect(self, request):
259260
'''
260-
:rtype: `Response` instance
261+
:rtype: `Response` instance or subclass
261262
262263
:raise: NetworkError
263264
'''
264-
assert isinstance(request, Request)
265+
if not isinstance(request, Request):
266+
raise TypeError("Expected Request instance")
265267

266268
response = None
267269
while True:
268270
try:
269271
self._socket.sendall(bytes(request))
270-
response = Response(self, self._read_response())
272+
response = request.response_class(self, self._read_response())
271273
break
272274
except SchemaReloadException as e:
273275
self.update_schema(e.schema_version)
@@ -792,3 +794,22 @@ def generate_sync(self):
792794
Need override for async io connection
793795
'''
794796
return 0
797+
798+
def execute(self, query, params=None):
799+
'''
800+
Execute SQL request.
801+
802+
:param query: SQL syntax query
803+
:type query: str
804+
805+
:param params: Bind values to use in query
806+
:type params: list, dict
807+
808+
:return: query result data
809+
:rtype: `Response` instance
810+
'''
811+
if not params:
812+
params = []
813+
request = RequestExecute(self, query, params)
814+
response = self._send_request(request)
815+
return response

Diff for: tarantool/const.py

+8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
#
3030
IPROTO_DATA = 0x30
3131
IPROTO_ERROR = 0x31
32+
#
33+
IPROTO_METADATA = 0x32
34+
IPROTO_SQL_TEXT = 0x40
35+
IPROTO_SQL_BIND = 0x41
36+
IPROTO_SQL_INFO = 0x42
37+
IPROTO_SQL_INFO_ROW_COUNT = 0x00
38+
IPROTO_SQL_INFO_AUTOINCREMENT_IDS = 0x01
3239

3340
IPROTO_GREETING_SIZE = 128
3441
IPROTO_BODY_MAX_LEN = 2147483648
@@ -44,6 +51,7 @@
4451
REQUEST_TYPE_EVAL = 8
4552
REQUEST_TYPE_UPSERT = 9
4653
REQUEST_TYPE_CALL = 10
54+
REQUEST_TYPE_EXECUTE = 11
4755
REQUEST_TYPE_PING = 64
4856
REQUEST_TYPE_JOIN = 65
4957
REQUEST_TYPE_SUBSCRIBE = 66

Diff for: tarantool/request.py

+34
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44
Request types definitions
55
'''
66

7+
import collections
78
import msgpack
89
import hashlib
910

11+
try:
12+
collectionsAbc = collections.abc
13+
except AttributeError:
14+
collectionsAbc = collections
1015

16+
17+
from tarantool.error import DatabaseError
1118
from tarantool.const import (
1219
IPROTO_CODE,
1320
IPROTO_SYNC,
@@ -27,6 +34,8 @@
2734
IPROTO_OPS,
2835
# IPROTO_INDEX_BASE,
2936
IPROTO_SCHEMA_ID,
37+
IPROTO_SQL_TEXT,
38+
IPROTO_SQL_BIND,
3039
REQUEST_TYPE_OK,
3140
REQUEST_TYPE_PING,
3241
REQUEST_TYPE_SELECT,
@@ -37,11 +46,13 @@
3746
REQUEST_TYPE_UPSERT,
3847
REQUEST_TYPE_CALL16,
3948
REQUEST_TYPE_CALL,
49+
REQUEST_TYPE_EXECUTE,
4050
REQUEST_TYPE_EVAL,
4151
REQUEST_TYPE_AUTHENTICATE,
4252
REQUEST_TYPE_JOIN,
4353
REQUEST_TYPE_SUBSCRIBE
4454
)
55+
from tarantool.response import Response, ResponseExecute
4556
from tarantool.utils import (
4657
strxor,
4758
binary_types
@@ -64,6 +75,7 @@ def __init__(self, conn):
6475
self.conn = conn
6576
self._sync = None
6677
self._body = ''
78+
self.response_class = Response
6779

6880
packer_kwargs = dict()
6981

@@ -360,3 +372,25 @@ def __init__(self, conn, sync):
360372
request_body = self._dumps({IPROTO_CODE: self.request_type,
361373
IPROTO_SYNC: sync})
362374
self._body = request_body
375+
376+
377+
class RequestExecute(Request):
378+
'''
379+
Represents EXECUTE SQL request
380+
'''
381+
request_type = REQUEST_TYPE_EXECUTE
382+
383+
# pylint: disable=W0231
384+
def __init__(self, conn, sql, args):
385+
super(RequestExecute, self).__init__(conn)
386+
if isinstance(args, collectionsAbc.Mapping):
387+
args = [{":%s" % name: value} for name, value in args.items()]
388+
elif not isinstance(args, collectionsAbc.Sequence):
389+
raise TypeError("Parameter type '%s' is not supported. "
390+
"Must be a mapping or sequence" % type(args))
391+
392+
request_body = self._dumps({IPROTO_SQL_TEXT: sql,
393+
IPROTO_SQL_BIND: args})
394+
395+
self._body = request_body
396+
self.response_class = ResponseExecute

Diff for: tarantool/response.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
IPROTO_ERROR,
1818
IPROTO_SYNC,
1919
IPROTO_SCHEMA_ID,
20-
REQUEST_TYPE_ERROR
20+
REQUEST_TYPE_ERROR,
21+
IPROTO_SQL_INFO,
22+
IPROTO_SQL_INFO_ROW_COUNT,
23+
IPROTO_SQL_INFO_AUTOINCREMENT_IDS
2124
)
2225
from tarantool.error import (
2326
DatabaseError,
@@ -268,3 +271,51 @@ def __str__(self):
268271
return ''.join(output)
269272

270273
__repr__ = __str__
274+
275+
276+
class ResponseExecute(Response):
277+
@property
278+
def autoincrement_ids(self):
279+
"""
280+
Returns the list of tuples containing
281+
values of autoincrement fields for a DML request
282+
and None for a DQL request (NOT result set size)
283+
284+
:rtype: list or None
285+
"""
286+
if self._body is None:
287+
raise InterfaceError("Trying to access data, when there's no data")
288+
info = self._body.get(IPROTO_SQL_INFO)
289+
290+
if info is None:
291+
return None
292+
293+
autoincrement_ids = info.get(IPROTO_SQL_INFO_AUTOINCREMENT_IDS)
294+
295+
return autoincrement_ids
296+
297+
@property
298+
def affected_row_count(self):
299+
"""
300+
Returns the number of affected rows for responses
301+
to DML requests and None for other requests
302+
:rtype: int
303+
"""
304+
if self._body is None:
305+
raise InterfaceError("Trying to access data, when there's no data")
306+
307+
info = self._body.get(IPROTO_SQL_INFO)
308+
309+
if info is None:
310+
return None
311+
312+
return info.get(IPROTO_SQL_INFO_ROW_COUNT)
313+
314+
@property
315+
def rows(self):
316+
"""
317+
:rtype: list or None
318+
"""
319+
if self._return_code != 0:
320+
return None
321+
return self._body.get(IPROTO_DATA)

Diff for: unit/suites/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
from .test_protocol import TestSuite_Protocol
1111
from .test_reconnect import TestSuite_Reconnect
1212
from .test_mesh import TestSuite_Mesh
13+
from .test_execute import TestSuite_Execute
1314

1415
test_cases = (TestSuite_Schema_UnicodeConnection,
1516
TestSuite_Schema_BinaryConnection,
1617
TestSuite_Request, TestSuite_Protocol, TestSuite_Reconnect,
17-
TestSuite_Mesh)
18+
TestSuite_Mesh, TestSuite_Execute)
1819

1920
def load_tests(loader, tests, pattern):
2021
suite = unittest.TestSuite()

Diff for: unit/suites/test_execute.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from __future__ import print_function
4+
5+
import sys
6+
import unittest
7+
8+
import tarantool
9+
from .lib.tarantool_server import TarantoolServer
10+
11+
12+
class TestSuite_Execute(unittest.TestCase):
13+
ddl = 'create table %s (id INTEGER PRIMARY KEY AUTOINCREMENT, ' \
14+
'name varchar(20))'
15+
16+
ddl_params = [
17+
{'id': None, 'name': 'Michael'},
18+
{'id': None, 'name': 'Mary'},
19+
{'id': None, 'name': 'John'},
20+
{'id': None, 'name': 'Ruth'},
21+
{'id': None, 'name': 'Rachel'}
22+
]
23+
24+
@classmethod
25+
def setUpClass(self):
26+
print(' EXECUTE '.center(70, '='), file=sys.stderr)
27+
print('-' * 70, file=sys.stderr)
28+
self.srv = TarantoolServer()
29+
self.srv.script = 'unit/suites/box.lua'
30+
self.srv.start()
31+
self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'])
32+
33+
def setUp(self):
34+
# prevent a remote tarantool from clean our session
35+
if self.srv.is_started():
36+
self.srv.touch_lock()
37+
self.con.flush_schema()
38+
39+
# grant full access to guest
40+
self.srv.admin("box.schema.user.grant('guest', 'create,read,write,"
41+
"execute', 'universe')")
42+
43+
@classmethod
44+
def tearDownClass(self):
45+
self.con.close()
46+
self.srv.stop()
47+
self.srv.clean()
48+
49+
def _populate_data(self, table_name):
50+
query = "insert into %s values (:id, :name)" % table_name
51+
for param in self.ddl_params:
52+
self.con.execute(query, param)
53+
54+
def _create_table(self, table_name):
55+
return self.con.execute(self.ddl % table_name)
56+
57+
def test_dml_response(self):
58+
table_name = 'foo'
59+
response = self._create_table(table_name)
60+
self.assertEqual(response.autoincrement_ids, None)
61+
self.assertEqual(response.affected_row_count, 1)
62+
self.assertEqual(response.data, None)
63+
64+
query = "insert into %s values (:id, :name)" % table_name
65+
66+
for num, param in enumerate(self.ddl_params, start=1):
67+
response = self.con.execute(query, param)
68+
self.assertEqual(response.autoincrement_ids[0], num)
69+
self.assertEqual(response.affected_row_count, 1)
70+
self.assertEqual(response.data, None)
71+
72+
query = "delete from %s where id in (4, 5)" % table_name
73+
response = self.con.execute(query)
74+
self.assertEqual(response.autoincrement_ids, None)
75+
self.assertEqual(response.affected_row_count, 2)
76+
self.assertEqual(response.data, None)
77+
78+
def test_dql_response(self):
79+
table_name = 'bar'
80+
self._create_table(table_name)
81+
self._populate_data(table_name)
82+
83+
select_query = "select name from %s where id in (1, 3, 5)" % table_name
84+
response = self.con.execute(select_query)
85+
self.assertEqual(response.autoincrement_ids, None)
86+
self.assertEqual(response.affected_row_count, None)
87+
expected_data = [['Michael'], ['John'], ['Rachel']]
88+
self.assertListEqual(response.data, expected_data)

0 commit comments

Comments
 (0)