Skip to content

Commit 3a9de8b

Browse files
mdesmethashhar
authored andcommitted
Support "segment" cursor style
1 parent 43ed692 commit 3a9de8b

File tree

3 files changed

+63
-4
lines changed

3 files changed

+63
-4
lines changed

tests/integration/test_dbapi_integration.py

+23
Original file line numberDiff line numberDiff line change
@@ -1861,6 +1861,29 @@ def test_select_query_spooled_segments(trino_connection):
18611861
assert isinstance(row[13], str), f"Expected string for shipinstruct, got {type(row[13])}"
18621862

18631863

1864+
@pytest.mark.skipif(
1865+
trino_version() <= 466,
1866+
reason="spooling protocol was introduced in version 466"
1867+
)
1868+
def test_segments_cursor(trino_connection):
1869+
if trino_connection._client_session.encoding is None:
1870+
with pytest.raises(ValueError, match=".*encoding.*"):
1871+
trino_connection.cursor("segment")
1872+
return
1873+
cur = trino_connection.cursor("segment")
1874+
cur.execute("""SELECT l.*
1875+
FROM tpch.tiny.lineitem l, TABLE(sequence(
1876+
start => 1,
1877+
stop => 5,
1878+
step => 1)) n""")
1879+
rows = cur.fetchall()
1880+
assert len(rows) > 0
1881+
for spooled_data, spooled_segment in rows:
1882+
assert spooled_data.encoding == trino_connection._client_session.encoding
1883+
assert isinstance(spooled_segment.uri, str), f"Expected string for uri, got {spooled_segment.uri}"
1884+
assert isinstance(spooled_segment.ack_uri, str), f"Expected string for ack_uri, got {spooled_segment.ack_uri}"
1885+
1886+
18641887
def get_cursor(legacy_prepared_statements, run_trino):
18651888
host, port = run_trino
18661889

trino/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ def __init__(
799799
request: TrinoRequest,
800800
query: str,
801801
legacy_primitive_types: bool = False,
802+
fetch_mode: Literal["mapped", "segments"] = "mapped"
802803
) -> None:
803804
self._query_id: Optional[str] = None
804805
self._stats: Dict[Any, Any] = {}
@@ -815,6 +816,7 @@ def __init__(
815816
self._result: Optional[TrinoResult] = None
816817
self._legacy_primitive_types = legacy_primitive_types
817818
self._row_mapper: Optional[RowMapper] = None
819+
self._fetch_mode = fetch_mode
818820

819821
@property
820822
def query_id(self) -> Optional[str]:
@@ -919,6 +921,8 @@ def fetch(self) -> List[Union[List[Any]], Any]:
919921
# spooling protocol
920922
rows = cast(_SpooledProtocolResponseTO, rows)
921923
segments = self._to_segments(rows)
924+
if self._fetch_mode == "segments":
925+
return segments
922926
return list(SegmentIterator(segments, self._row_mapper))
923927
elif isinstance(status.rows, list):
924928
return self._row_mapper.map(rows)

trino/dbapi.py

+36-4
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def _create_request(self):
268268
self.request_timeout,
269269
)
270270

271-
def cursor(self, legacy_primitive_types: bool = None):
271+
def cursor(self, cursor_style: str = "row", legacy_primitive_types: bool = None):
272272
"""Return a new :py:class:`Cursor` object using the connection."""
273273
if self.isolation_level != IsolationLevel.AUTOCOMMIT:
274274
if self.transaction is None:
@@ -277,11 +277,21 @@ def cursor(self, legacy_primitive_types: bool = None):
277277
request = self.transaction.request
278278
else:
279279
request = self._create_request()
280-
return Cursor(
280+
281+
cursor_class = {
282+
# Add any custom Cursor classes here
283+
"segment": SegmentCursor,
284+
"row": Cursor
285+
}.get(cursor_style.lower(), Cursor)
286+
287+
return cursor_class(
281288
self,
282289
request,
283-
# if legacy params are not explicitly set in Cursor, take them from Connection
284-
legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types
290+
legacy_primitive_types=(
291+
legacy_primitive_types
292+
if legacy_primitive_types is not None
293+
else self.legacy_primitive_types
294+
)
285295
)
286296

287297
def _use_legacy_prepared_statements(self):
@@ -714,6 +724,28 @@ def close(self):
714724
# but also any other outstanding queries executed through this cursor.
715725

716726

727+
class SegmentCursor(Cursor):
728+
def __init__(
729+
self,
730+
connection,
731+
request,
732+
legacy_primitive_types: bool = False):
733+
super().__init__(connection, request, legacy_primitive_types=legacy_primitive_types)
734+
if self.connection._client_session.encoding is None:
735+
raise ValueError("SegmentCursor can only be used if encoding is set on the connection")
736+
737+
def execute(self, operation, params=None):
738+
if params:
739+
# TODO: refactor code to allow for params to be supported
740+
raise ValueError("params not supported")
741+
742+
self._query = trino.client.TrinoQuery(self._request, query=operation,
743+
legacy_primitive_types=self._legacy_primitive_types,
744+
fetch_mode="segments")
745+
self._iterator = iter(self._query.execute())
746+
return self
747+
748+
717749
Date = datetime.date
718750
Time = datetime.time
719751
Timestamp = datetime.datetime

0 commit comments

Comments
 (0)