Skip to content

Commit c529767

Browse files
Add sql execute to Connection
Closes #159 Co-authored-by: Denis Ignatenko <[email protected]>
1 parent e99662d commit c529767

File tree

6 files changed

+212
-5
lines changed

6 files changed

+212
-5
lines changed

tarantool/connection.py

+37-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
RequestSubscribe,
3535
RequestUpdate,
3636
RequestUpsert,
37-
RequestAuthenticate
37+
RequestAuthenticate,
38+
RequestExecute
3839
)
3940
from tarantool.space import Space
4041
from tarantool.const import (
@@ -257,7 +258,7 @@ def _read_response(self):
257258

258259
def _send_request_wo_reconnect(self, request):
259260
'''
260-
:rtype: `Response` instance
261+
:rtype: `Response` instance or subclass
261262
262263
:raise: NetworkError
263264
'''
@@ -267,7 +268,7 @@ def _send_request_wo_reconnect(self, request):
267268
while True:
268269
try:
269270
self._socket.sendall(bytes(request))
270-
response = Response(self, self._read_response())
271+
response = request.response_class(self, self._read_response())
271272
break
272273
except SchemaReloadException as e:
273274
self.update_schema(e.schema_version)
@@ -792,3 +793,36 @@ def generate_sync(self):
792793
Need override for async io connection
793794
'''
794795
return 0
796+
797+
def execute(self, query, params=None):
798+
'''
799+
Execute SQL request.
800+
801+
Tarantool binary protocol for SQL requests
802+
supports "qmark" and "named" param styles.
803+
Sequence of values can be used for "qmark" style.
804+
A mapping is used for "named" param style
805+
without leading colon in the keys.
806+
807+
Example for "qmark" arguments:
808+
>>> args = ['[email protected]']
809+
>>> c.execute('select * from "users" where "email"=?', args)
810+
811+
Example for "named" arguments:
812+
>>> args = {'email': '[email protected]'}
813+
>>> c.execute('select * from "users" where "email"=:email', args)
814+
815+
:param query: SQL syntax query
816+
:type query: str
817+
818+
:param params: Bind values to use in the query.
819+
:type params: list, dict
820+
821+
:return: query result data
822+
:rtype: `Response` instance
823+
'''
824+
if not params:
825+
params = []
826+
request = RequestExecute(self, query, params)
827+
response = self._send_request(request)
828+
return response

tarantool/const.py

+8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
#
3030
IPROTO_DATA = 0x30
3131
IPROTO_ERROR = 0x31
32+
#
33+
IPROTO_METADATA = 0x32
34+
IPROTO_SQL_TEXT = 0x40
35+
IPROTO_SQL_BIND = 0x41
36+
IPROTO_SQL_INFO = 0x42
37+
IPROTO_SQL_INFO_ROW_COUNT = 0x00
38+
IPROTO_SQL_INFO_AUTOINCREMENT_IDS = 0x01
3239

3340
IPROTO_GREETING_SIZE = 128
3441
IPROTO_BODY_MAX_LEN = 2147483648
@@ -44,6 +51,7 @@
4451
REQUEST_TYPE_EVAL = 8
4552
REQUEST_TYPE_UPSERT = 9
4653
REQUEST_TYPE_CALL = 10
54+
REQUEST_TYPE_EXECUTE = 11
4755
REQUEST_TYPE_PING = 64
4856
REQUEST_TYPE_JOIN = 65
4957
REQUEST_TYPE_SUBSCRIBE = 66

tarantool/request.py

+33
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44
Request types definitions
55
'''
66

7+
import collections
78
import msgpack
89
import hashlib
910

11+
try:
12+
collectionsAbc = collections.abc
13+
except AttributeError:
14+
collectionsAbc = collections
1015

16+
17+
from tarantool.error import DatabaseError
1118
from tarantool.const import (
1219
IPROTO_CODE,
1320
IPROTO_SYNC,
@@ -27,6 +34,8 @@
2734
IPROTO_OPS,
2835
# IPROTO_INDEX_BASE,
2936
IPROTO_SCHEMA_ID,
37+
IPROTO_SQL_TEXT,
38+
IPROTO_SQL_BIND,
3039
REQUEST_TYPE_OK,
3140
REQUEST_TYPE_PING,
3241
REQUEST_TYPE_SELECT,
@@ -37,11 +46,13 @@
3746
REQUEST_TYPE_UPSERT,
3847
REQUEST_TYPE_CALL16,
3948
REQUEST_TYPE_CALL,
49+
REQUEST_TYPE_EXECUTE,
4050
REQUEST_TYPE_EVAL,
4151
REQUEST_TYPE_AUTHENTICATE,
4252
REQUEST_TYPE_JOIN,
4353
REQUEST_TYPE_SUBSCRIBE
4454
)
55+
from tarantool.response import Response, ResponseExecute
4556
from tarantool.utils import (
4657
strxor,
4758
binary_types
@@ -64,6 +75,7 @@ def __init__(self, conn):
6475
self.conn = conn
6576
self._sync = None
6677
self._body = ''
78+
self.response_class = Response
6779

6880
packer_kwargs = dict()
6981

@@ -360,3 +372,24 @@ def __init__(self, conn, sync):
360372
request_body = self._dumps({IPROTO_CODE: self.request_type,
361373
IPROTO_SYNC: sync})
362374
self._body = request_body
375+
376+
377+
class RequestExecute(Request):
378+
'''
379+
Represents EXECUTE SQL request
380+
'''
381+
request_type = REQUEST_TYPE_EXECUTE
382+
383+
def __init__(self, conn, sql, args):
384+
super(RequestExecute, self).__init__(conn)
385+
if isinstance(args, collectionsAbc.Mapping):
386+
args = [{":%s" % name: value} for name, value in args.items()]
387+
elif not isinstance(args, collectionsAbc.Sequence):
388+
raise TypeError("Parameter type '%s' is not supported. "
389+
"Must be a mapping or sequence" % type(args))
390+
391+
request_body = self._dumps({IPROTO_SQL_TEXT: sql,
392+
IPROTO_SQL_BIND: args})
393+
394+
self._body = request_body
395+
self.response_class = ResponseExecute

tarantool/response.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
IPROTO_ERROR,
1818
IPROTO_SYNC,
1919
IPROTO_SCHEMA_ID,
20-
REQUEST_TYPE_ERROR
20+
REQUEST_TYPE_ERROR,
21+
IPROTO_SQL_INFO,
22+
IPROTO_SQL_INFO_ROW_COUNT,
23+
IPROTO_SQL_INFO_AUTOINCREMENT_IDS
2124
)
2225
from tarantool.error import (
2326
DatabaseError,
@@ -268,3 +271,43 @@ def __str__(self):
268271
return ''.join(output)
269272

270273
__repr__ = __str__
274+
275+
276+
class ResponseExecute(Response):
277+
@property
278+
def autoincrement_ids(self):
279+
"""
280+
Returns a list with the new primary-key value
281+
(or values) for an INSERT in a table defined with
282+
PRIMARY KEY AUTOINCREMENT
283+
(NOT result set size)
284+
285+
:rtype: list or None
286+
"""
287+
if self._return_code != 0:
288+
return None
289+
info = self._body.get(IPROTO_SQL_INFO)
290+
291+
if info is None:
292+
return None
293+
294+
autoincrement_ids = info.get(IPROTO_SQL_INFO_AUTOINCREMENT_IDS)
295+
296+
return autoincrement_ids
297+
298+
@property
299+
def affected_row_count(self):
300+
"""
301+
Returns the number of changed rows for responses
302+
to DML requests and None for DQL requests.
303+
304+
:rtype: int
305+
"""
306+
if self._return_code != 0:
307+
return None
308+
info = self._body.get(IPROTO_SQL_INFO)
309+
310+
if info is None:
311+
return None
312+
313+
return info.get(IPROTO_SQL_INFO_ROW_COUNT)

unit/suites/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
from .test_protocol import TestSuite_Protocol
1111
from .test_reconnect import TestSuite_Reconnect
1212
from .test_mesh import TestSuite_Mesh
13+
from .test_execute import TestSuite_Execute
1314

1415
test_cases = (TestSuite_Schema_UnicodeConnection,
1516
TestSuite_Schema_BinaryConnection,
1617
TestSuite_Request, TestSuite_Protocol, TestSuite_Reconnect,
17-
TestSuite_Mesh)
18+
TestSuite_Mesh, TestSuite_Execute)
1819

1920
def load_tests(loader, tests, pattern):
2021
suite = unittest.TestSuite()

unit/suites/test_execute.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from __future__ import print_function
4+
5+
import sys
6+
import unittest
7+
8+
import tarantool
9+
from .lib.tarantool_server import TarantoolServer
10+
11+
12+
class TestSuite_Execute(unittest.TestCase):
13+
ddl = 'create table %s (id INTEGER PRIMARY KEY AUTOINCREMENT, ' \
14+
'name varchar(20))'
15+
16+
dml_params = [
17+
{'id': None, 'name': 'Michael'},
18+
{'id': None, 'name': 'Mary'},
19+
{'id': None, 'name': 'John'},
20+
{'id': None, 'name': 'Ruth'},
21+
{'id': None, 'name': 'Rachel'}
22+
]
23+
24+
@classmethod
25+
def setUpClass(self):
26+
print(' EXECUTE '.center(70, '='), file=sys.stderr)
27+
print('-' * 70, file=sys.stderr)
28+
self.srv = TarantoolServer()
29+
self.srv.script = 'unit/suites/box.lua'
30+
self.srv.start()
31+
self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'])
32+
33+
def setUp(self):
34+
# prevent a remote tarantool from clean our session
35+
if self.srv.is_started():
36+
self.srv.touch_lock()
37+
self.con.flush_schema()
38+
39+
# grant full access to guest
40+
self.srv.admin("box.schema.user.grant('guest', 'create,read,write,"
41+
"execute', 'universe')")
42+
43+
@classmethod
44+
def tearDownClass(self):
45+
self.con.close()
46+
self.srv.stop()
47+
self.srv.clean()
48+
49+
def _populate_data(self, table_name):
50+
query = "insert into %s values (:id, :name)" % table_name
51+
for param in self.dml_params:
52+
self.con.execute(query, param)
53+
54+
def _create_table(self, table_name):
55+
return self.con.execute(self.ddl % table_name)
56+
57+
def test_dml_response(self):
58+
table_name = 'foo'
59+
response = self._create_table(table_name)
60+
self.assertEqual(response.autoincrement_ids, None)
61+
self.assertEqual(response.affected_row_count, 1)
62+
self.assertEqual(response.data, None)
63+
64+
query = "insert into %s values (:id, :name)" % table_name
65+
66+
for num, param in enumerate(self.dml_params, start=1):
67+
response = self.con.execute(query, param)
68+
self.assertEqual(response.autoincrement_ids[0], num)
69+
self.assertEqual(response.affected_row_count, 1)
70+
self.assertEqual(response.data, None)
71+
72+
query = "delete from %s where id in (4, 5)" % table_name
73+
response = self.con.execute(query)
74+
self.assertEqual(response.autoincrement_ids, None)
75+
self.assertEqual(response.affected_row_count, 2)
76+
self.assertEqual(response.data, None)
77+
78+
def test_dql_response(self):
79+
table_name = 'bar'
80+
self._create_table(table_name)
81+
self._populate_data(table_name)
82+
83+
select_query = "select name from %s where id in (1, 3, 5)" % table_name
84+
response = self.con.execute(select_query)
85+
self.assertEqual(response.autoincrement_ids, None)
86+
self.assertEqual(response.affected_row_count, None)
87+
expected_data = [['Michael'], ['John'], ['Rachel']]
88+
self.assertListEqual(response.data, expected_data)

0 commit comments

Comments
 (0)