@@ -60,6 +60,7 @@ def build_create_table_commands(
60
60
schema_name : str ,
61
61
stream : str ,
62
62
table_name : str ,
63
+ temp_table : bool = False ,
63
64
database_name : str = None ,
64
65
unique_constraints : List [str ] = None ,
65
66
) -> List [str ]:
@@ -69,14 +70,15 @@ def build_create_table_commands(
69
70
convert_column_type ,
70
71
lambda item_type_converted : 'ARRAY' ,
71
72
),
73
+ column_identifier = self .quote ,
72
74
columns = schema ['properties' ].keys (),
73
75
full_table_name = self .full_table_name (
74
76
database_name ,
75
77
schema_name ,
76
78
table_name ,
77
79
),
80
+ create_temporary_table = temp_table ,
78
81
unique_constraints = unique_constraints ,
79
- column_identifier = self .quote ,
80
82
use_lowercase = self .use_lowercase ,
81
83
)
82
84
@@ -341,6 +343,8 @@ def write_dataframe_to_table(
341
343
database : str ,
342
344
schema : str ,
343
345
table : str ,
346
+ connection = None ,
347
+ temp_table : bool = False ,
344
348
) -> List [List [tuple ]]:
345
349
"""
346
350
Write a Pandas DataFrame to a table in a Snowflake database.
@@ -363,19 +367,31 @@ def write_dataframe_to_table(
363
367
"""
364
368
self .logger .info (
365
369
f'write_pandas to: { database } .{ schema } .{ table } ' )
366
- snowflake_connection = self .build_connection ()
367
- connection = snowflake_connection .build_connection ()
370
+
371
+ new_connection_created = False
372
+ snowflake_connection = None
373
+ if connection is None :
374
+ snowflake_connection = self .build_connection ()
375
+ connection = snowflake_connection .build_connection ()
376
+ new_connection_created
368
377
if self .disable_double_quotes :
369
378
df .columns = [col .upper () for col in df .columns ]
379
+
380
+ kwargs = dict (
381
+ database = database .upper () if self .disable_double_quotes else database ,
382
+ schema = schema .upper () if self .disable_double_quotes else schema ,
383
+ auto_create_table = False ,
384
+ )
385
+ if temp_table :
386
+ kwargs ['table_type' ] = 'temp'
370
387
success , num_chunks , num_rows , output = write_pandas (
371
388
connection ,
372
389
df ,
373
390
table .upper () if self .disable_double_quotes else table ,
374
- database = database .upper () if self .disable_double_quotes else database ,
375
- schema = schema .upper () if self .disable_double_quotes else schema ,
376
- auto_create_table = False ,
391
+ ** kwargs ,
377
392
)
378
- snowflake_connection .close_connection (connection )
393
+ if new_connection_created and snowflake_connection is not None :
394
+ snowflake_connection .close_connection (connection )
379
395
self .logger .info (
380
396
f'write_pandas completed: { success } , { num_chunks } chunks, { num_rows } rows.' )
381
397
self .logger .info (f'write_pandas output: { output } ' )
@@ -456,17 +472,30 @@ def process_queries(
456
472
schema_name = schema ,
457
473
stream = None ,
458
474
table_name = f'temp_{ table } ' ,
475
+ temp_table = True ,
459
476
database_name = database ,
460
477
unique_constraints = unique_constraints ,
461
478
)
462
-
463
- results += self .build_connection ().execute (
464
- drop_temp_table_command + create_temp_table_command , commit = True )
479
+ # Run commands in one Snowflake session to leverage TEMP table
480
+ snowflake_connection = self .build_connection ()
481
+ connection = snowflake_connection .build_connection ()
482
+ results += snowflake_connection .execute (
483
+ drop_temp_table_command + create_temp_table_command ,
484
+ commit = False ,
485
+ connection = connection ,
486
+ )
465
487
466
488
# Outputs of write_dataframe_to_table are for temporary table only, thus not added
467
489
# to results
468
490
# results += self.write_dataframe_to_table(df, database, schema, f'temp_{table}')
469
- self .write_dataframe_to_table (df , database , schema , f'temp_{ table } ' )
491
+ self .write_dataframe_to_table (
492
+ df ,
493
+ database ,
494
+ schema ,
495
+ f'temp_{ table } ' ,
496
+ connection = connection ,
497
+ temp_table = True ,
498
+ )
470
499
self .logger .info (
471
500
f'write_dataframe_to_table completed to: { full_table_name_temp } ' )
472
501
@@ -480,8 +509,13 @@ def process_queries(
480
509
481
510
self .logger .info (f'Merging { full_table_name_temp } into { full_table_name } ' )
482
511
self .logger .info (f'Dropping temporary table: { full_table_name_temp } ' )
483
- results += self .build_connection ().execute (
484
- merge_command + drop_temp_table_command , commit = True )
512
+ results += snowflake_connection .execute (
513
+ merge_command + drop_temp_table_command ,
514
+ commit = True ,
515
+ connection = connection ,
516
+ )
517
+ # Close connection after finishing running all commands
518
+ snowflake_connection .close_connection (connection )
485
519
self .logger .info (f'Merged and dropped temporary table: { full_table_name_temp } ' )
486
520
else :
487
521
results += self .write_dataframe_to_table (df , database , schema , table )
0 commit comments