Skip to content

Commit 8affa61

Browse files
mdesmethashhar
authored andcommitted
Add Cursor.describe to retrieve the schema of a query
1 parent d96bff2 commit 8affa61

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

tests/integration/test_dbapi_integration.py

+67
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import trino
2222
from tests.integration.conftest import trino_version
2323
from trino import constants
24+
from trino.dbapi import DescribeOutput
2425
from trino.exceptions import NotSupportedError, TrinoQueryError, TrinoUserError
2526
from trino.transaction import IsolationLevel
2627

@@ -1155,3 +1156,69 @@ def test_connection_without_timezone(run_trino):
11551156
assert session_tz == localzone or \
11561157
(session_tz == "UTC" and localzone == "Etc/UTC") \
11571158
# Workaround for difference between Trino timezone and tzlocal for UTC
1159+
1160+
1161+
def test_describe(run_trino):
1162+
_, host, port = run_trino
1163+
1164+
trino_connection = trino.dbapi.Connection(
1165+
host=host, port=port, user="test", catalog="tpch",
1166+
)
1167+
cur = trino_connection.cursor()
1168+
1169+
result = cur.describe("SELECT 1, DECIMAL '1.0' as a")
1170+
1171+
assert result == [
1172+
DescribeOutput(name='_col0', catalog='', schema='', table='', type='integer', type_size=4, aliased=False),
1173+
DescribeOutput(name='a', catalog='', schema='', table='', type='decimal(2,1)', type_size=8, aliased=True)
1174+
]
1175+
1176+
1177+
def test_describe_table_query(run_trino):
1178+
_, host, port = run_trino
1179+
1180+
trino_connection = trino.dbapi.Connection(
1181+
host=host, port=port, user="test", catalog="tpch",
1182+
)
1183+
cur = trino_connection.cursor()
1184+
1185+
result = cur.describe("SELECT * from tpch.tiny.nation")
1186+
1187+
assert result == [
1188+
DescribeOutput(
1189+
name='nationkey',
1190+
catalog='tpch',
1191+
schema='tiny',
1192+
table='nation',
1193+
type='bigint',
1194+
type_size=8,
1195+
aliased=False,
1196+
),
1197+
DescribeOutput(
1198+
name='name',
1199+
catalog='tpch',
1200+
schema='tiny',
1201+
table='nation',
1202+
type='varchar(25)',
1203+
type_size=0,
1204+
aliased=False,
1205+
),
1206+
DescribeOutput(
1207+
name='regionkey',
1208+
catalog='tpch',
1209+
schema='tiny',
1210+
table='nation',
1211+
type='bigint',
1212+
type_size=8,
1213+
aliased=False,
1214+
),
1215+
DescribeOutput(
1216+
name='comment',
1217+
catalog='tpch',
1218+
schema='tiny',
1219+
table='nation',
1220+
type='varchar(152)',
1221+
type_size=0,
1222+
aliased=False,
1223+
)
1224+
]

trino/dbapi.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import math
2323
import uuid
2424
from decimal import Decimal
25-
from typing import Any, Dict, List, Optional # NOQA for mypy types
25+
from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types
2626

2727
import trino.client
2828
import trino.exceptions
@@ -223,6 +223,20 @@ def cursor(self, experimental_python_types: bool = None):
223223
)
224224

225225

226+
class DescribeOutput(NamedTuple):
227+
name: str
228+
catalog: str
229+
schema: str
230+
table: str
231+
type: str
232+
type_size: int
233+
aliased: bool
234+
235+
@classmethod
236+
def from_row(cls, row: List[Any]):
237+
return cls(*row)
238+
239+
226240
class Cursor(object):
227241
"""Database cursor.
228242
@@ -523,6 +537,28 @@ def fetchmany(self, size=None) -> List[List[Any]]:
523537

524538
return result
525539

540+
def describe(self, sql: str) -> List[DescribeOutput]:
541+
"""
542+
List the output columns of a SQL statement, including the column name (or alias), catalog, schema, table, type,
543+
type size in bytes, and a boolean indicating if the column is aliased.
544+
545+
:param sql: SQL statement
546+
"""
547+
statement_name = self._generate_unique_statement_name()
548+
self._prepare_statement(sql, statement_name)
549+
try:
550+
sql = f"DESCRIBE OUTPUT {statement_name}"
551+
self._query = trino.client.TrinoQuery(
552+
self._request,
553+
sql=sql,
554+
experimental_python_types=self._experimental_pyton_types,
555+
)
556+
result = self._query.execute()
557+
finally:
558+
self._deallocate_prepared_statement(statement_name)
559+
560+
return list(map(lambda x: DescribeOutput.from_row(x), result))
561+
526562
def genall(self):
527563
return self._query.result
528564

0 commit comments

Comments
 (0)