Skip to content

Commit ee59b16

Browse files
authored
Fix not fetching cursor (#59)
* Fix not fetching cursor on insert/update * Fetch the cursor on insert/update/delete/copy into * Enable CTE tests * Enable further tests * Support for table and column comments * Include CTE test now bug is fixed * Run against nightly * Update pipenv * Work in SQLAlchemy 1.4
1 parent e5afdca commit ee59b16

File tree

9 files changed

+497
-127
lines changed

9 files changed

+497
-127
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-latest
1212
services:
1313
databend:
14-
image: datafuselabs/databend
14+
image: datafuselabs/databend:nightly
1515
env:
1616
QUERY_DEFAULT_USER: databend
1717
QUERY_DEFAULT_PASSWORD: databend

Pipfile.lock

Lines changed: 95 additions & 95 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

databend_sqlalchemy/connector.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,13 @@ def escape_item(self, item):
4949
return self.escape_number(item)
5050
elif isinstance(item, timedelta):
5151
return self.escape_string(f"{item.total_seconds()} seconds") + "::interval"
52-
elif isinstance(item, (datetime, date, time, timedelta)):
53-
return self.escape_string(item.strftime("%Y-%m-%d %H:%M:%S"))
52+
elif isinstance(item, time):
53+
# N.B. Date here must match date in DatabendTime.literal_processor - 1970-01-01
54+
return self.escape_string(item.strftime("1970-01-01 %H:%M:%S.%f")) + "::timestamp"
55+
elif isinstance(item, datetime):
56+
return self.escape_string(item.strftime("%Y-%m-%d %H:%M:%S.%f")) + "::timestamp"
57+
elif isinstance(item, date):
58+
return self.escape_string(item.strftime("%Y-%m-%d")) + "::date"
5459
else:
5560
return self.escape_string(item)
5661

databend_sqlalchemy/databend_dialect.py

Lines changed: 191 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import datetime
3131
from types import NoneType
3232

33+
import sqlalchemy.engine.reflection
3334
import sqlalchemy.types as sqltypes
3435
from typing import Any, Dict, Optional, Union
3536
from sqlalchemy import util as sa_util
@@ -44,6 +45,7 @@
4445
Subquery,
4546
)
4647
from sqlalchemy.dialects.postgresql.base import PGCompiler, PGIdentifierPreparer
48+
from sqlalchemy import Table, MetaData, Column
4749
from sqlalchemy.types import (
4850
BIGINT,
4951
INTEGER,
@@ -670,7 +672,7 @@ def process(value):
670672
class DatabendDateTime(sqltypes.DATETIME):
671673
__visit_name__ = "DATETIME"
672674

673-
_reg = re.compile(r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)")
675+
_reg = re.compile(r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)\.(\d+)")
674676

675677
def result_processor(self, dialect, coltype):
676678
def process(value):
@@ -698,7 +700,7 @@ def process(value):
698700
class DatabendTime(sqltypes.TIME):
699701
__visit_name__ = "TIME"
700702

701-
_reg = re.compile(r"(?:\d+)-(?:\d+)-(?:\d+) (\d+):(\d+):(\d+)")
703+
_reg = re.compile(r"(?:\d+)-(?:\d+)-(?:\d+) (\d+):(\d+):(\d+)\.(\d+)")
702704

703705
def result_processor(self, dialect, coltype):
704706
def process(value):
@@ -720,7 +722,7 @@ def literal_processor(self, dialect):
720722
def process(value):
721723
if value is not None:
722724
from_min_value = datetime.datetime.combine(
723-
datetime.date(1000, 1, 1), value
725+
datetime.date(1970, 1, 1), value
724726
)
725727
time_str = from_min_value.isoformat(timespec="microseconds")
726728
return f"'{time_str}'"
@@ -800,6 +802,9 @@ class DatabendIdentifierPreparer(PGIdentifierPreparer):
800802

801803

802804
class DatabendCompiler(PGCompiler):
805+
iscopyintotable: bool = False
806+
iscopyintolocation: bool = False
807+
803808
def get_select_precolumns(self, select, **kw):
804809
# call the base implementation because Databend doesn't support DISTINCT ON
805810
return super(PGCompiler, self).get_select_precolumns(select, **kw)
@@ -971,6 +976,11 @@ def visit_when_merge_unmatched(self, merge_unmatched, **kw):
971976
)
972977

973978
def visit_copy_into(self, copy_into, **kw):
979+
if isinstance(copy_into.target, (TableClause,)):
980+
self.iscopyintotable = True
981+
else:
982+
self.iscopyintolocation = True
983+
974984
target = (
975985
self.preparer.format_table(copy_into.target)
976986
if isinstance(copy_into.target, (TableClause,))
@@ -1090,8 +1100,21 @@ def visit_google_cloud_storage(self, gcs: GoogleCloudStorage, **kw):
10901100
f")"
10911101
)
10921102

1103+
def visit_stage(self, stage, **kw):
1104+
if stage.path:
1105+
return f"@{stage.name}/{stage.path}"
1106+
return f"@{stage.name}"
1107+
10931108

10941109
class DatabendExecutionContext(default.DefaultExecutionContext):
1110+
iscopyintotable = False
1111+
iscopyintolocation = False
1112+
1113+
_copy_input_bytes: Optional[int] = None
1114+
_copy_output_bytes: Optional[int] = None
1115+
_copy_into_table_results: Optional[list[dict]] = None
1116+
_copy_into_location_results: dict = None
1117+
10951118
@sa_util.memoized_property
10961119
def should_autocommit(self):
10971120
return False # No DML supported, never autocommit
@@ -1102,6 +1125,38 @@ def create_server_side_cursor(self):
11021125
def create_default_cursor(self):
11031126
return self._dbapi_connection.cursor()
11041127

1128+
def post_exec(self):
1129+
self.iscopyintotable = getattr(self.compiled, 'iscopyintotable', False)
1130+
self.iscopyintolocation = getattr(self.compiled, 'iscopyintolocation', False)
1131+
if (self.isinsert or self.isupdate or self.isdelete or
1132+
self.iscopyintolocation or self.iscopyintotable):
1133+
result = self.cursor.fetchall()
1134+
if self.iscopyintotable:
1135+
self._copy_into_table_results = [
1136+
{
1137+
'file': row[0],
1138+
'rows_loaded': row[1],
1139+
'errors_seen': row[2],
1140+
'first_error': row[3],
1141+
'first_error_line': row[4],
1142+
} for row in result
1143+
]
1144+
self._rowcount = sum(c['rows_loaded'] for c in self._copy_into_table_results)
1145+
else:
1146+
self._rowcount = result[0][0]
1147+
if self.iscopyintolocation:
1148+
self._copy_into_location_results = {
1149+
'rows_unloaded': result[0][0],
1150+
'input_bytes': result[0][1],
1151+
'output_bytes': result[0][2],
1152+
}
1153+
1154+
def copy_into_table_results(self) -> list[dict]:
1155+
return self._copy_into_table_results
1156+
1157+
def copy_into_location_results(self) -> dict:
1158+
return self._copy_into_location_results
1159+
11051160

11061161
class DatabendTypeCompiler(compiler.GenericTypeCompiler):
11071162
def visit_ARRAY(self, type_, **kw):
@@ -1171,6 +1226,12 @@ def post_create_table(self, table):
11711226
if engine is not None:
11721227
table_opts.append(f" ENGINE={engine}")
11731228

1229+
if table.comment is not None:
1230+
comment = self.sql_compiler.render_literal_value(
1231+
table.comment, sqltypes.String()
1232+
)
1233+
table_opts.append(f" COMMENT={comment}")
1234+
11741235
cluster_keys = db_opts.get("cluster_by")
11751236
if cluster_keys is not None:
11761237
if isinstance(cluster_keys, str):
@@ -1192,6 +1253,37 @@ def post_create_table(self, table):
11921253

11931254
return " ".join(table_opts)
11941255

1256+
def get_column_specification(self, column, **kwargs):
1257+
colspec = super().get_column_specification(column, **kwargs)
1258+
comment = column.comment
1259+
if comment is not None:
1260+
literal = self.sql_compiler.render_literal_value(
1261+
comment, sqltypes.String()
1262+
)
1263+
colspec += " COMMENT " + literal
1264+
1265+
return colspec
1266+
1267+
def visit_set_table_comment(self, create, **kw):
1268+
return "ALTER TABLE %s COMMENT = %s" % (
1269+
self.preparer.format_table(create.element),
1270+
self.sql_compiler.render_literal_value(
1271+
create.element.comment, sqltypes.String()
1272+
),
1273+
)
1274+
1275+
def visit_drop_table_comment(self, create, **kw):
1276+
return "ALTER TABLE %s COMMENT = ''" % (
1277+
self.preparer.format_table(create.element)
1278+
)
1279+
1280+
def visit_set_column_comment(self, create, **kw):
1281+
return "ALTER TABLE %s MODIFY %s %s" % (
1282+
self.preparer.format_table(create.element.table),
1283+
self.preparer.format_column(create.element),
1284+
self.get_column_specification(create.element),
1285+
)
1286+
11951287

11961288
class DatabendDialect(default.DefaultDialect):
11971289
name = "databend"
@@ -1204,7 +1296,7 @@ class DatabendDialect(default.DefaultDialect):
12041296
supports_alter = True
12051297
supports_comments = False
12061298
supports_empty_insert = False
1207-
supports_is_distinct_from = False
1299+
supports_is_distinct_from = True
12081300
supports_multivalues_insert = True
12091301

12101302
supports_statement_cache = False
@@ -1316,7 +1408,7 @@ def has_table(self, connection, table_name, schema=None, **kw):
13161408
def get_columns(self, connection, table_name, schema=None, **kw):
13171409
query = text(
13181410
"""
1319-
select column_name, column_type, is_nullable
1411+
select column_name, column_type, is_nullable, nullif(column_comment, '')
13201412
from information_schema.columns
13211413
where table_name = :table_name
13221414
and table_schema = :schema_name
@@ -1337,6 +1429,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
13371429
"type": self._get_column_type(row[1]),
13381430
"nullable": get_is_nullable(row[2]),
13391431
"default": None,
1432+
"comment": row[3],
13401433
}
13411434
for row in result
13421435
]
@@ -1416,6 +1509,23 @@ def get_table_names(self, connection, schema=None, **kw):
14161509
result = connection.execute(query, dict(schema_name=schema))
14171510
return [row[0] for row in result]
14181511

1512+
@reflection.cache
1513+
def get_temp_table_names(self, connection, schema=None, **kw):
1514+
table_name_query = """
1515+
select name
1516+
from system.temporary_tables
1517+
where database = :schema_name
1518+
"""
1519+
query = text(table_name_query).bindparams(
1520+
bindparam("schema_name", type_=sqltypes.Unicode)
1521+
)
1522+
if schema is None:
1523+
schema = self.default_schema_name
1524+
1525+
result = connection.execute(query, dict(schema_name=schema))
1526+
return [row[0] for row in result]
1527+
1528+
14191529
@reflection.cache
14201530
def get_view_names(self, connection, schema=None, **kw):
14211531
view_name_query = """
@@ -1510,6 +1620,82 @@ def get_table_options(self, connection, table_name, schema=None, **kw):
15101620

15111621
return options
15121622

1623+
@reflection.cache
1624+
def get_table_comment(self, connection, table_name, schema, **kw):
1625+
query_text = """
1626+
SELECT comment
1627+
FROM system.tables
1628+
WHERE database = :schema_name
1629+
and name = :table_name
1630+
"""
1631+
query = text(query_text).bindparams(
1632+
bindparam("table_name", type_=sqltypes.Unicode),
1633+
bindparam("schema_name", type_=sqltypes.Unicode),
1634+
)
1635+
if schema is None:
1636+
schema = self.default_schema_name
1637+
1638+
result = connection.execute(
1639+
query, dict(table_name=table_name, schema_name=schema)
1640+
).one_or_none()
1641+
if not result:
1642+
raise NoSuchTableError(
1643+
f"{self.identifier_preparer.quote_identifier(schema)}."
1644+
f"{self.identifier_preparer.quote_identifier(table_name)}"
1645+
)
1646+
return {'text': result[0]} if result[0] else reflection.ReflectionDefaults.table_comment() if hasattr(reflection, 'ReflectionDefault') else {'text': None}
1647+
1648+
def _prepare_filter_names(self, filter_names):
1649+
if filter_names:
1650+
fn = [name for name in filter_names]
1651+
return True, {"filter_names": fn}
1652+
else:
1653+
return False, {}
1654+
1655+
def get_multi_table_comment(
1656+
self, connection, schema, filter_names, scope, kind, **kw
1657+
):
1658+
meta = MetaData()
1659+
all_tab_comments=Table(
1660+
"tables",
1661+
meta,
1662+
Column("database", VARCHAR, nullable=False),
1663+
Column("name", VARCHAR, nullable=False),
1664+
Column("comment", VARCHAR),
1665+
Column("table_type", VARCHAR),
1666+
schema='system',
1667+
).alias("a_tab_comments")
1668+
1669+
1670+
has_filter_names, params = self._prepare_filter_names(filter_names)
1671+
owner = schema or self.default_schema_name
1672+
1673+
table_types = set()
1674+
if reflection.ObjectKind.TABLE in kind:
1675+
table_types.add('BASE TABLE')
1676+
if reflection.ObjectKind.VIEW in kind:
1677+
table_types.add('VIEW')
1678+
1679+
query = select(
1680+
all_tab_comments.c.name, all_tab_comments.c.comment
1681+
).where(
1682+
all_tab_comments.c.database == owner,
1683+
all_tab_comments.c.table_type.in_(table_types),
1684+
sqlalchemy.true() if reflection.ObjectScope.DEFAULT in scope else sqlalchemy.false(),
1685+
)
1686+
if has_filter_names:
1687+
query = query.where(all_tab_comments.c.name.in_(bindparam("filter_names")))
1688+
1689+
result = connection.execute(query, params)
1690+
default_comment = reflection.ReflectionDefaults.table_comment
1691+
return (
1692+
(
1693+
(schema, table),
1694+
{"text": comment} if comment else default_comment(),
1695+
)
1696+
for table, comment in result
1697+
)
1698+
15131699
def do_rollback(self, dbapi_connection):
15141700
# No transactions
15151701
pass

databend_sqlalchemy/provision.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
from sqlalchemy.testing.provision import create_db
33
from sqlalchemy.testing.provision import drop_db
4-
from sqlalchemy.testing.provision import configure_follower, update_db_opts
4+
from sqlalchemy.testing.provision import configure_follower, update_db_opts, temp_table_keyword_args
55

66

77
@create_db.for_db("databend")
@@ -31,6 +31,10 @@ def _databend_drop_db(cfg, eng, ident):
3131
conn.exec_driver_sql("DROP DATABASE IF EXISTS %s_test_schema_2" % ident)
3232
conn.exec_driver_sql("DROP DATABASE IF EXISTS %s" % ident)
3333

34+
@temp_table_keyword_args.for_db("databend")
35+
def _databend_temp_table_keyword_args(cfg, eng):
36+
return {"prefixes": ["TEMPORARY"]}
37+
3438

3539
@configure_follower.for_db("databend")
3640
def _databend_configure_follower(config, ident):
@@ -39,5 +43,5 @@ def _databend_configure_follower(config, ident):
3943

4044
# Uncomment to debug SQL Statements in tests
4145
# @update_db_opts.for_db("databend")
42-
# def _mssql_update_db_opts(db_url, db_opts):
46+
# def _databend_update_db_opts(db_url, db_opts):
4347
# db_opts["echo"] = True

0 commit comments

Comments
 (0)