Skip to content

Commit d8b216b

Browse files
committed
Small fixes
1 parent d2f2429 commit d8b216b

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

trino/client.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,8 @@ def columns(self):
753753
while not self._columns and not self.finished and not self.cancelled:
754754
# Columns are not returned immediately after query is submitted.
755755
# Continue fetching data until columns information is available and push fetched rows into buffer.
756-
self._result.rows += self.fetch()
756+
if self._result:
757+
self._result.rows += self.fetch()
757758
return self._columns
758759

759760
@property

trino/dbapi.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,12 @@ def stats(self) -> Optional[Dict[Any, Any]]:
338338
return self._query.stats
339339
return None
340340

341+
@property
342+
def query_id(self) -> Optional[str]:
343+
if self._query is not None:
344+
return self._query.query_id
345+
return None
346+
341347
@property
342348
def warnings(self) -> Optional[List[Dict[Any, Any]]]:
343349
if self._query is not None:
@@ -505,6 +511,7 @@ def executemany(self, operation: str, seq_of_params: Any) -> None:
505511
for parameters in seq_of_params[:-1]:
506512
self.execute(operation, parameters)
507513
self.fetchall()
514+
assert self._query is not None
508515
if self._query.update_type is None:
509516
raise NotSupportedError("Query must return update type")
510517
if seq_of_params:
@@ -586,8 +593,10 @@ def describe(self, sql: str) -> List[DescribeOutput]:
586593

587594
return list(map(lambda x: DescribeOutput.from_row(x), result))
588595

589-
def genall(self) -> trino.client.TrinoResult:
590-
return self._query.result
596+
def genall(self) -> Any:
597+
if self._query:
598+
return self._query.result
599+
return None
591600

592601
def fetchall(self) -> List[List[Any]]:
593602
return list(self.genall())
@@ -625,6 +634,8 @@ def __init__(self, *values: str):
625634
self.values = [v.lower() for v in values]
626635

627636
def __eq__(self, other: object) -> bool:
637+
if not isinstance(other, str):
638+
return NotImplemented
628639
return other.lower() in self.values
629640

630641

0 commit comments

Comments
 (0)