Skip to content

Commit a83cb22

Browse files
committed
Compile Merge On clause + Fix spacing
1 parent 60fc1d0 commit a83cb22

File tree

4 files changed

+18
-7
lines changed

4 files changed

+18
-7
lines changed

databend_sqlalchemy/connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def mogrify(self, query, parameters):
163163
def execute(self, operation, parameters=None):
164164
"""Prepare and execute a database operation (query or command)."""
165165

166-
# ToDo - Fix this, which is preventing the execution of blank DDL sunch as CREATE INDEX statements which aren't currently supported
166+
# ToDo - Fix this, which is preventing the execution of blank DDL such as CREATE INDEX statements which aren't currently supported
167167
# Seems hard to fix when statements are coming from metadata.create_all()
168168
if not operation:
169169
return

databend_sqlalchemy/databend_dialect.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -906,13 +906,17 @@ def visit_merge(self, merge, **kw):
906906
)
907907
elif isinstance(merge.source, Subquery):
908908
source = merge.source._compiler_dispatch(self, **source_kw)
909+
else:
910+
source = merge.source
911+
912+
merge_on = merge.on._compiler_dispatch(self, **kw)
909913

910914
target_table = self.preparer.format_table(merge.target)
911915
return (
912916
f"MERGE INTO {target_table}\n"
913917
f" USING {source}\n"
914-
f" ON {merge.on}\n"
915-
f"{clauses if clauses else ''}"
918+
f" ON {merge_on}\n"
919+
f" {clauses if clauses else ''}"
916920
)
917921

918922
def visit_when_merge_matched_update(self, merge_matched_update, **kw):
@@ -921,7 +925,7 @@ def visit_when_merge_matched_update(self, merge_matched_update, **kw):
921925
if merge_matched_update.predicate is not None
922926
else ""
923927
)
924-
update_str = f"WHEN MATCHED{case_predicate} THEN\n" f"\tUPDATE"
928+
update_str = f"WHEN MATCHED{case_predicate} THEN\n UPDATE"
925929
if not merge_matched_update.set:
926930
return f"{update_str} *"
927931

@@ -950,7 +954,7 @@ def visit_when_merge_unmatched(self, merge_unmatched, **kw):
950954
if merge_unmatched.predicate is not None
951955
else ""
952956
)
953-
insert_str = f"WHEN NOT MATCHED{case_predicate} THEN\n" f"\tINSERT"
957+
insert_str = f"WHEN NOT MATCHED{case_predicate} THEN\n INSERT"
954958
if not merge_unmatched.set:
955959
return f"{insert_str} *"
956960

databend_sqlalchemy/dml.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class Merge(UpdateBase):
7676
__visit_name__ = "merge"
7777
_bind = None
7878

79+
inherit_cache = False
80+
7981
def __init__(self, target, source, on):
8082
if not isinstance(source, (TableClause, Select, Subquery)):
8183
raise Exception(f"Invalid type for merge source: {source}")

tests/test_merge.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
from sqlalchemy import types as sqltypes
1212

1313
# from sqlalchemy.dialects.postgresql import insert
14-
from sqlalchemy.testing import config
14+
from sqlalchemy.testing import config, AssertsCompiledSQL
1515
from sqlalchemy.testing import fixtures
1616
from sqlalchemy.testing.assertions import assert_raises
1717
from sqlalchemy.testing.assertions import eq_
1818

1919
from databend_sqlalchemy import Merge
2020

2121

22-
class MergeIntoTest(fixtures.TablesTest):
22+
class MergeIntoTest(fixtures.TablesTest, AssertsCompiledSQL):
2323
__backend__ = True
2424
run_define_tables = "each"
2525

@@ -170,6 +170,11 @@ def test_when_not_matched_insert(self, connection):
170170
merge = Merge(users, users_xtra, users.c.id == users_xtra.c.id)
171171
merge.when_not_matched_then_insert()
172172

173+
self.assert_compile(
174+
merge,
175+
'MERGE INTO "users" USING (SELECT users_xtra.id AS id, users_xtra.name AS name, users_xtra.login_email AS login_email FROM users_xtra) AS users_xtra ON "users".id = users_xtra.id WHEN NOT MATCHED THEN INSERT *',
176+
)
177+
173178
result = connection.execute(merge)
174179
eq_(
175180
connection.execute(users.select().order_by(users.c.id)).fetchall(),

0 commit comments

Comments
 (0)