Skip to content

Commit 97b372c

Browse files
[xy] Create TEMP Snowflake table in data integration destination (mage-ai#4578)
* [xy] Create temp table in Snowflake destination. * [xy] Pass temp_table flag. * [xy] Update create command. * [xy] Not commit for temp table. * [xy] Set table_type in write_pandas method. * [xy] Set table_type in write_pandas method. * [xy] Restrict langchain package version. * [xy] Restrict langchain_community version. * [xy] Fix unit test.
1 parent cd91bc4 commit 97b372c

File tree

5 files changed

+75
-16
lines changed

5 files changed

+75
-16
lines changed

mage_integrations/mage_integrations/connections/sql/base.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,37 @@ def execute(
1919
self,
2020
query_strings: List[str],
2121
commit=False,
22+
connection=None,
2223
) -> List[List[tuple]]:
23-
connection = self.build_connection()
24+
"""
25+
Execute the provided SQL queries using the given connection, or create a new one if not
26+
provided. If a new connection is created, it'll be closed automatically at the end.
27+
28+
Args:
29+
query_strings (List[str]): List of SQL queries to execute.
30+
commit (bool, optional): Whether to commit the transaction. Defaults to False.
31+
connection (Optional[Any], optional): An existing connection to use. If None, a new
32+
connection will be created. Defaults to None.
33+
34+
Returns:
35+
List[List[Tuple]]: A list of result sets for each query executed. Each result set is
36+
represented as a list of tuples.
37+
38+
Raises:
39+
Any exceptions raised during the execution process may propagate upward.
40+
"""
41+
new_connection_created = False
42+
if connection is None:
43+
connection = self.build_connection()
44+
new_connection_created = True
2445

2546
data = self.execute_with_connection(connection, query_strings)
2647

2748
if commit:
2849
connection.commit()
2950

30-
self.close_connection(connection)
51+
if new_connection_created:
52+
self.close_connection(connection)
3153

3254
return data
3355

mage_integrations/mage_integrations/destinations/snowflake/__init__.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def build_create_table_commands(
6060
schema_name: str,
6161
stream: str,
6262
table_name: str,
63+
temp_table: bool = False,
6364
database_name: str = None,
6465
unique_constraints: List[str] = None,
6566
) -> List[str]:
@@ -69,14 +70,15 @@ def build_create_table_commands(
6970
convert_column_type,
7071
lambda item_type_converted: 'ARRAY',
7172
),
73+
column_identifier=self.quote,
7274
columns=schema['properties'].keys(),
7375
full_table_name=self.full_table_name(
7476
database_name,
7577
schema_name,
7678
table_name,
7779
),
80+
create_temporary_table=temp_table,
7881
unique_constraints=unique_constraints,
79-
column_identifier=self.quote,
8082
use_lowercase=self.use_lowercase,
8183
)
8284

@@ -341,6 +343,8 @@ def write_dataframe_to_table(
341343
database: str,
342344
schema: str,
343345
table: str,
346+
connection=None,
347+
temp_table: bool = False,
344348
) -> List[List[tuple]]:
345349
"""
346350
Write a Pandas DataFrame to a table in a Snowflake database.
@@ -363,19 +367,31 @@ def write_dataframe_to_table(
363367
"""
364368
self.logger.info(
365369
f'write_pandas to: {database}.{schema}.{table}')
366-
snowflake_connection = self.build_connection()
367-
connection = snowflake_connection.build_connection()
370+
371+
new_connection_created = False
372+
snowflake_connection = None
373+
if connection is None:
374+
snowflake_connection = self.build_connection()
375+
connection = snowflake_connection.build_connection()
376+
new_connection_created
368377
if self.disable_double_quotes:
369378
df.columns = [col.upper() for col in df.columns]
379+
380+
kwargs = dict(
381+
database=database.upper() if self.disable_double_quotes else database,
382+
schema=schema.upper() if self.disable_double_quotes else schema,
383+
auto_create_table=False,
384+
)
385+
if temp_table:
386+
kwargs['table_type'] = 'temp'
370387
success, num_chunks, num_rows, output = write_pandas(
371388
connection,
372389
df,
373390
table.upper() if self.disable_double_quotes else table,
374-
database=database.upper() if self.disable_double_quotes else database,
375-
schema=schema.upper() if self.disable_double_quotes else schema,
376-
auto_create_table=False,
391+
**kwargs,
377392
)
378-
snowflake_connection.close_connection(connection)
393+
if new_connection_created and snowflake_connection is not None:
394+
snowflake_connection.close_connection(connection)
379395
self.logger.info(
380396
f'write_pandas completed: {success}, {num_chunks} chunks, {num_rows} rows.')
381397
self.logger.info(f'write_pandas output: {output}')
@@ -456,17 +472,30 @@ def process_queries(
456472
schema_name=schema,
457473
stream=None,
458474
table_name=f'temp_{table}',
475+
temp_table=True,
459476
database_name=database,
460477
unique_constraints=unique_constraints,
461478
)
462-
463-
results += self.build_connection().execute(
464-
drop_temp_table_command + create_temp_table_command, commit=True)
479+
# Run commands in one Snowflake session to leverage TEMP table
480+
snowflake_connection = self.build_connection()
481+
connection = snowflake_connection.build_connection()
482+
results += snowflake_connection.execute(
483+
drop_temp_table_command + create_temp_table_command,
484+
commit=False,
485+
connection=connection,
486+
)
465487

466488
# Outputs of write_dataframe_to_table are for temporary table only, thus not added
467489
# to results
468490
# results += self.write_dataframe_to_table(df, database, schema, f'temp_{table}')
469-
self.write_dataframe_to_table(df, database, schema, f'temp_{table}')
491+
self.write_dataframe_to_table(
492+
df,
493+
database,
494+
schema,
495+
f'temp_{table}',
496+
connection=connection,
497+
temp_table=True,
498+
)
470499
self.logger.info(
471500
f'write_dataframe_to_table completed to: {full_table_name_temp}')
472501

@@ -480,8 +509,13 @@ def process_queries(
480509

481510
self.logger.info(f'Merging {full_table_name_temp} into {full_table_name}')
482511
self.logger.info(f'Dropping temporary table: {full_table_name_temp}')
483-
results += self.build_connection().execute(
484-
merge_command + drop_temp_table_command, commit=True)
512+
results += snowflake_connection.execute(
513+
merge_command + drop_temp_table_command,
514+
commit=True,
515+
connection=connection,
516+
)
517+
# Close connection after finishing running all commands
518+
snowflake_connection.close_connection(connection)
485519
self.logger.info(f'Merged and dropped temporary table: {full_table_name_temp}')
486520
else:
487521
results += self.write_dataframe_to_table(df, database, schema, table)

mage_integrations/mage_integrations/tests/destinations/snowflake/test_snowflake.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_create_table_commands(self):
6666
SCHEMA_NAME,
6767
STREAM,
6868
TABLE_NAME,
69-
DATABASE_NAME)
69+
database_name=DATABASE_NAME)
7070
self.assertEqual(
7171
table_commands,
7272
['CREATE TABLE "test_db"."test"."test_table" ("ID" VARCHAR, "_USER" VARIANT)']

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ google-cloud-storage~=2.5.0
6565
gspread==5.7.2
6666
kubernetes>=28.1.0
6767
langchain>=0.0.222; python_version >= '3.8'
68+
langchain_community<0.0.20
6869
mysql-connector-python~=8.2.0
6970
openai>=0.27.8, <1.0.0
7071
opentelemetry-exporter-prometheus==0.43b0

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def readme():
4343
'ai': [
4444
'astor>=0.8.1',
4545
'langchain>=0.0.222',
46+
'langchain_community<0.0.20',
4647
'openai>=0.27.8, <1.0.0',
4748
],
4849
'azure': [
@@ -166,6 +167,7 @@ def readme():
166167
'kafka-python==2.0.2',
167168
'kubernetes>=28.1.0',
168169
'langchain>=0.0.222',
170+
'langchain_community<0.0.20',
169171
'ldap3==2.9.1',
170172
'nats-py==2.6.0',
171173
'nkeys~=0.1.0',

0 commit comments

Comments
 (0)