@@ -268,7 +268,7 @@ def _create_request(self):
268
268
self .request_timeout ,
269
269
)
270
270
271
- def cursor (self , legacy_primitive_types : bool = None ):
271
+ def cursor (self , cursor_style : str = "row" , legacy_primitive_types : bool = None ):
272
272
"""Return a new :py:class:`Cursor` object using the connection."""
273
273
if self .isolation_level != IsolationLevel .AUTOCOMMIT :
274
274
if self .transaction is None :
@@ -277,11 +277,21 @@ def cursor(self, legacy_primitive_types: bool = None):
277
277
request = self .transaction .request
278
278
else :
279
279
request = self ._create_request ()
280
- return Cursor (
280
+
281
+ cursor_class = {
282
+ # Add any custom Cursor classes here
283
+ "segment" : SegmentCursor ,
284
+ "row" : Cursor
285
+ }.get (cursor_style .lower (), Cursor )
286
+
287
+ return cursor_class (
281
288
self ,
282
289
request ,
283
- # if legacy params are not explicitly set in Cursor, take them from Connection
284
- legacy_primitive_types if legacy_primitive_types is not None else self .legacy_primitive_types
290
+ legacy_primitive_types = (
291
+ legacy_primitive_types
292
+ if legacy_primitive_types is not None
293
+ else self .legacy_primitive_types
294
+ )
285
295
)
286
296
287
297
def _use_legacy_prepared_statements (self ):
@@ -714,6 +724,28 @@ def close(self):
714
724
# but also any other outstanding queries executed through this cursor.
715
725
716
726
727
+ class SegmentCursor (Cursor ):
728
+ def __init__ (
729
+ self ,
730
+ connection ,
731
+ request ,
732
+ legacy_primitive_types : bool = False ):
733
+ super ().__init__ (connection , request , legacy_primitive_types = legacy_primitive_types )
734
+ if self .connection ._client_session .encoding is None :
735
+ raise ValueError ("SegmentCursor can only be used if encoding is set on the connection" )
736
+
737
+ def execute (self , operation , params = None ):
738
+ if params :
739
+ # TODO: refactor code to allow for params to be supported
740
+ raise ValueError ("params not supported" )
741
+
742
+ self ._query = trino .client .TrinoQuery (self ._request , query = operation ,
743
+ legacy_primitive_types = self ._legacy_primitive_types ,
744
+ fetch_mode = "segments" )
745
+ self ._iterator = iter (self ._query .execute ())
746
+ return self
747
+
748
+
717
749
Date = datetime .date
718
750
Time = datetime .time
719
751
Timestamp = datetime .datetime
0 commit comments