1
+ import re
1
2
import typing
2
3
from collections import deque
3
4
from itertools import count , islice
@@ -185,7 +186,7 @@ def executemany(self: "Cursor", operation, param_sets) -> "Cursor":
185
186
self ._row_count = - 1 if - 1 in rowcounts else sum (rowcounts )
186
187
return self
187
188
188
- def fetchone (self : "Cursor" ) -> typing .Optional ["Cursor" ]:
189
+ def fetchone (self : "Cursor" ) -> typing .Optional [typing . List ]:
189
190
"""Fetch the next row of a query result set.
190
191
191
192
This method is part of the `DBAPI 2.0 specification
@@ -196,7 +197,7 @@ def fetchone(self: "Cursor") -> typing.Optional["Cursor"]:
196
197
are available.
197
198
"""
198
199
try :
199
- return typing . cast ( "Cursor" , next (self ) )
200
+ return next (self )
200
201
except StopIteration :
201
202
return None
202
203
except TypeError :
@@ -271,7 +272,7 @@ def setoutputsize(self: "Cursor", size, column=None):
271
272
"""
272
273
pass
273
274
274
- def __next__ (self : "Cursor" ):
275
+ def __next__ (self : "Cursor" ) -> typing . List :
275
276
try :
276
277
return self ._cached_rows .popleft ()
277
278
except IndexError :
@@ -311,16 +312,48 @@ def fetch_dataframe(self: "Cursor", num: typing.Optional[int] = None) -> typing.
311
312
return None
312
313
return pandas .DataFrame (result , columns = columns )
313
314
315
+ def __is_valid_table (self : "Cursor" , table : str ) -> bool :
316
+ split_table_name : typing .List [str ] = table .split ("." )
317
+
318
+ if len (split_table_name ) > 2 :
319
+ return False
320
+
321
+ q : str = "select 1 from information_schema.tables where table_name = ?"
322
+
323
+ temp = self .paramstyle
324
+ self .paramstyle = "qmark"
325
+
326
+ try :
327
+ if len (split_table_name ) == 2 :
328
+ q += " and table_schema = ?"
329
+ self .execute (q , (split_table_name [1 ], split_table_name [0 ]))
330
+ else :
331
+ self .execute (q , (split_table_name [0 ],))
332
+ except :
333
+ raise
334
+ finally :
335
+ # reset paramstyle to it's original value
336
+ self .paramstyle = temp
337
+
338
+ result = self .fetchone ()
339
+
340
+ return result [0 ] == 1 if result is not None else False
341
+
314
342
def write_dataframe (self : "Cursor" , df : "pandas.DataFrame" , table : str ) -> None :
315
343
"""write same structure dataframe into Redshift database"""
316
344
try :
317
345
import pandas
318
346
except ModuleNotFoundError :
319
347
raise ModuleNotFoundError (MISSING_MODULE_ERROR_MSG .format (module = "pandas" ))
320
348
349
+ if not self .__is_valid_table (table ):
350
+ raise InterfaceError ("Invalid table name passed to write_dataframe: {}" .format (table ))
351
+ sanitized_table_name : str = self .__sanitize_str (table )
321
352
arrays : "numpy.ndarray" = df .values
322
353
placeholder : str = ", " .join (["%s" ] * len (arrays [0 ]))
323
- sql : str = "insert into {table} values ({placeholder})" .format (table = table , placeholder = placeholder )
354
+ sql : str = "insert into {table} values ({placeholder})" .format (
355
+ table = sanitized_table_name , placeholder = placeholder
356
+ )
324
357
if len (arrays ) == 1 :
325
358
self .execute (sql , arrays [0 ])
326
359
elif len (arrays ) > 1 :
@@ -361,16 +394,33 @@ def get_procedures(
361
394
" LEFT JOIN pg_catalog.pg_namespace pn ON (c.relnamespace=pn.oid AND pn.nspname='pg_catalog') "
362
395
" WHERE p.pronamespace=n.oid "
363
396
)
397
+ query_args : typing .List [str ] = []
364
398
if schema_pattern is not None :
365
- sql += " AND n.nspname LIKE {schema}" .format (schema = self .__escape_quotes (schema_pattern ))
399
+ sql += " AND n.nspname LIKE ?"
400
+ query_args .append (self .__sanitize_str (schema_pattern ))
366
401
else :
367
402
sql += "and pg_function_is_visible(p.prooid)"
368
403
369
404
if procedure_name_pattern is not None :
370
- sql += " AND p.proname LIKE {procedure}" .format (procedure = self .__escape_quotes (procedure_name_pattern ))
405
+ sql += " AND p.proname LIKE ?"
406
+ query_args .append (self .__sanitize_str (procedure_name_pattern ))
371
407
sql += " ORDER BY PROCEDURE_SCHEM, PROCEDURE_NAME, p.prooid::text "
372
408
373
- self .execute (sql )
409
+ if len (query_args ) > 0 :
410
+ # temporarily use qmark paramstyle
411
+ temp = self .paramstyle
412
+ self .paramstyle = "qmark"
413
+
414
+ try :
415
+ self .execute (sql , tuple (query_args ))
416
+ except :
417
+ raise
418
+ finally :
419
+ # reset the original value of paramstyle
420
+ self .paramstyle = temp
421
+ else :
422
+ self .execute (sql )
423
+
374
424
procedures : tuple = self .fetchall ()
375
425
return procedures
376
426
@@ -383,11 +433,25 @@ def get_schemas(
383
433
" OR nspname = (pg_catalog.current_schemas(true))[1]) AND (nspname !~ '^pg_toast_temp_' "
384
434
" OR nspname = replace((pg_catalog.current_schemas(true))[1], 'pg_temp_', 'pg_toast_temp_')) "
385
435
)
436
+ query_args : typing .List [str ] = []
386
437
if schema_pattern is not None :
387
- sql += " AND nspname LIKE {schema}" .format (schema = self .__escape_quotes (schema_pattern ))
438
+ sql += " AND nspname LIKE ?"
439
+ query_args .append (self .__sanitize_str (schema_pattern ))
388
440
sql += " ORDER BY TABLE_SCHEM"
389
441
390
- self .execute (sql )
442
+ if len (query_args ) == 1 :
443
+ # temporarily use qmark paramstyle
444
+ temp = self .paramstyle
445
+ self .paramstyle = "qmark"
446
+ try :
447
+ self .execute (sql , tuple (query_args ))
448
+ except :
449
+ raise
450
+ finally :
451
+ self .paramstyle = temp
452
+ else :
453
+ self .execute (sql )
454
+
391
455
schemas : tuple = self .fetchall ()
392
456
return schemas
393
457
@@ -418,13 +482,28 @@ def get_primary_keys(
418
482
"i.indisprimary AND "
419
483
"ct.relnamespace = n.oid "
420
484
)
485
+ query_args : typing .List [str ] = []
421
486
if schema is not None :
422
- sql += " AND n.nspname = {schema}" .format (schema = self .__escape_quotes (schema ))
487
+ sql += " AND n.nspname = ?"
488
+ query_args .append (self .__sanitize_str (schema ))
423
489
if table is not None :
424
- sql += " AND ct.relname = {table}" .format (table = self .__escape_quotes (table ))
490
+ sql += " AND ct.relname = ?"
491
+ query_args .append (self .__sanitize_str (table ))
425
492
426
493
sql += " ORDER BY table_name, pk_name, key_seq"
427
- self .execute (sql )
494
+
495
+ if len (query_args ) > 0 :
496
+ # temporarily use qmark paramstyle
497
+ temp = self .paramstyle
498
+ self .paramstyle = "qmark"
499
+ try :
500
+ self .execute (sql , tuple (query_args ))
501
+ except :
502
+ raise
503
+ finally :
504
+ self .paramstyle = temp
505
+ else :
506
+ self .execute (sql )
428
507
keys : tuple = self .fetchall ()
429
508
return keys
430
509
@@ -437,15 +516,30 @@ def get_tables(
437
516
) -> tuple :
438
517
"""Returns the unique public tables which are user-defined within the system"""
439
518
sql : str = ""
519
+ sql_args : typing .Tuple [str , ...] = tuple ()
440
520
schema_pattern_type : str = self .__schema_pattern_match (schema_pattern )
441
521
if schema_pattern_type == "LOCAL_SCHEMA_QUERY" :
442
- sql = self .__build_local_schema_tables_query (catalog , schema_pattern , table_name_pattern , types )
522
+ sql , sql_args = self .__build_local_schema_tables_query (catalog , schema_pattern , table_name_pattern , types )
443
523
elif schema_pattern_type == "NO_SCHEMA_UNIVERSAL_QUERY" :
444
- sql = self .__build_universal_schema_tables_query (catalog , schema_pattern , table_name_pattern , types )
524
+ sql , sql_args = self .__build_universal_schema_tables_query (
525
+ catalog , schema_pattern , table_name_pattern , types
526
+ )
445
527
elif schema_pattern_type == "EXTERNAL_SCHEMA_QUERY" :
446
- sql = self .__build_external_schema_tables_query (catalog , schema_pattern , table_name_pattern , types )
528
+ sql , sql_args = self .__build_external_schema_tables_query (
529
+ catalog , schema_pattern , table_name_pattern , types
530
+ )
447
531
448
- self .execute (sql )
532
+ if len (sql_args ) > 0 :
533
+ temp = self .paramstyle
534
+ self .paramstyle = "qmark"
535
+ try :
536
+ self .execute (sql , sql_args )
537
+ except :
538
+ raise
539
+ finally :
540
+ self .paramstyle = temp
541
+ else :
542
+ self .execute (sql )
449
543
tables : tuple = self .fetchall ()
450
544
return tables
451
545
@@ -455,7 +549,7 @@ def __build_local_schema_tables_query(
455
549
schema_pattern : typing .Optional [str ],
456
550
table_name_pattern : typing .Optional [str ],
457
551
types : list ,
458
- ) -> str :
552
+ ) -> typing . Tuple [ str , typing . Tuple [ str , ...]] :
459
553
sql : str = (
460
554
"SELECT CAST(current_database() AS VARCHAR(124)) AS TABLE_CAT, n.nspname AS TABLE_SCHEM, c.relname AS TABLE_NAME, "
461
555
" CASE n.nspname ~ '^pg_' OR n.nspname = 'information_schema' "
@@ -502,32 +596,41 @@ def __build_local_schema_tables_query(
502
596
" LEFT JOIN pg_catalog.pg_namespace dn ON (dn.oid=dc.relnamespace AND dn.nspname='pg_catalog') "
503
597
" WHERE c.relnamespace = n.oid "
504
598
)
505
- filter_clause : str = self .__get_table_filter_clause (
599
+ filter_clause , filter_args = self .__get_table_filter_clause (
506
600
catalog , schema_pattern , table_name_pattern , types , "LOCAL_SCHEMA_QUERY"
507
601
)
508
602
orderby : str = " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "
509
603
510
- return sql + filter_clause + orderby
604
+ return sql + filter_clause + orderby , filter_args
511
605
512
606
def __get_table_filter_clause (
513
607
self : "Cursor" ,
514
608
catalog : typing .Optional [str ],
515
609
schema_pattern : typing .Optional [str ],
516
610
table_name_pattern : typing .Optional [str ],
517
- types : list ,
611
+ types : typing . List [ str ] ,
518
612
schema_pattern_type : str ,
519
- ) -> str :
613
+ ) -> typing . Tuple [ str , typing . Tuple [ str , ...]] :
520
614
filter_clause : str = ""
521
615
use_schemas : str = "SCHEMAS"
616
+ query_args : typing .List [str ] = []
522
617
if schema_pattern is not None :
523
- filter_clause += " AND TABLE_SCHEM LIKE {schema}" .format (schema = self .__escape_quotes (schema_pattern ))
618
+ filter_clause += " AND TABLE_SCHEM LIKE ?"
619
+ query_args .append (self .__sanitize_str (schema_pattern ))
524
620
if table_name_pattern is not None :
525
- filter_clause += " AND TABLE_NAME LIKE {table}" .format (table = self .__escape_quotes (table_name_pattern ))
621
+ filter_clause += " AND TABLE_NAME LIKE ?"
622
+ query_args .append (self .__sanitize_str (table_name_pattern ))
526
623
if len (types ) > 0 :
527
624
if schema_pattern_type == "LOCAL_SCHEMA_QUERY" :
528
625
filter_clause += " AND (false "
529
626
orclause : str = ""
530
627
for type in types :
628
+ if type not in table_type_clauses .keys ():
629
+ raise InterfaceError (
630
+ "Invalid type: {} provided. types may only contain: {}" .format (
631
+ type , table_type_clauses .keys ()
632
+ )
633
+ )
531
634
clauses = table_type_clauses [type ]
532
635
if len (clauses ) > 0 :
533
636
cluase = clauses [use_schemas ]
@@ -538,21 +641,28 @@ def __get_table_filter_clause(
538
641
filter_clause += " AND TABLE_TYPE IN ( "
539
642
length = len (types )
540
643
for type in types :
541
- filter_clause += self .__escape_quotes (type )
644
+ if type not in table_type_clauses .keys ():
645
+ raise InterfaceError (
646
+ "Invalid type: {} provided. types may only contain: {}" .format (
647
+ type , table_type_clauses .keys ()
648
+ )
649
+ )
650
+ filter_clause += "?"
651
+ query_args .append (self .__sanitize_str (type ))
542
652
length -= 1
543
653
if length > 0 :
544
654
filter_clause += ", "
545
655
filter_clause += ") "
546
656
547
- return filter_clause
657
+ return filter_clause , tuple ( query_args )
548
658
549
659
def __build_universal_schema_tables_query (
550
660
self : "Cursor" ,
551
661
catalog : typing .Optional [str ],
552
662
schema_pattern : typing .Optional [str ],
553
663
table_name_pattern : typing .Optional [str ],
554
664
types : list ,
555
- ) -> str :
665
+ ) -> typing . Tuple [ str , typing . Tuple [ str , ...]] :
556
666
sql : str = (
557
667
"SELECT * FROM (SELECT CAST(current_database() AS VARCHAR(124)) AS TABLE_CAT,"
558
668
" table_schema AS TABLE_SCHEM,"
@@ -583,20 +693,20 @@ def __build_universal_schema_tables_query(
583
693
" FROM svv_tables)"
584
694
" WHERE true "
585
695
)
586
- filter_clause : str = self .__get_table_filter_clause (
696
+ filter_clause , filter_args = self .__get_table_filter_clause (
587
697
catalog , schema_pattern , table_name_pattern , types , "NO_SCHEMA_UNIVERSAL_QUERY"
588
698
)
589
699
orderby : str = " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "
590
700
sql += filter_clause + orderby
591
- return sql
701
+ return sql , filter_args
592
702
593
703
def __build_external_schema_tables_query (
594
704
self : "Cursor" ,
595
705
catalog : typing .Optional [str ],
596
706
schema_pattern : typing .Optional [str ],
597
707
table_name_pattern : typing .Optional [str ],
598
708
types : list ,
599
- ) -> str :
709
+ ) -> typing . Tuple [ str , typing . Tuple [ str , ...]] :
600
710
sql : str = (
601
711
"SELECT * FROM (SELECT CAST(current_database() AS VARCHAR(124)) AS TABLE_CAT,"
602
712
" schemaname AS table_schem,"
@@ -611,12 +721,12 @@ def __build_external_schema_tables_query(
611
721
" FROM svv_external_tables)"
612
722
" WHERE true "
613
723
)
614
- filter_clause : str = self .__get_table_filter_clause (
724
+ filter_clause , filter_args = self .__get_table_filter_clause (
615
725
catalog , schema_pattern , table_name_pattern , types , "EXTERNAL_SCHEMA_QUERY"
616
726
)
617
727
orderby : str = " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "
618
728
sql += filter_clause + orderby
619
- return sql
729
+ return sql , filter_args
620
730
621
731
def get_columns (
622
732
self : "Cursor" ,
@@ -1477,5 +1587,8 @@ def __schema_pattern_match(self: "Cursor", schema_pattern: typing.Optional[str])
1477
1587
else :
1478
1588
return "NO_SCHEMA_UNIVERSAL_QUERY"
1479
1589
1590
+ def __sanitize_str (self : "Cursor" , s : str ) -> str :
1591
+ return re .sub (r"[-;/'\"\n\r ]" , "" , s )
1592
+
1480
1593
def __escape_quotes (self : "Cursor" , s : str ) -> str :
1481
- return "'{s}'" .format (s = s )
1594
+ return "'{s}'" .format (s = self . __sanitize_str ( s ) )
0 commit comments