Skip to content

Commit 32eb2da

Browse files
committed
refactor(Connection, type_utils): add py_types, pg_types to Connection class, declare typecode constants at package level
1 parent 229dfd2 commit 32eb2da

File tree

3 files changed

+254
-97
lines changed

3 files changed

+254
-97
lines changed

redshift_connector/__init__.py

+86-18
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,50 @@
3939
)
4040
from redshift_connector.redshift_property import RedshiftProperty
4141
from redshift_connector.utils import DriverInfo
42+
from redshift_connector.utils.type_utils import (
43+
BIGINT,
44+
BIGINTEGER,
45+
BOOLEAN,
46+
BOOLEAN_ARRAY,
47+
BYTES,
48+
CHAR,
49+
CHAR_ARRAY,
50+
DATE,
51+
DATETIME,
52+
DECIMAL,
53+
DECIMAL_ARRAY,
54+
FLOAT,
55+
FLOAT_ARRAY,
56+
GEOMETRY,
57+
INET,
58+
INT2VECTOR,
59+
INTEGER,
60+
INTEGER_ARRAY,
61+
INTERVAL,
62+
JSON,
63+
JSONB,
64+
MACADDR,
65+
NAME,
66+
NAME_ARRAY,
67+
NULLTYPE,
68+
NUMBER,
69+
OID,
70+
ROWID,
71+
SMALLINT,
72+
STRING,
73+
SUPER,
74+
TEXT,
75+
TEXT_ARRAY,
76+
TIME,
77+
TIMESTAMP,
78+
TIMESTAMPTZ,
79+
TIMETZ,
80+
UNKNOWN,
81+
UUID_TYPE,
82+
VARCHAR,
83+
VARCHAR_ARRAY,
84+
XID,
85+
)
4286

4387
from .version import __version__
4488

@@ -293,24 +337,6 @@ def connect(
293337
String property stating the type of parameter marker formatting expected by the interface; This value defaults to "format", in which parameters are marked in this format "WHERE name=%s"
294338
"""
295339

296-
# I have no idea what this would be used for by a client app. Should it be
297-
# TEXT, VARCHAR, CHAR? It will only compare against row_description's
298-
# type_code if it is this one type. It is the varchar type oid for now, this
299-
# appears to match expectations in the DB API 2.0 compliance test suite.
300-
301-
STRING: int = 1043
302-
"""String type oid."""
303-
304-
305-
NUMBER: int = 1700
306-
"""Numeric type oid"""
307-
308-
DATETIME: int = 1114
309-
"""Timestamp type oid"""
310-
311-
ROWID: int = 26
312-
"""ROWID type oid"""
313-
314340
__all__: typing.Any = [
315341
"Warning",
316342
"DataError",
@@ -343,4 +369,46 @@ def connect(
343369
"PGText",
344370
"PGVarchar",
345371
"__version__",
372+
"BIGINT",
373+
"BIGINTEGER",
374+
"BOOLEAN",
375+
"BOOLEAN_ARRAY",
376+
"BYTES",
377+
"CHAR",
378+
"CHAR_ARRAY",
379+
"DATE",
380+
"DATETIME",
381+
"DECIMAL",
382+
"DECIMAL_ARRAY",
383+
"FLOAT",
384+
"FLOAT_ARRAY",
385+
"GEOMETRY",
386+
"INET",
387+
"INT2VECTOR",
388+
"INTEGER",
389+
"INTEGER_ARRAY",
390+
"INTERVAL",
391+
"JSON",
392+
"JSONB",
393+
"MACADDR",
394+
"NAME",
395+
"NAME_ARRAY",
396+
"NULLTYPE",
397+
"NUMBER",
398+
"OID",
399+
"ROWID",
400+
"STRING",
401+
"SMALLINT",
402+
"SUPER",
403+
"TEXT",
404+
"TEXT_ARRAY",
405+
"TIME",
406+
"TIMESTAMP",
407+
"TIMESTAMPTZ",
408+
"TIMETZ",
409+
"UNKNOWN",
410+
"UUID_TYPE",
411+
"VARCHAR",
412+
"VARCHAR_ARRAY",
413+
"XID",
346414
]

redshift_connector/core.py

+56-36
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,33 @@
7373
int_array_recv,
7474
numeric_in,
7575
numeric_in_binary,
76-
pg_types,
77-
py_types,
76+
)
77+
from redshift_connector.utils import pg_types as PG_TYPES
78+
from redshift_connector.utils import py_types as PY_TYPES
79+
from redshift_connector.utils import (
7880
q_pack,
7981
time_in,
8082
time_recv_binary,
8183
timetz_in,
8284
timetz_recv_binary,
8385
walk_array,
8486
)
87+
from redshift_connector.utils.type_utils import (
88+
BIGINT,
89+
DATE,
90+
INTEGER,
91+
INTEGER_ARRAY,
92+
NUMERIC,
93+
REAL_ARRAY,
94+
SMALLINT,
95+
SMALLINT_ARRAY,
96+
TEXT_ARRAY,
97+
TIME,
98+
TIMESTAMP,
99+
TIMESTAMPTZ,
100+
TIMETZ,
101+
VARCHAR_ARRAY,
102+
)
85103

86104
if TYPE_CHECKING:
87105
from ssl import SSLSocket
@@ -451,6 +469,8 @@ def __init__(
451469
self._run_cursor: Cursor = Cursor(self, paramstyle="named")
452470
self._client_protocol_version: int = client_protocol_version
453471
self._database = database
472+
self.py_types = deepcopy(PY_TYPES)
473+
self.pg_types = deepcopy(PG_TYPES)
454474
self._database_metadata_current_db_only: bool = database_metadata_current_db_only
455475

456476
# based on _client_protocol_version value, we must use different conversion functions
@@ -647,31 +667,31 @@ def __init__(
647667

648668
def _enable_protocol_based_conversion_funcs(self: "Connection"):
649669
if self._client_protocol_version == ClientProtocolVersion.BINARY.value:
650-
pg_types[1700] = (FC_BINARY, numeric_in_binary)
651-
pg_types[1082] = (FC_BINARY, date_recv_binary)
652-
pg_types[1083] = (FC_BINARY, time_recv_binary)
653-
pg_types[1266] = (FC_BINARY, timetz_recv_binary)
654-
pg_types[1002] = (FC_BINARY, array_recv_binary) # CHAR[]
655-
pg_types[1005] = (FC_BINARY, array_recv_binary) # INT2[]
656-
pg_types[1007] = (FC_BINARY, array_recv_binary) # INT4[]
657-
pg_types[1009] = (FC_BINARY, array_recv_binary) # TEXT[]
658-
pg_types[1015] = (FC_BINARY, array_recv_binary) # VARCHAR[]
659-
pg_types[1021] = (FC_BINARY, array_recv_binary) # FLOAT4[]
660-
pg_types[1028] = (FC_BINARY, array_recv_binary) # OID[]
661-
pg_types[1034] = (FC_BINARY, array_recv_binary) # ACLITEM[]
670+
self.pg_types[NUMERIC] = (FC_BINARY, numeric_in_binary)
671+
self.pg_types[DATE] = (FC_BINARY, date_recv_binary)
672+
self.pg_types[TIME] = (FC_BINARY, time_recv_binary)
673+
self.pg_types[TIMETZ] = (FC_BINARY, timetz_recv_binary)
674+
self.pg_types[1002] = (FC_BINARY, array_recv_binary) # CHAR[]
675+
self.pg_types[SMALLINT_ARRAY] = (FC_BINARY, array_recv_binary) # INT2[]
676+
self.pg_types[INTEGER_ARRAY] = (FC_BINARY, array_recv_binary) # INT4[]
677+
self.pg_types[TEXT_ARRAY] = (FC_BINARY, array_recv_binary) # TEXT[]
678+
self.pg_types[VARCHAR_ARRAY] = (FC_BINARY, array_recv_binary) # VARCHAR[]
679+
self.pg_types[REAL_ARRAY] = (FC_BINARY, array_recv_binary) # FLOAT4[]
680+
self.pg_types[1028] = (FC_BINARY, array_recv_binary) # OID[]
681+
self.pg_types[1034] = (FC_BINARY, array_recv_binary) # ACLITEM[]
662682
else: # text protocol
663-
pg_types[1700] = (FC_TEXT, numeric_in)
664-
pg_types[1083] = (FC_TEXT, time_in)
665-
pg_types[1082] = (FC_TEXT, date_in)
666-
pg_types[1266] = (FC_TEXT, timetz_in)
667-
pg_types[1002] = (FC_TEXT, array_recv_text) # CHAR[]
668-
pg_types[1005] = (FC_TEXT, int_array_recv) # INT2[]
669-
pg_types[1007] = (FC_TEXT, int_array_recv) # INT4[]
670-
pg_types[1009] = (FC_TEXT, array_recv_text) # TEXT[]
671-
pg_types[1015] = (FC_TEXT, array_recv_text) # VARCHAR[]
672-
pg_types[1021] = (FC_TEXT, float_array_recv) # FLOAT4[]
673-
pg_types[1028] = (FC_TEXT, int_array_recv) # OID[]
674-
pg_types[1034] = (FC_TEXT, array_recv_text) # ACLITEM[]
683+
self.pg_types[NUMERIC] = (FC_TEXT, numeric_in)
684+
self.pg_types[TIME] = (FC_TEXT, time_in)
685+
self.pg_types[DATE] = (FC_TEXT, date_in)
686+
self.pg_types[TIMETZ] = (FC_TEXT, timetz_in)
687+
self.pg_types[1002] = (FC_TEXT, array_recv_text) # CHAR[]
688+
self.pg_types[SMALLINT_ARRAY] = (FC_TEXT, int_array_recv) # INT2[]
689+
self.pg_types[INTEGER_ARRAY] = (FC_TEXT, int_array_recv) # INT4[]
690+
self.pg_types[TEXT_ARRAY] = (FC_TEXT, array_recv_text) # TEXT[]
691+
self.pg_types[VARCHAR_ARRAY] = (FC_TEXT, array_recv_text) # VARCHAR[]
692+
self.pg_types[REAL_ARRAY] = (FC_TEXT, float_array_recv) # FLOAT4[]
693+
self.pg_types[1028] = (FC_TEXT, int_array_recv) # OID[]
694+
self.pg_types[1034] = (FC_TEXT, array_recv_text) # ACLITEM[]
675695

676696
@property
677697
def _is_multi_databases_catalog_enable_in_server(self: "Connection") -> bool:
@@ -946,31 +966,31 @@ def handle_BACKEND_KEY_DATA(self: "Connection", data: bytes, ps) -> None:
946966

947967
def inspect_datetime(self: "Connection", value: Datetime):
948968
if value.tzinfo is None:
949-
return py_types[1114] # timestamp
969+
return self.py_types[TIMESTAMP] # timestamp
950970
else:
951-
return py_types[1184] # send as timestamptz
971+
return self.py_types[TIMESTAMPTZ] # send as timestamptz
952972

953973
def inspect_int(self: "Connection", value: int):
954974
if min_int2 < value < max_int2:
955-
return py_types[21]
975+
return self.py_types[SMALLINT]
956976
if min_int4 < value < max_int4:
957-
return py_types[23]
977+
return self.py_types[INTEGER]
958978
if min_int8 < value < max_int8:
959-
return py_types[20]
960-
return py_types[Decimal]
979+
return self.py_types[BIGINT]
980+
return self.py_types[Decimal]
961981

962982
def make_params(self: "Connection", values):
963983
params = []
964984
for value in values:
965985
typ = type(value)
966986
try:
967-
params.append(py_types[typ])
987+
params.append(self.py_types[typ])
968988
except KeyError:
969989
try:
970990
params.append(self.inspect_funcs[typ](value))
971991
except KeyError as e:
972992
param = None
973-
for k, v in py_types.items():
993+
for k, v in self.py_types.items():
974994
try:
975995
if isinstance(value, typing.cast(type, k)):
976996
param = v
@@ -1033,7 +1053,7 @@ def handle_ROW_DESCRIPTION(self: "Connection", data, cursor: Cursor) -> None:
10331053
idx += 2
10341054

10351055
cursor.ps["row_desc"].append(field)
1036-
field["pg8000_fc"], field["func"] = pg_types[field["type_oid"]]
1056+
field["pg8000_fc"], field["func"] = self.pg_types[field["type_oid"]]
10371057

10381058
def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
10391059
"""
@@ -1158,7 +1178,7 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
11581178

11591179
# We've got row_desc that allows us to identify what we're
11601180
# going to get back from this statement.
1161-
output_fc = tuple(pg_types[f["type_oid"]][0] for f in ps["row_desc"])
1181+
output_fc = tuple(self.pg_types[f["type_oid"]][0] for f in ps["row_desc"])
11621182

11631183
ps["input_funcs"] = tuple(f["func"] for f in ps["row_desc"])
11641184
# Byte1('B') - Identifies the Bind command.

0 commit comments

Comments
 (0)