Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit dfe3390

Browse files
authored
Merge pull request #757 from pik94/prevent-type-overflow
Prevent type overflow
2 parents c409c81 + 842481f commit dfe3390

17 files changed

+109
-34
lines changed

data_diff/databases/base.py

+32-4
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,19 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal
199199
class BaseDialect(abc.ABC):
200200
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False
201201
SUPPORTS_INDEXES: ClassVar[bool] = False
202+
PREVENT_OVERFLOW_WHEN_CONCAT: ClassVar[bool] = False
202203
TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {}
203204

204205
PLACEHOLDER_TABLE = None # Used for Oracle
205206

207+
# Some database do not support long string so concatenation might lead to type overflow
208+
209+
_prevent_overflow_when_concat: bool = False
210+
211+
def enable_preventing_type_overflow(self) -> None:
212+
logger.info("Preventing type overflow when concatenation is enabled")
213+
self._prevent_overflow_when_concat = True
214+
206215
def parse_table_name(self, name: str) -> DbPath:
207216
"Parse the given table name into a DbPath"
208217
return parse_table_name(name)
@@ -392,10 +401,19 @@ def render_checksum(self, c: Compiler, elem: Checksum) -> str:
392401
return f"sum({md5})"
393402

394403
def render_concat(self, c: Compiler, elem: Concat) -> str:
404+
if self._prevent_overflow_when_concat:
405+
items = [
406+
f"{self.compile(c, Code(self.md5_as_hex(self.to_string(self.compile(c, expr)))))}"
407+
for expr in elem.exprs
408+
]
409+
395410
# We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL
396-
items = [
397-
f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')" for expr in elem.exprs
398-
]
411+
else:
412+
items = [
413+
f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')"
414+
for expr in elem.exprs
415+
]
416+
399417
assert items
400418
if len(items) == 1:
401419
return items[0]
@@ -769,6 +787,10 @@ def set_timezone_to_utc(self) -> str:
769787
def md5_as_int(self, s: str) -> str:
770788
"Provide SQL for computing md5 and returning an int"
771789

790+
@abstractmethod
791+
def md5_as_hex(self, s: str) -> str:
792+
"""Method to calculate MD5"""
793+
772794
@abstractmethod
773795
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
774796
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.
@@ -885,13 +907,16 @@ class Database(abc.ABC):
885907
Instanciated using :meth:`~data_diff.connect`
886908
"""
887909

910+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = BaseDialect
911+
888912
SUPPORTS_ALPHANUMS: ClassVar[bool] = True
889913
SUPPORTS_UNIQUE_CONSTAINT: ClassVar[bool] = False
890914
CONNECT_URI_KWPARAMS: ClassVar[List[str]] = []
891915

892916
default_schema: Optional[str] = None
893917
_interactive: bool = False
894918
is_closed: bool = False
919+
_dialect: BaseDialect = None
895920

896921
@property
897922
def name(self):
@@ -1120,10 +1145,13 @@ def close(self):
11201145
return super().close()
11211146

11221147
@property
1123-
@abstractmethod
11241148
def dialect(self) -> BaseDialect:
11251149
"The dialect of the database. Used internally by Database, and also available publicly."
11261150

1151+
if not self._dialect:
1152+
self._dialect = self.DIALECT_CLASS()
1153+
return self._dialect
1154+
11271155
@property
11281156
@abstractmethod
11291157
def CONNECT_URI_HELP(self) -> str:

data_diff/databases/bigquery.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from typing import Any, List, Union
2+
from typing import Any, ClassVar, List, Union, Type
33

44
import attrs
55

@@ -134,6 +134,9 @@ def parse_table_name(self, name: str) -> DbPath:
134134
def md5_as_int(self, s: str) -> str:
135135
return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS})) as int64) as numeric) - {CHECKSUM_OFFSET}"
136136

137+
def md5_as_hex(self, s: str) -> str:
138+
return f"md5({s})"
139+
137140
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
138141
if coltype.rounds:
139142
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
@@ -179,9 +182,9 @@ def normalize_struct(self, value: str, _coltype: Struct) -> str:
179182

180183
@attrs.define(frozen=False, init=False, kw_only=True)
181184
class BigQuery(Database):
185+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
182186
CONNECT_URI_HELP = "bigquery://<project>/<dataset>"
183187
CONNECT_URI_PARAMS = ["dataset"]
184-
dialect = Dialect()
185188

186189
project: str
187190
dataset: str

data_diff/databases/clickhouse.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Optional, Type
1+
from typing import Any, ClassVar, Dict, Optional, Type
22

33
import attrs
44

@@ -105,6 +105,9 @@ def md5_as_int(self, s: str) -> str:
105105
f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx}))))) - {CHECKSUM_OFFSET}"
106106
)
107107

108+
def md5_as_hex(self, s: str) -> str:
109+
return f"hex(MD5({s}))"
110+
108111
def normalize_number(self, value: str, coltype: FractionalType) -> str:
109112
# If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped.
110113
# For example:
@@ -164,7 +167,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
164167

165168
@attrs.define(frozen=False, init=False, kw_only=True)
166169
class Clickhouse(ThreadedDatabase):
167-
dialect = Dialect()
170+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
168171
CONNECT_URI_HELP = "clickhouse://<user>:<password>@<host>/<database>"
169172
CONNECT_URI_PARAMS = ["database?"]
170173

data_diff/databases/databricks.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Any, Dict, Sequence
2+
from typing import Any, ClassVar, Dict, Sequence, Type
33
import logging
44

55
import attrs
@@ -82,6 +82,9 @@ def parse_table_name(self, name: str) -> DbPath:
8282
def md5_as_int(self, s: str) -> str:
8383
return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0)) - {CHECKSUM_OFFSET}"
8484

85+
def md5_as_hex(self, s: str) -> str:
86+
return f"md5({s})"
87+
8588
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
8689
"""Databricks timestamp contains no more than 6 digits in precision"""
8790

@@ -104,7 +107,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
104107

105108
@attrs.define(frozen=False, init=False, kw_only=True)
106109
class Databricks(ThreadedDatabase):
107-
dialect = Dialect()
110+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
108111
CONNECT_URI_HELP = "databricks://:<access_token>@<server_hostname>/<http_path>"
109112
CONNECT_URI_PARAMS = ["catalog", "schema"]
110113

data_diff/databases/duckdb.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Union
1+
from typing import Any, ClassVar, Dict, Union, Type
22

33
import attrs
44

@@ -100,6 +100,9 @@ def current_timestamp(self) -> str:
100100
def md5_as_int(self, s: str) -> str:
101101
return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT - {CHECKSUM_OFFSET}"
102102

103+
def md5_as_hex(self, s: str) -> str:
104+
return f"md5({s})"
105+
103106
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
104107
# It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers.
105108
if coltype.rounds and coltype.precision > 0:
@@ -116,7 +119,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
116119

117120
@attrs.define(frozen=False, init=False, kw_only=True)
118121
class DuckDB(Database):
119-
dialect = Dialect()
122+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
120123
SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it
121124
CONNECT_URI_HELP = "duckdb://<dbname>@<filepath>"
122125
CONNECT_URI_PARAMS = ["database", "dbpath"]

data_diff/databases/mssql.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Optional
1+
from typing import Any, ClassVar, Dict, Optional, Type
22

33
import attrs
44

@@ -38,7 +38,7 @@ def import_mssql():
3838
class Dialect(BaseDialect):
3939
name = "MsSQL"
4040
ROUNDS_ON_PREC_LOSS = True
41-
SUPPORTS_PRIMARY_KEY = True
41+
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
4242
SUPPORTS_INDEXES = True
4343
TYPE_CLASSES = {
4444
# Timestamps
@@ -151,10 +151,13 @@ def normalize_number(self, value: str, coltype: NumericType) -> str:
151151
def md5_as_int(self, s: str) -> str:
152152
return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1)) - {CHECKSUM_OFFSET}"
153153

154+
def md5_as_hex(self, s: str) -> str:
155+
return f"HashBytes('MD5', {s})"
156+
154157

155158
@attrs.define(frozen=False, init=False, kw_only=True)
156159
class MsSQL(ThreadedDatabase):
157-
dialect = Dialect()
160+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
158161
CONNECT_URI_HELP = "mssql://<user>:<password>@<host>/<database>/<schema>"
159162
CONNECT_URI_PARAMS = ["database", "schema"]
160163

data_diff/databases/mysql.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict
1+
from typing import Any, ClassVar, Dict, Type
22

33
import attrs
44

@@ -40,7 +40,7 @@ def import_mysql():
4040
class Dialect(BaseDialect):
4141
name = "MySQL"
4242
ROUNDS_ON_PREC_LOSS = True
43-
SUPPORTS_PRIMARY_KEY = True
43+
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
4444
SUPPORTS_INDEXES = True
4545
TYPE_CLASSES = {
4646
# Dates
@@ -101,6 +101,9 @@ def set_timezone_to_utc(self) -> str:
101101
def md5_as_int(self, s: str) -> str:
102102
return f"conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) - {CHECKSUM_OFFSET}"
103103

104+
def md5_as_hex(self, s: str) -> str:
105+
return f"md5({s})"
106+
104107
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
105108
if coltype.rounds:
106109
return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))")
@@ -117,7 +120,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
117120

118121
@attrs.define(frozen=False, init=False, kw_only=True)
119122
class MySQL(ThreadedDatabase):
120-
dialect = Dialect()
123+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
121124
SUPPORTS_ALPHANUMS = False
122125
SUPPORTS_UNIQUE_CONSTAINT = True
123126
CONNECT_URI_HELP = "mysql://<user>:<password>@<host>/<database>"

data_diff/databases/oracle.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, ClassVar, Dict, List, Optional, Type
22

33
import attrs
44

@@ -43,7 +43,7 @@ class Dialect(
4343
BaseDialect,
4444
):
4545
name = "Oracle"
46-
SUPPORTS_PRIMARY_KEY = True
46+
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
4747
SUPPORTS_INDEXES = True
4848
TYPE_CLASSES: Dict[str, type] = {
4949
"NUMBER": Decimal,
@@ -137,6 +137,9 @@ def md5_as_int(self, s: str) -> str:
137137
# TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ?
138138
return f"to_number(substr(standard_hash({s}, 'MD5'), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 'xxxxxxxxxxxxxxx') - {CHECKSUM_OFFSET}"
139139

140+
def md5_as_hex(self, s: str) -> str:
141+
return f"standard_hash({s}, 'MD5')"
142+
140143
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
141144
# Cast is necessary for correct MD5 (trimming not enough)
142145
return f"CAST(TRIM({value}) AS VARCHAR(36))"
@@ -161,7 +164,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
161164

162165
@attrs.define(frozen=False, init=False, kw_only=True)
163166
class Oracle(ThreadedDatabase):
164-
dialect = Dialect()
167+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
165168
CONNECT_URI_HELP = "oracle://<user>:<password>@<host>/<database>"
166169
CONNECT_URI_PARAMS = ["database?"]
167170

data_diff/databases/postgresql.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def import_postgresql():
4242
class PostgresqlDialect(BaseDialect):
4343
name = "PostgreSQL"
4444
ROUNDS_ON_PREC_LOSS = True
45-
SUPPORTS_PRIMARY_KEY = True
45+
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
4646
SUPPORTS_INDEXES = True
4747

4848
TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {
@@ -98,6 +98,9 @@ def type_repr(self, t) -> str:
9898
def md5_as_int(self, s: str) -> str:
9999
return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint - {CHECKSUM_OFFSET}"
100100

101+
def md5_as_hex(self, s: str) -> str:
102+
return f"md5({s})"
103+
101104
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
102105
if coltype.rounds:
103106
return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')"
@@ -119,7 +122,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str:
119122

120123
@attrs.define(frozen=False, init=False, kw_only=True)
121124
class PostgreSQL(ThreadedDatabase):
122-
dialect = PostgresqlDialect()
125+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = PostgresqlDialect
123126
SUPPORTS_UNIQUE_CONSTAINT = True
124127
CONNECT_URI_HELP = "postgresql://<user>:<password>@<host>/<database>"
125128
CONNECT_URI_PARAMS = ["database?"]

data_diff/databases/presto.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from functools import partial
22
import re
3-
from typing import Any
3+
from typing import Any, ClassVar, Type
44

55
import attrs
66

@@ -128,6 +128,9 @@ def current_timestamp(self) -> str:
128128
def md5_as_int(self, s: str) -> str:
129129
return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0)) - {CHECKSUM_OFFSET}"
130130

131+
def md5_as_hex(self, s: str) -> str:
132+
return f"to_hex(md5(to_utf8({s})))"
133+
131134
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
132135
# Trim doesn't work on CHAR type
133136
return f"TRIM(CAST({value} AS VARCHAR))"
@@ -150,7 +153,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
150153

151154
@attrs.define(frozen=False, init=False, kw_only=True)
152155
class Presto(Database):
153-
dialect = Dialect()
156+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
154157
CONNECT_URI_HELP = "presto://<user>@<host>/<catalog>/<schema>"
155158
CONNECT_URI_PARAMS = ["catalog", "schema"]
156159

data_diff/databases/redshift.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
TimestampTZ,
1313
)
1414
from data_diff.databases.postgresql import (
15+
BaseDialect,
1516
PostgreSQL,
1617
MD5_HEXDIGITS,
1718
CHECKSUM_HEXDIGITS,
@@ -47,6 +48,9 @@ def type_repr(self, t) -> str:
4748
def md5_as_int(self, s: str) -> str:
4849
return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38) - {CHECKSUM_OFFSET}"
4950

51+
def md5_as_hex(self, s: str) -> str:
52+
return f"md5({s})"
53+
5054
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
5155
if coltype.rounds:
5256
timestamp = f"{value}::timestamp(6)"
@@ -76,7 +80,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str:
7680

7781
@attrs.define(frozen=False, init=False, kw_only=True)
7882
class Redshift(PostgreSQL):
79-
dialect = Dialect()
83+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
8084
CONNECT_URI_HELP = "redshift://<user>:<password>@<host>/<database>"
8185
CONNECT_URI_PARAMS = ["database?"]
8286

data_diff/databases/snowflake.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Union, List
1+
from typing import Any, ClassVar, Union, List, Type
22
import logging
33

44
import attrs
@@ -76,6 +76,9 @@ def type_repr(self, t) -> str:
7676
def md5_as_int(self, s: str) -> str:
7777
return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK}) - {CHECKSUM_OFFSET}"
7878

79+
def md5_as_hex(self, s: str) -> str:
80+
return f"md5({s})"
81+
7982
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
8083
if coltype.rounds:
8184
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))"
@@ -93,7 +96,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
9396

9497
@attrs.define(frozen=False, init=False, kw_only=True)
9598
class Snowflake(Database):
96-
dialect = Dialect()
99+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
97100
CONNECT_URI_HELP = "snowflake://<user>:<password>@<account>/<database>/<SCHEMA>?warehouse=<WAREHOUSE>"
98101
CONNECT_URI_PARAMS = ["database", "schema"]
99102
CONNECT_URI_KWPARAMS = ["warehouse"]

0 commit comments

Comments
 (0)