Skip to content

Commit 23df7cc

Browse files
committed
feat(Connection, numeric_to_float): add connection option to convert numeric datatype to Python float
1 parent 6119ac0 commit 23df7cc

File tree

7 files changed

+67
-1
lines changed

7 files changed

+67
-1
lines changed

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ Connection Parameters
314314
+-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+
315315
| max_prepared_statements | int | The maximum number of prepared statements that can be open at once | 1000 | No |
316316
+-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+
317+
| numeric_to_float | bool | Specifies if NUMERIC datatype values will be converted from decimal.Decimal to float. By default NUMERIC values are received as decimal.Decimal | False | No |
318+
+-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+
317319
| partner_sp_id | str | The Partner SP Id used for authentication with Ping | None | No |
318320
+-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+
319321
| password | str | The password to use for authentication | None | No |

redshift_connector/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def connect(
179179
endpoint_url: typing.Optional[str] = None,
180180
provider_name: typing.Optional[str] = None,
181181
scope: typing.Optional[str] = None,
182+
numeric_to_float: typing.Optional[bool] = False,
182183
) -> Connection:
183184
"""
184185
Establishes a :class:`Connection` to an Amazon Redshift cluster. This function validates user input, optionally authenticates using an identity provider plugin, then constructs a :class:`Connection` object.
@@ -272,6 +273,8 @@ def connect(
272273
The name of the Redshift Native Auth Provider.
273274
scope: Optional[str]
274275
Scope for BrowserAzureOauth2CredentialsProvider authentication.
276+
numeric_to_float: Optional[str]
277+
Specifies if NUMERIC datatype values will be converted from ``decimal.Decimal`` to ``float``. By default NUMERIC values are received as ``decimal.Decimal``.
275278
Returns
276279
-------
277280
A Connection object associated with the specified Amazon Redshift cluster: :class:`Connection`
@@ -304,6 +307,7 @@ def connect(
304307
info.put("listen_port", listen_port)
305308
info.put("login_url", login_url)
306309
info.put("max_prepared_statements", max_prepared_statements)
310+
info.put("numeric_to_float", numeric_to_float)
307311
info.put("partner_sp_id", partner_sp_id)
308312
info.put("password", password)
309313
info.put("port", port)
@@ -382,6 +386,7 @@ def connect(
382386
credentials_provider=info.credentials_provider,
383387
provider_name=info.provider_name,
384388
web_identity_token=info.web_identity_token,
389+
numeric_to_float=info.numeric_to_float,
385390
)
386391

387392

redshift_connector/core.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575
make_divider_block,
7676
numeric_in,
7777
numeric_in_binary,
78+
numeric_to_float_binary,
79+
numeric_to_float_in,
7880
)
7981
from redshift_connector.utils import pg_types as PG_TYPES
8082
from redshift_connector.utils import py_types as PY_TYPES
@@ -419,6 +421,7 @@ def __init__(
419421
credentials_provider: typing.Optional[str] = None,
420422
provider_name: typing.Optional[str] = None,
421423
web_identity_token: typing.Optional[str] = None,
424+
numeric_to_float: bool = False,
422425
):
423426
"""
424427
Creates a :class:`Connection` to an Amazon Redshift cluster. For more information on establishing a connection to an Amazon Redshift cluster using `federated API access <https://aws.amazon.com/blogs/big-data/federated-api-access-to-amazon-redshift-using-an-amazon-redshift-connector-for-python/>`_ see our examples page.
@@ -461,6 +464,8 @@ def __init__(
461464
The name of the Redshift Native Auth Provider.
462465
web_identity_token: Optional[str]
463466
A web identity token used for authentication via Redshift Native IDP Integration
467+
numeric_to_float: bool
468+
Specifies if NUMERIC datatype values will be converted from ``decimal.Decimal`` to ``float``. By default NUMERIC values are received as ``decimal.Decimal``.
464469
"""
465470
self.merge_socket_read = True
466471

@@ -484,6 +489,7 @@ def __init__(
484489
self.py_types = deepcopy(PY_TYPES)
485490
self.pg_types = deepcopy(PG_TYPES)
486491
self._database_metadata_current_db_only: bool = database_metadata_current_db_only
492+
self.numeric_to_float: bool = numeric_to_float
487493

488494
# based on _client_protocol_version value, we must use different conversion functions
489495
# for receiving some datatypes
@@ -725,6 +731,10 @@ def _enable_protocol_based_conversion_funcs(self: "Connection"):
725731
self.pg_types[1028] = (FC_BINARY, array_recv_binary) # OID[]
726732
self.pg_types[1034] = (FC_BINARY, array_recv_binary) # ACLITEM[]
727733
self.pg_types[VARBYTE] = (FC_TEXT, text_recv) # VARBYTE
734+
735+
if self.numeric_to_float:
736+
self.pg_types[NUMERIC] = (FC_BINARY, numeric_to_float_binary)
737+
728738
else: # text protocol
729739
self.pg_types[NUMERIC] = (FC_TEXT, numeric_in)
730740
self.pg_types[TIME] = (FC_TEXT, time_in)
@@ -741,6 +751,9 @@ def _enable_protocol_based_conversion_funcs(self: "Connection"):
741751
self.pg_types[1034] = (FC_TEXT, array_recv_text) # ACLITEM[]
742752
self.pg_types[VARBYTE] = (FC_TEXT, varbytehex_recv) # VARBYTE
743753

754+
if self.numeric_to_float:
755+
self.pg_types[NUMERIC] = (FC_TEXT, numeric_to_float_in)
756+
744757
@property
745758
def _is_multi_databases_catalog_enable_in_server(self: "Connection") -> bool:
746759
if (b"datashare_enabled", str("on").encode()) in self.parameter_statuses:
@@ -1918,7 +1931,7 @@ def handle_DATA_ROW(self: "Connection", data: bytes, cursor: Cursor) -> None:
19181931
data_idx += 4
19191932
if vlen == -1:
19201933
row.append(None)
1921-
elif desc[0] == numeric_in_binary:
1934+
elif desc[0] in (numeric_in_binary, numeric_to_float_binary):
19221935
row.append(desc[0](data, data_idx, vlen, desc[1]))
19231936
data_idx += vlen
19241937
else:

redshift_connector/redshift_property.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(self: "RedshiftProperty", **kwargs):
113113
# The name of the Redshift Native Auth Provider
114114
self.provider_name: typing.Optional[str] = None
115115
self.scope: str = ""
116+
self.numeric_to_float: bool = False
116117

117118
else:
118119
for k, v in kwargs.items():

redshift_connector/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
int_array_recv,
3333
numeric_in,
3434
numeric_in_binary,
35+
numeric_to_float_binary,
36+
numeric_to_float_in,
3537
pg_types,
3638
py_types,
3739
q_pack,

redshift_connector/utils/type_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,28 @@ def numeric_in_binary(data: bytes, offset: int, length: int, scale: int) -> Deci
312312
return Decimal(raw_value).scaleb(-1 * scale)
313313

314314

315+
def numeric_to_float_binary(data: bytes, offset: int, length: int, scale: int) -> float:
316+
raw_value: int
317+
318+
if length == 8:
319+
raw_value = q_unpack(data, offset)[0]
320+
elif length == 16:
321+
temp: typing.Tuple[int, int] = qq_unpack(data, offset)
322+
raw_value = (temp[0] << 64) | temp[1]
323+
else:
324+
raise Exception("Malformed column value of type numeric received")
325+
326+
return raw_value * 10 ** (-1 * scale)
327+
328+
315329
def numeric_in(data: bytes, offset: int, length: int) -> Decimal:
316330
return Decimal(data[offset : offset + length].decode(_client_encoding))
317331

318332

333+
def numeric_to_float_in(data: bytes, offset: int, length: int) -> float:
334+
return float(data[offset : offset + length].decode(_client_encoding))
335+
336+
319337
# def uuid_recv(data: bytes, offset: int, length: int) -> UUID:
320338
# return UUID(bytes=data[offset:offset+length])
321339

test/integration/datatype/test_datatypes.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,28 @@ def test_abstime(db_kwargs, _input, client_protocol):
181181
cursor.execute("select '{}'::abstime".format(insert_val))
182182
res = cursor.fetchone()
183183
assert res[0] == exp_val
184+
185+
186+
numeric_vals: typing.List[typing.Tuple[str, float]] = [
187+
("to_number('12,454.8-', 'S99G999D9')", -12454.8),
188+
("to_number('8.1-', '9D9S')", -8.1),
189+
("to_number('897.6', '999D9S')", 897.6),
190+
]
191+
192+
193+
@pytest.mark.parametrize("client_protocol", ClientProtocolVersion.list())
194+
@pytest.mark.parametrize("_input", numeric_vals)
195+
def test_numeric_to_float(db_kwargs, _input, client_protocol):
196+
insert_val, exp_val = _input
197+
db_kwargs["numeric_to_float"] = True
198+
with redshift_connector.connect(**db_kwargs) as conn:
199+
with conn.cursor() as cursor:
200+
cursor.execute("select {}".format(insert_val))
201+
res = cursor.fetchone()
202+
assert isinstance(res[0], float)
203+
assert isclose(
204+
typing.cast(float, res[0]),
205+
typing.cast(float, exp_val),
206+
rel_tol=1e-05,
207+
abs_tol=1e-08,
208+
)

0 commit comments

Comments
 (0)