30
30
import datetime
31
31
from types import NoneType
32
32
33
+ import sqlalchemy .engine .reflection
33
34
import sqlalchemy .types as sqltypes
34
35
from typing import Any , Dict , Optional , Union
35
36
from sqlalchemy import util as sa_util
44
45
Subquery ,
45
46
)
46
47
from sqlalchemy .dialects .postgresql .base import PGCompiler , PGIdentifierPreparer
48
+ from sqlalchemy import Table , MetaData , Column
47
49
from sqlalchemy .types import (
48
50
BIGINT ,
49
51
INTEGER ,
@@ -670,7 +672,7 @@ def process(value):
670
672
class DatabendDateTime (sqltypes .DATETIME ):
671
673
__visit_name__ = "DATETIME"
672
674
673
- _reg = re .compile (r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)" )
675
+ _reg = re .compile (r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)\.(\d+) " )
674
676
675
677
def result_processor (self , dialect , coltype ):
676
678
def process (value ):
@@ -698,7 +700,7 @@ def process(value):
698
700
class DatabendTime (sqltypes .TIME ):
699
701
__visit_name__ = "TIME"
700
702
701
- _reg = re .compile (r"(?:\d+)-(?:\d+)-(?:\d+) (\d+):(\d+):(\d+)" )
703
+ _reg = re .compile (r"(?:\d+)-(?:\d+)-(?:\d+) (\d+):(\d+):(\d+)\.(\d+) " )
702
704
703
705
def result_processor (self , dialect , coltype ):
704
706
def process (value ):
@@ -720,7 +722,7 @@ def literal_processor(self, dialect):
720
722
def process (value ):
721
723
if value is not None :
722
724
from_min_value = datetime .datetime .combine (
723
- datetime .date (1000 , 1 , 1 ), value
725
+ datetime .date (1970 , 1 , 1 ), value
724
726
)
725
727
time_str = from_min_value .isoformat (timespec = "microseconds" )
726
728
return f"'{ time_str } '"
@@ -800,6 +802,9 @@ class DatabendIdentifierPreparer(PGIdentifierPreparer):
800
802
801
803
802
804
class DatabendCompiler (PGCompiler ):
805
+ iscopyintotable : bool = False
806
+ iscopyintolocation : bool = False
807
+
803
808
def get_select_precolumns (self , select , ** kw ):
804
809
# call the base implementation because Databend doesn't support DISTINCT ON
805
810
return super (PGCompiler , self ).get_select_precolumns (select , ** kw )
@@ -971,6 +976,11 @@ def visit_when_merge_unmatched(self, merge_unmatched, **kw):
971
976
)
972
977
973
978
def visit_copy_into (self , copy_into , ** kw ):
979
+ if isinstance (copy_into .target , (TableClause ,)):
980
+ self .iscopyintotable = True
981
+ else :
982
+ self .iscopyintolocation = True
983
+
974
984
target = (
975
985
self .preparer .format_table (copy_into .target )
976
986
if isinstance (copy_into .target , (TableClause ,))
@@ -1090,8 +1100,21 @@ def visit_google_cloud_storage(self, gcs: GoogleCloudStorage, **kw):
1090
1100
f")"
1091
1101
)
1092
1102
1103
+ def visit_stage (self , stage , ** kw ):
1104
+ if stage .path :
1105
+ return f"@{ stage .name } /{ stage .path } "
1106
+ return f"@{ stage .name } "
1107
+
1093
1108
1094
1109
class DatabendExecutionContext (default .DefaultExecutionContext ):
1110
+ iscopyintotable = False
1111
+ iscopyintolocation = False
1112
+
1113
+ _copy_input_bytes : Optional [int ] = None
1114
+ _copy_output_bytes : Optional [int ] = None
1115
+ _copy_into_table_results : Optional [list [dict ]] = None
1116
+ _copy_into_location_results : dict = None
1117
+
1095
1118
@sa_util .memoized_property
1096
1119
def should_autocommit (self ):
1097
1120
return False # No DML supported, never autocommit
@@ -1102,6 +1125,38 @@ def create_server_side_cursor(self):
1102
1125
def create_default_cursor (self ):
1103
1126
return self ._dbapi_connection .cursor ()
1104
1127
1128
+ def post_exec (self ):
1129
+ self .iscopyintotable = getattr (self .compiled , 'iscopyintotable' , False )
1130
+ self .iscopyintolocation = getattr (self .compiled , 'iscopyintolocation' , False )
1131
+ if (self .isinsert or self .isupdate or self .isdelete or
1132
+ self .iscopyintolocation or self .iscopyintotable ):
1133
+ result = self .cursor .fetchall ()
1134
+ if self .iscopyintotable :
1135
+ self ._copy_into_table_results = [
1136
+ {
1137
+ 'file' : row [0 ],
1138
+ 'rows_loaded' : row [1 ],
1139
+ 'errors_seen' : row [2 ],
1140
+ 'first_error' : row [3 ],
1141
+ 'first_error_line' : row [4 ],
1142
+ } for row in result
1143
+ ]
1144
+ self ._rowcount = sum (c ['rows_loaded' ] for c in self ._copy_into_table_results )
1145
+ else :
1146
+ self ._rowcount = result [0 ][0 ]
1147
+ if self .iscopyintolocation :
1148
+ self ._copy_into_location_results = {
1149
+ 'rows_unloaded' : result [0 ][0 ],
1150
+ 'input_bytes' : result [0 ][1 ],
1151
+ 'output_bytes' : result [0 ][2 ],
1152
+ }
1153
+
1154
+ def copy_into_table_results (self ) -> list [dict ]:
1155
+ return self ._copy_into_table_results
1156
+
1157
+ def copy_into_location_results (self ) -> dict :
1158
+ return self ._copy_into_location_results
1159
+
1105
1160
1106
1161
class DatabendTypeCompiler (compiler .GenericTypeCompiler ):
1107
1162
def visit_ARRAY (self , type_ , ** kw ):
@@ -1171,6 +1226,12 @@ def post_create_table(self, table):
1171
1226
if engine is not None :
1172
1227
table_opts .append (f" ENGINE={ engine } " )
1173
1228
1229
+ if table .comment is not None :
1230
+ comment = self .sql_compiler .render_literal_value (
1231
+ table .comment , sqltypes .String ()
1232
+ )
1233
+ table_opts .append (f" COMMENT={ comment } " )
1234
+
1174
1235
cluster_keys = db_opts .get ("cluster_by" )
1175
1236
if cluster_keys is not None :
1176
1237
if isinstance (cluster_keys , str ):
@@ -1192,6 +1253,37 @@ def post_create_table(self, table):
1192
1253
1193
1254
return " " .join (table_opts )
1194
1255
1256
+ def get_column_specification (self , column , ** kwargs ):
1257
+ colspec = super ().get_column_specification (column , ** kwargs )
1258
+ comment = column .comment
1259
+ if comment is not None :
1260
+ literal = self .sql_compiler .render_literal_value (
1261
+ comment , sqltypes .String ()
1262
+ )
1263
+ colspec += " COMMENT " + literal
1264
+
1265
+ return colspec
1266
+
1267
+ def visit_set_table_comment (self , create , ** kw ):
1268
+ return "ALTER TABLE %s COMMENT = %s" % (
1269
+ self .preparer .format_table (create .element ),
1270
+ self .sql_compiler .render_literal_value (
1271
+ create .element .comment , sqltypes .String ()
1272
+ ),
1273
+ )
1274
+
1275
+ def visit_drop_table_comment (self , create , ** kw ):
1276
+ return "ALTER TABLE %s COMMENT = ''" % (
1277
+ self .preparer .format_table (create .element )
1278
+ )
1279
+
1280
+ def visit_set_column_comment (self , create , ** kw ):
1281
+ return "ALTER TABLE %s MODIFY %s %s" % (
1282
+ self .preparer .format_table (create .element .table ),
1283
+ self .preparer .format_column (create .element ),
1284
+ self .get_column_specification (create .element ),
1285
+ )
1286
+
1195
1287
1196
1288
class DatabendDialect (default .DefaultDialect ):
1197
1289
name = "databend"
@@ -1204,7 +1296,7 @@ class DatabendDialect(default.DefaultDialect):
1204
1296
supports_alter = True
1205
1297
supports_comments = False
1206
1298
supports_empty_insert = False
1207
- supports_is_distinct_from = False
1299
+ supports_is_distinct_from = True
1208
1300
supports_multivalues_insert = True
1209
1301
1210
1302
supports_statement_cache = False
@@ -1316,7 +1408,7 @@ def has_table(self, connection, table_name, schema=None, **kw):
1316
1408
def get_columns (self , connection , table_name , schema = None , ** kw ):
1317
1409
query = text (
1318
1410
"""
1319
- select column_name, column_type, is_nullable
1411
+ select column_name, column_type, is_nullable, nullif(column_comment, '')
1320
1412
from information_schema.columns
1321
1413
where table_name = :table_name
1322
1414
and table_schema = :schema_name
@@ -1337,6 +1429,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
1337
1429
"type" : self ._get_column_type (row [1 ]),
1338
1430
"nullable" : get_is_nullable (row [2 ]),
1339
1431
"default" : None ,
1432
+ "comment" : row [3 ],
1340
1433
}
1341
1434
for row in result
1342
1435
]
@@ -1416,6 +1509,23 @@ def get_table_names(self, connection, schema=None, **kw):
1416
1509
result = connection .execute (query , dict (schema_name = schema ))
1417
1510
return [row [0 ] for row in result ]
1418
1511
1512
+ @reflection .cache
1513
+ def get_temp_table_names (self , connection , schema = None , ** kw ):
1514
+ table_name_query = """
1515
+ select name
1516
+ from system.temporary_tables
1517
+ where database = :schema_name
1518
+ """
1519
+ query = text (table_name_query ).bindparams (
1520
+ bindparam ("schema_name" , type_ = sqltypes .Unicode )
1521
+ )
1522
+ if schema is None :
1523
+ schema = self .default_schema_name
1524
+
1525
+ result = connection .execute (query , dict (schema_name = schema ))
1526
+ return [row [0 ] for row in result ]
1527
+
1528
+
1419
1529
@reflection .cache
1420
1530
def get_view_names (self , connection , schema = None , ** kw ):
1421
1531
view_name_query = """
@@ -1510,6 +1620,82 @@ def get_table_options(self, connection, table_name, schema=None, **kw):
1510
1620
1511
1621
return options
1512
1622
1623
+ @reflection .cache
1624
+ def get_table_comment (self , connection , table_name , schema , ** kw ):
1625
+ query_text = """
1626
+ SELECT comment
1627
+ FROM system.tables
1628
+ WHERE database = :schema_name
1629
+ and name = :table_name
1630
+ """
1631
+ query = text (query_text ).bindparams (
1632
+ bindparam ("table_name" , type_ = sqltypes .Unicode ),
1633
+ bindparam ("schema_name" , type_ = sqltypes .Unicode ),
1634
+ )
1635
+ if schema is None :
1636
+ schema = self .default_schema_name
1637
+
1638
+ result = connection .execute (
1639
+ query , dict (table_name = table_name , schema_name = schema )
1640
+ ).one_or_none ()
1641
+ if not result :
1642
+ raise NoSuchTableError (
1643
+ f"{ self .identifier_preparer .quote_identifier (schema )} ."
1644
+ f"{ self .identifier_preparer .quote_identifier (table_name )} "
1645
+ )
1646
+ return {'text' : result [0 ]} if result [0 ] else reflection .ReflectionDefaults .table_comment () if hasattr (reflection , 'ReflectionDefault' ) else {'text' : None }
1647
+
1648
+ def _prepare_filter_names (self , filter_names ):
1649
+ if filter_names :
1650
+ fn = [name for name in filter_names ]
1651
+ return True , {"filter_names" : fn }
1652
+ else :
1653
+ return False , {}
1654
+
1655
+ def get_multi_table_comment (
1656
+ self , connection , schema , filter_names , scope , kind , ** kw
1657
+ ):
1658
+ meta = MetaData ()
1659
+ all_tab_comments = Table (
1660
+ "tables" ,
1661
+ meta ,
1662
+ Column ("database" , VARCHAR , nullable = False ),
1663
+ Column ("name" , VARCHAR , nullable = False ),
1664
+ Column ("comment" , VARCHAR ),
1665
+ Column ("table_type" , VARCHAR ),
1666
+ schema = 'system' ,
1667
+ ).alias ("a_tab_comments" )
1668
+
1669
+
1670
+ has_filter_names , params = self ._prepare_filter_names (filter_names )
1671
+ owner = schema or self .default_schema_name
1672
+
1673
+ table_types = set ()
1674
+ if reflection .ObjectKind .TABLE in kind :
1675
+ table_types .add ('BASE TABLE' )
1676
+ if reflection .ObjectKind .VIEW in kind :
1677
+ table_types .add ('VIEW' )
1678
+
1679
+ query = select (
1680
+ all_tab_comments .c .name , all_tab_comments .c .comment
1681
+ ).where (
1682
+ all_tab_comments .c .database == owner ,
1683
+ all_tab_comments .c .table_type .in_ (table_types ),
1684
+ sqlalchemy .true () if reflection .ObjectScope .DEFAULT in scope else sqlalchemy .false (),
1685
+ )
1686
+ if has_filter_names :
1687
+ query = query .where (all_tab_comments .c .name .in_ (bindparam ("filter_names" )))
1688
+
1689
+ result = connection .execute (query , params )
1690
+ default_comment = reflection .ReflectionDefaults .table_comment
1691
+ return (
1692
+ (
1693
+ (schema , table ),
1694
+ {"text" : comment } if comment else default_comment (),
1695
+ )
1696
+ for table , comment in result
1697
+ )
1698
+
1513
1699
def do_rollback (self , dbapi_connection ):
1514
1700
# No transactions
1515
1701
pass
0 commit comments