Skip to content

Commit d96bff2

Browse files
mdesmethashhar
authored andcommitted
Support VARBINARY query parameter
1 parent a83dcfe commit d96bff2

File tree

4 files changed

+36
-4
lines changed

4 files changed

+36
-4
lines changed

tests/integration/test_dbapi_integration.py

+19
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,25 @@ def test_null_date_with_time_zone(trino_connection):
533533
assert rows[0][0] is None
534534

535535

536+
@pytest.mark.parametrize(
537+
"binary_input",
538+
[
539+
bytearray("a", "utf-8"),
540+
bytearray("a", "ascii"),
541+
bytearray(b'\x00\x00\x00\x00'),
542+
bytearray(4),
543+
bytearray([1, 2, 3]),
544+
],
545+
)
546+
def test_binary_query_param(trino_connection, binary_input):
547+
cur = trino_connection.cursor(experimental_python_types=True)
548+
549+
cur.execute("SELECT ?", params=(binary_input,))
550+
rows = cur.fetchall()
551+
552+
assert rows[0][0] == binary_input
553+
554+
536555
def test_array_query_param(trino_connection):
537556
cur = trino_connection.cursor()
538557

tests/integration/test_types_integration.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,9 @@ def test_char(trino_connection):
121121

122122
def test_varbinary(trino_connection):
123123
SqlTest(trino_connection) \
124-
.add_field(sql="X'65683F'", python='ZWg/') \
125-
.add_field(sql="X''", python='') \
126-
.add_field(sql="CAST('' AS VARBINARY)", python='') \
127-
.add_field(sql="from_utf8(CAST('😂😂😂😂😂😂' AS VARBINARY))", python='😂😂😂😂😂😂') \
124+
.add_field(sql="X'65683F'", python=b'eh?') \
125+
.add_field(sql="X''", python=b'') \
126+
.add_field(sql="CAST('' AS VARBINARY)", python=b'') \
128127
.add_field(sql="CAST(null AS VARBINARY)", python=None) \
129128
.execute()
130129

trino/client.py

+10
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from __future__ import annotations
3636

3737
import abc
38+
import base64
3839
import copy
3940
import functools
4041
import os
@@ -1072,6 +1073,13 @@ def map(self, value) -> Optional[datetime]:
10721073
).round_to(self.precision).to_python_type()
10731074

10741075

1076+
class BinaryValueMapper(ValueMapper[bytes]):
1077+
def map(self, value) -> Optional[bytes]:
1078+
if value is None:
1079+
return None
1080+
return base64.b64decode(value.encode("utf8"))
1081+
1082+
10751083
class ArrayValueMapper(ValueMapper[List[Optional[Any]]]):
10761084
def __init__(self, mapper: ValueMapper[Any]):
10771085
self.mapper = mapper
@@ -1157,6 +1165,8 @@ def _create_value_mapper(self, column) -> ValueMapper:
11571165
return TimeValueMapper(self._get_precision(column))
11581166
elif col_type == 'date':
11591167
return DateValueMapper()
1168+
elif col_type == 'varbinary':
1169+
return BinaryValueMapper()
11601170
else:
11611171
return NoOpValueMapper()
11621172

trino/dbapi.py

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Fetch methods returns rows as a list of lists on purpose to let the caller
1818
decide to convert then to a list of tuples.
1919
"""
20+
import binascii
2021
import datetime
2122
import math
2223
import uuid
@@ -400,6 +401,9 @@ def _format_prepared_param(self, param):
400401
if isinstance(param, Decimal):
401402
return "DECIMAL '%s'" % param
402403

404+
if isinstance(param, (bytes, bytearray)):
405+
return "X'%s'" % binascii.hexlify(param).decode("utf-8")
406+
403407
raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param))
404408

405409
def _deallocate_prepared_statement(self, statement_name: str) -> None:

0 commit comments

Comments
 (0)