Skip to content

Commit f0b21a2

Browse files
committed
refactor(cursor): use prepared statements in class methods
1 parent d041f2b commit f0b21a2

File tree

2 files changed

+152
-32
lines changed

2 files changed

+152
-32
lines changed

redshift_connector/cursor.py

+145-32
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
import typing
23
from collections import deque
34
from itertools import count, islice
@@ -185,7 +186,7 @@ def executemany(self: "Cursor", operation, param_sets) -> "Cursor":
185186
self._row_count = -1 if -1 in rowcounts else sum(rowcounts)
186187
return self
187188

188-
def fetchone(self: "Cursor") -> typing.Optional["Cursor"]:
189+
def fetchone(self: "Cursor") -> typing.Optional[typing.List]:
189190
"""Fetch the next row of a query result set.
190191
191192
This method is part of the `DBAPI 2.0 specification
@@ -196,7 +197,7 @@ def fetchone(self: "Cursor") -> typing.Optional["Cursor"]:
196197
are available.
197198
"""
198199
try:
199-
return typing.cast("Cursor", next(self))
200+
return next(self)
200201
except StopIteration:
201202
return None
202203
except TypeError:
@@ -271,7 +272,7 @@ def setoutputsize(self: "Cursor", size, column=None):
271272
"""
272273
pass
273274

274-
def __next__(self: "Cursor"):
275+
def __next__(self: "Cursor") -> typing.List:
275276
try:
276277
return self._cached_rows.popleft()
277278
except IndexError:
@@ -311,16 +312,48 @@ def fetch_dataframe(self: "Cursor", num: typing.Optional[int] = None) -> typing.
311312
return None
312313
return pandas.DataFrame(result, columns=columns)
313314

315+
def __is_valid_table(self: "Cursor", table: str) -> bool:
316+
split_table_name: typing.List[str] = table.split(".")
317+
318+
if len(split_table_name) > 2:
319+
return False
320+
321+
q: str = "select 1 from information_schema.tables where table_name = ?"
322+
323+
temp = self.paramstyle
324+
self.paramstyle = "qmark"
325+
326+
try:
327+
if len(split_table_name) == 2:
328+
q += " and table_schema = ?"
329+
self.execute(q, (split_table_name[1], split_table_name[0]))
330+
else:
331+
self.execute(q, (split_table_name[0],))
332+
except:
333+
raise
334+
finally:
335+
# reset paramstyle to it's original value
336+
self.paramstyle = temp
337+
338+
result = self.fetchone()
339+
340+
return result[0] == 1 if result is not None else False
341+
314342
def write_dataframe(self: "Cursor", df: "pandas.DataFrame", table: str) -> None:
315343
"""write same structure dataframe into Redshift database"""
316344
try:
317345
import pandas
318346
except ModuleNotFoundError:
319347
raise ModuleNotFoundError(MISSING_MODULE_ERROR_MSG.format(module="pandas"))
320348

349+
if not self.__is_valid_table(table):
350+
raise InterfaceError("Invalid table name passed to write_dataframe: {}".format(table))
351+
sanitized_table_name: str = self.__sanitize_str(table)
321352
arrays: "numpy.ndarray" = df.values
322353
placeholder: str = ", ".join(["%s"] * len(arrays[0]))
323-
sql: str = "insert into {table} values ({placeholder})".format(table=table, placeholder=placeholder)
354+
sql: str = "insert into {table} values ({placeholder})".format(
355+
table=sanitized_table_name, placeholder=placeholder
356+
)
324357
if len(arrays) == 1:
325358
self.execute(sql, arrays[0])
326359
elif len(arrays) > 1:
@@ -361,16 +394,33 @@ def get_procedures(
361394
" LEFT JOIN pg_catalog.pg_namespace pn ON (c.relnamespace=pn.oid AND pn.nspname='pg_catalog') "
362395
" WHERE p.pronamespace=n.oid "
363396
)
397+
query_args: typing.List[str] = []
364398
if schema_pattern is not None:
365-
sql += " AND n.nspname LIKE {schema}".format(schema=self.__escape_quotes(schema_pattern))
399+
sql += " AND n.nspname LIKE ?"
400+
query_args.append(self.__sanitize_str(schema_pattern))
366401
else:
367402
sql += "and pg_function_is_visible(p.prooid)"
368403

369404
if procedure_name_pattern is not None:
370-
sql += " AND p.proname LIKE {procedure}".format(procedure=self.__escape_quotes(procedure_name_pattern))
405+
sql += " AND p.proname LIKE ?"
406+
query_args.append(self.__sanitize_str(procedure_name_pattern))
371407
sql += " ORDER BY PROCEDURE_SCHEM, PROCEDURE_NAME, p.prooid::text "
372408

373-
self.execute(sql)
409+
if len(query_args) > 0:
410+
# temporarily use qmark paramstyle
411+
temp = self.paramstyle
412+
self.paramstyle = "qmark"
413+
414+
try:
415+
self.execute(sql, tuple(query_args))
416+
except:
417+
raise
418+
finally:
419+
# reset the original value of paramstyle
420+
self.paramstyle = temp
421+
else:
422+
self.execute(sql)
423+
374424
procedures: tuple = self.fetchall()
375425
return procedures
376426

@@ -383,11 +433,25 @@ def get_schemas(
383433
" OR nspname = (pg_catalog.current_schemas(true))[1]) AND (nspname !~ '^pg_toast_temp_' "
384434
" OR nspname = replace((pg_catalog.current_schemas(true))[1], 'pg_temp_', 'pg_toast_temp_')) "
385435
)
436+
query_args: typing.List[str] = []
386437
if schema_pattern is not None:
387-
sql += " AND nspname LIKE {schema}".format(schema=self.__escape_quotes(schema_pattern))
438+
sql += " AND nspname LIKE ?"
439+
query_args.append(self.__sanitize_str(schema_pattern))
388440
sql += " ORDER BY TABLE_SCHEM"
389441

390-
self.execute(sql)
442+
if len(query_args) == 1:
443+
# temporarily use qmark paramstyle
444+
temp = self.paramstyle
445+
self.paramstyle = "qmark"
446+
try:
447+
self.execute(sql, tuple(query_args))
448+
except:
449+
raise
450+
finally:
451+
self.paramstyle = temp
452+
else:
453+
self.execute(sql)
454+
391455
schemas: tuple = self.fetchall()
392456
return schemas
393457

@@ -418,13 +482,28 @@ def get_primary_keys(
418482
"i.indisprimary AND "
419483
"ct.relnamespace = n.oid "
420484
)
485+
query_args: typing.List[str] = []
421486
if schema is not None:
422-
sql += " AND n.nspname = {schema}".format(schema=self.__escape_quotes(schema))
487+
sql += " AND n.nspname = ?"
488+
query_args.append(self.__sanitize_str(schema))
423489
if table is not None:
424-
sql += " AND ct.relname = {table}".format(table=self.__escape_quotes(table))
490+
sql += " AND ct.relname = ?"
491+
query_args.append(self.__sanitize_str(table))
425492

426493
sql += " ORDER BY table_name, pk_name, key_seq"
427-
self.execute(sql)
494+
495+
if len(query_args) > 0:
496+
# temporarily use qmark paramstyle
497+
temp = self.paramstyle
498+
self.paramstyle = "qmark"
499+
try:
500+
self.execute(sql, tuple(query_args))
501+
except:
502+
raise
503+
finally:
504+
self.paramstyle = temp
505+
else:
506+
self.execute(sql)
428507
keys: tuple = self.fetchall()
429508
return keys
430509

@@ -437,15 +516,30 @@ def get_tables(
437516
) -> tuple:
438517
"""Returns the unique public tables which are user-defined within the system"""
439518
sql: str = ""
519+
sql_args: typing.Tuple[str, ...] = tuple()
440520
schema_pattern_type: str = self.__schema_pattern_match(schema_pattern)
441521
if schema_pattern_type == "LOCAL_SCHEMA_QUERY":
442-
sql = self.__build_local_schema_tables_query(catalog, schema_pattern, table_name_pattern, types)
522+
sql, sql_args = self.__build_local_schema_tables_query(catalog, schema_pattern, table_name_pattern, types)
443523
elif schema_pattern_type == "NO_SCHEMA_UNIVERSAL_QUERY":
444-
sql = self.__build_universal_schema_tables_query(catalog, schema_pattern, table_name_pattern, types)
524+
sql, sql_args = self.__build_universal_schema_tables_query(
525+
catalog, schema_pattern, table_name_pattern, types
526+
)
445527
elif schema_pattern_type == "EXTERNAL_SCHEMA_QUERY":
446-
sql = self.__build_external_schema_tables_query(catalog, schema_pattern, table_name_pattern, types)
528+
sql, sql_args = self.__build_external_schema_tables_query(
529+
catalog, schema_pattern, table_name_pattern, types
530+
)
447531

448-
self.execute(sql)
532+
if len(sql_args) > 0:
533+
temp = self.paramstyle
534+
self.paramstyle = "qmark"
535+
try:
536+
self.execute(sql, sql_args)
537+
except:
538+
raise
539+
finally:
540+
self.paramstyle = temp
541+
else:
542+
self.execute(sql)
449543
tables: tuple = self.fetchall()
450544
return tables
451545

@@ -455,7 +549,7 @@ def __build_local_schema_tables_query(
455549
schema_pattern: typing.Optional[str],
456550
table_name_pattern: typing.Optional[str],
457551
types: list,
458-
) -> str:
552+
) -> typing.Tuple[str, typing.Tuple[str, ...]]:
459553
sql: str = (
460554
"SELECT CAST(current_database() AS VARCHAR(124)) AS TABLE_CAT, n.nspname AS TABLE_SCHEM, c.relname AS TABLE_NAME, "
461555
" CASE n.nspname ~ '^pg_' OR n.nspname = 'information_schema' "
@@ -502,32 +596,41 @@ def __build_local_schema_tables_query(
502596
" LEFT JOIN pg_catalog.pg_namespace dn ON (dn.oid=dc.relnamespace AND dn.nspname='pg_catalog') "
503597
" WHERE c.relnamespace = n.oid "
504598
)
505-
filter_clause: str = self.__get_table_filter_clause(
599+
filter_clause, filter_args = self.__get_table_filter_clause(
506600
catalog, schema_pattern, table_name_pattern, types, "LOCAL_SCHEMA_QUERY"
507601
)
508602
orderby: str = " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "
509603

510-
return sql + filter_clause + orderby
604+
return sql + filter_clause + orderby, filter_args
511605

512606
def __get_table_filter_clause(
513607
self: "Cursor",
514608
catalog: typing.Optional[str],
515609
schema_pattern: typing.Optional[str],
516610
table_name_pattern: typing.Optional[str],
517-
types: list,
611+
types: typing.List[str],
518612
schema_pattern_type: str,
519-
) -> str:
613+
) -> typing.Tuple[str, typing.Tuple[str, ...]]:
520614
filter_clause: str = ""
521615
use_schemas: str = "SCHEMAS"
616+
query_args: typing.List[str] = []
522617
if schema_pattern is not None:
523-
filter_clause += " AND TABLE_SCHEM LIKE {schema}".format(schema=self.__escape_quotes(schema_pattern))
618+
filter_clause += " AND TABLE_SCHEM LIKE ?"
619+
query_args.append(self.__sanitize_str(schema_pattern))
524620
if table_name_pattern is not None:
525-
filter_clause += " AND TABLE_NAME LIKE {table}".format(table=self.__escape_quotes(table_name_pattern))
621+
filter_clause += " AND TABLE_NAME LIKE ?"
622+
query_args.append(self.__sanitize_str(table_name_pattern))
526623
if len(types) > 0:
527624
if schema_pattern_type == "LOCAL_SCHEMA_QUERY":
528625
filter_clause += " AND (false "
529626
orclause: str = ""
530627
for type in types:
628+
if type not in table_type_clauses.keys():
629+
raise InterfaceError(
630+
"Invalid type: {} provided. types may only contain: {}".format(
631+
type, table_type_clauses.keys()
632+
)
633+
)
531634
clauses = table_type_clauses[type]
532635
if len(clauses) > 0:
533636
cluase = clauses[use_schemas]
@@ -538,21 +641,28 @@ def __get_table_filter_clause(
538641
filter_clause += " AND TABLE_TYPE IN ( "
539642
length = len(types)
540643
for type in types:
541-
filter_clause += self.__escape_quotes(type)
644+
if type not in table_type_clauses.keys():
645+
raise InterfaceError(
646+
"Invalid type: {} provided. types may only contain: {}".format(
647+
type, table_type_clauses.keys()
648+
)
649+
)
650+
filter_clause += "?"
651+
query_args.append(self.__sanitize_str(type))
542652
length -= 1
543653
if length > 0:
544654
filter_clause += ", "
545655
filter_clause += ") "
546656

547-
return filter_clause
657+
return filter_clause, tuple(query_args)
548658

549659
def __build_universal_schema_tables_query(
550660
self: "Cursor",
551661
catalog: typing.Optional[str],
552662
schema_pattern: typing.Optional[str],
553663
table_name_pattern: typing.Optional[str],
554664
types: list,
555-
) -> str:
665+
) -> typing.Tuple[str, typing.Tuple[str, ...]]:
556666
sql: str = (
557667
"SELECT * FROM (SELECT CAST(current_database() AS VARCHAR(124)) AS TABLE_CAT,"
558668
" table_schema AS TABLE_SCHEM,"
@@ -583,20 +693,20 @@ def __build_universal_schema_tables_query(
583693
" FROM svv_tables)"
584694
" WHERE true "
585695
)
586-
filter_clause: str = self.__get_table_filter_clause(
696+
filter_clause, filter_args = self.__get_table_filter_clause(
587697
catalog, schema_pattern, table_name_pattern, types, "NO_SCHEMA_UNIVERSAL_QUERY"
588698
)
589699
orderby: str = " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "
590700
sql += filter_clause + orderby
591-
return sql
701+
return sql, filter_args
592702

593703
def __build_external_schema_tables_query(
594704
self: "Cursor",
595705
catalog: typing.Optional[str],
596706
schema_pattern: typing.Optional[str],
597707
table_name_pattern: typing.Optional[str],
598708
types: list,
599-
) -> str:
709+
) -> typing.Tuple[str, typing.Tuple[str, ...]]:
600710
sql: str = (
601711
"SELECT * FROM (SELECT CAST(current_database() AS VARCHAR(124)) AS TABLE_CAT,"
602712
" schemaname AS table_schem,"
@@ -611,12 +721,12 @@ def __build_external_schema_tables_query(
611721
" FROM svv_external_tables)"
612722
" WHERE true "
613723
)
614-
filter_clause: str = self.__get_table_filter_clause(
724+
filter_clause, filter_args = self.__get_table_filter_clause(
615725
catalog, schema_pattern, table_name_pattern, types, "EXTERNAL_SCHEMA_QUERY"
616726
)
617727
orderby: str = " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "
618728
sql += filter_clause + orderby
619-
return sql
729+
return sql, filter_args
620730

621731
def get_columns(
622732
self: "Cursor",
@@ -1477,5 +1587,8 @@ def __schema_pattern_match(self: "Cursor", schema_pattern: typing.Optional[str])
14771587
else:
14781588
return "NO_SCHEMA_UNIVERSAL_QUERY"
14791589

1590+
def __sanitize_str(self: "Cursor", s: str) -> str:
1591+
return re.sub(r"[-;/'\"\n\r ]", "", s)
1592+
14801593
def __escape_quotes(self: "Cursor", s: str) -> str:
1481-
return "'{s}'".format(s=s)
1594+
return "'{s}'".format(s=self.__sanitize_str(s))

test/integration/test_query.py

+7
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,10 @@ def test_handle_COMMAND_COMPLETE_closed_ps(con, mocker):
367367
# begin transaction, drop table t1, create table t1
368368
assert spy.called
369369
assert spy.call_count == 3
370+
371+
372+
@pytest.mark.parametrize("_input", ["NO_SCHEMA_UNIVERSAL_QUERY", "EXTERNAL_SCHEMA_QUERY", "LOCAL_SCHEMA_QUERY"])
373+
def test___get_table_filter_clause_throws_for_bad_type(con, _input):
374+
with con.cursor() as cursor:
375+
with pytest.raises(redshift_connector.InterfaceError):
376+
cursor.get_tables(schema_pattern=_input, types=["garbage"])

0 commit comments

Comments
 (0)