Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit ada9306

Browse files
authored
Merge branch 'master' into unreachable_branch_validate_adjust
2 parents e61815b + 0b74046 commit ada9306

18 files changed

+1189
-1004
lines changed

.github/workflows/ci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ jobs:
6868
DATADIFF_CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse'
6969
DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica'
7070
DATADIFF_REDSHIFT_URI: '${{ secrets.DATADIFF_REDSHIFT_URI }}'
71+
MOTHERDUCK_TOKEN: '${{ secrets.MOTHERDUCK_TOKEN }}'
7172
run: |
7273
chmod +x tests/waiting_for_stack_up.sh
7374
./tests/waiting_for_stack_up.sh && TEST_ACROSS_ALL_DBS=0 poetry run unittest-parallel -j 16

.github/workflows/ci_full.yml

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ jobs:
6464
DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica'
6565
# DATADIFF_BIGQUERY_URI: '${{ secrets.DATADIFF_BIGQUERY_URI }}'
6666
DATADIFF_REDSHIFT_URI: '${{ secrets.DATADIFF_REDSHIFT_URI }}'
67+
MOTHERDUCK_TOKEN: '${{ secrets.MOTHERDUCK_TOKEN }}'
6768
run: |
6869
chmod +x tests/waiting_for_stack_up.sh
6970
./tests/waiting_for_stack_up.sh && poetry run unittest-parallel -j 16

data_diff/databases/_connect.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from data_diff.databases.mssql import MsSQL
2727

2828

29-
@attrs.define(frozen=True)
29+
@attrs.frozen
3030
class MatchUriPath:
3131
database_cls: Type[Database]
3232

@@ -98,13 +98,11 @@ class Connect:
9898
"""Provides methods for connecting to a supported database using a URL or connection dict."""
9999

100100
database_by_scheme: Dict[str, Database]
101-
match_uri_path: Dict[str, MatchUriPath]
102101
conn_cache: MutableMapping[Hashable, Database]
103102

104103
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
105104
super().__init__()
106105
self.database_by_scheme = database_by_scheme
107-
self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()}
108106
self.conn_cache = weakref.WeakValueDictionary()
109107

110108
def for_databases(self, *dbs) -> Self:
@@ -157,12 +155,10 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs)
157155
return self.connect_with_dict(conn_dict, thread_count, **kwargs)
158156

159157
try:
160-
matcher = self.match_uri_path[scheme]
158+
cls = self.database_by_scheme[scheme]
161159
except KeyError:
162160
raise NotImplementedError(f"Scheme '{scheme}' currently not supported")
163161

164-
cls = matcher.database_cls
165-
166162
if scheme == "databricks":
167163
assert not dsn.user
168164
kw = {}
@@ -175,6 +171,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs)
175171
kw["filepath"] = dsn.dbname
176172
kw["dbname"] = dsn.user
177173
else:
174+
matcher = MatchUriPath(cls)
178175
kw = matcher.match_path(dsn)
179176

180177
if scheme == "bigquery":
@@ -198,7 +195,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs)
198195

199196
kw = {k: v for k, v in kw.items() if v is not None}
200197

201-
if issubclass(cls, ThreadedDatabase):
198+
if isinstance(cls, type) and issubclass(cls, ThreadedDatabase):
202199
db = cls(thread_count=thread_count, **kw, **kwargs)
203200
else:
204201
db = cls(**kw, **kwargs)
@@ -209,11 +206,10 @@ def connect_with_dict(self, d, thread_count, **kwargs):
209206
d = dict(d)
210207
driver = d.pop("driver")
211208
try:
212-
matcher = self.match_uri_path[driver]
209+
cls = self.database_by_scheme[driver]
213210
except KeyError:
214211
raise NotImplementedError(f"Driver '{driver}' currently not supported")
215212

216-
cls = matcher.database_cls
217213
if issubclass(cls, ThreadedDatabase):
218214
db = cls(thread_count=thread_count, **d, **kwargs)
219215
else:

data_diff/databases/base.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -1093,11 +1093,7 @@ def _refine_coltypes(
10931093
list,
10941094
log_message=table_path,
10951095
)
1096-
if not samples_by_row:
1097-
raise ValueError(f"Table {table_path} is empty.")
1098-
1099-
samples_by_col = list(zip(*samples_by_row))
1100-
1096+
samples_by_col = list(zip(*samples_by_row)) if samples_by_row else [[]] * len(text_columns)
11011097
for col_name, samples in safezip(text_columns, samples_by_col):
11021098
uuid_samples = [s for s in samples if s and is_uuid(s)]
11031099

data_diff/databases/duckdb.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, ClassVar, Dict, Union, Type
22

33
import attrs
4+
from packaging.version import parse as parse_version
45

56
from data_diff.utils import match_regexps
67
from data_diff.abcs.database_types import (
@@ -27,6 +28,7 @@
2728
CHECKSUM_OFFSET,
2829
)
2930
from data_diff.databases.base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS
31+
from data_diff.version import __version__
3032

3133

3234
@import_helper("duckdb")
@@ -148,9 +150,21 @@ def close(self):
148150
def create_connection(self):
149151
ddb = import_duckdb()
150152
try:
151-
return ddb.connect(self._args["filepath"])
153+
# custom_user_agent is only available in duckdb >= 0.9.2
154+
if parse_version(ddb.__version__) >= parse_version("0.9.2"):
155+
custom_user_agent = f"data-diff/v{__version__}"
156+
config = {"custom_user_agent": custom_user_agent}
157+
connection = ddb.connect(database=self._args["filepath"], config=config)
158+
custom_user_agent_results = connection.sql("PRAGMA USER_AGENT;").fetchall()
159+
custom_user_agent_filtered = custom_user_agent_results[0][0]
160+
assert custom_user_agent in custom_user_agent_filtered
161+
else:
162+
connection = ddb.connect(database=self._args["filepath"])
163+
return connection
152164
except ddb.OperationalError as e:
153165
raise ConnectError(*e.args) from e
166+
except AssertionError:
167+
raise ConnectError("Assertion failed: Custom user agent is invalid.") from None
154168

155169
def select_table_schema(self, path: DbPath) -> str:
156170
database, schema, table = self._normalize_table_path(path)

data_diff/databases/mssql.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,15 @@ def limit_select(
119119
) -> str:
120120
if offset:
121121
raise NotImplementedError("No support for OFFSET in query")
122-
123122
result = ""
124123
if not has_order_by:
125124
result += "ORDER BY 1"
126125

127126
result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY"
128-
return f"SELECT * FROM ({select_query}) AS LIMITED_SELECT {result}"
127+
128+
# mssql requires that subquery columns are all aliased, so
129+
# don't wrap in an outer select
130+
return f"{select_query} {result}"
129131

130132
def constant_values(self, rows) -> str:
131133
values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)

data_diff/databases/redshift.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,38 @@ def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]:
122122

123123
return schema_dict
124124

125+
def select_svv_columns_schema(self, path: DbPath) -> Dict[str, tuple]:
126+
database, schema, table = self._normalize_table_path(path)
127+
128+
db_clause = ""
129+
if database:
130+
db_clause = f" AND table_catalog = '{database.lower()}'"
131+
132+
return (
133+
f"""
134+
select
135+
distinct
136+
column_name,
137+
data_type,
138+
datetime_precision,
139+
numeric_precision,
140+
numeric_scale
141+
from
142+
svv_columns
143+
where table_name = '{table.lower()}' and table_schema = '{schema.lower()}'
144+
"""
145+
+ db_clause
146+
)
147+
148+
def query_svv_columns(self, path: DbPath) -> Dict[str, tuple]:
149+
rows = self.query(self.select_svv_columns_schema(path), list)
150+
if not rows:
151+
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
152+
153+
d = {r[0]: r for r in rows}
154+
assert len(d) == len(rows)
155+
return d
156+
125157
# when using a non-information_schema source, strip (N) from type(N) etc. to match
126158
# typical information_schema output
127159
def _normalize_schema_info(self, rows) -> Dict[str, tuple]:
@@ -150,7 +182,10 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
150182
try:
151183
return self.query_external_table_schema(path)
152184
except RuntimeError:
153-
return self.query_pg_get_cols(path)
185+
try:
186+
return self.query_pg_get_cols(path)
187+
except Exception:
188+
return self.query_svv_columns(path)
154189

155190
def _normalize_table_path(self, path: DbPath) -> DbPath:
156191
if len(path) == 1:

data_diff/databases/snowflake.py

+25-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, ClassVar, Union, List, Type
1+
import base64
2+
from typing import Any, ClassVar, Union, List, Type, Optional
23
import logging
34

45
import attrs
@@ -103,7 +104,7 @@ class Snowflake(Database):
103104

104105
_conn: Any
105106

106-
def __init__(self, *, schema: str, **kw):
107+
def __init__(self, *, schema: str, key: Optional[str] = None, key_content: Optional[str] = None, **kw):
107108
super().__init__()
108109
snowflake, serialization, default_backend = import_snowflake()
109110
logging.getLogger("snowflake.connector").setLevel(logging.WARNING)
@@ -113,20 +114,29 @@ def __init__(self, *, schema: str, **kw):
113114
logging.getLogger("snowflake.connector.network").disabled = True
114115

115116
assert '"' not in schema, "Schema name should not contain quotes!"
117+
if key_content and key:
118+
raise ConnectError("Only key value or key file path can be specified, not both")
119+
120+
key_bytes = None
121+
if key:
122+
with open(key, "rb") as f:
123+
key_bytes = f.read()
124+
if key_content:
125+
key_bytes = base64.b64decode(key_content)
126+
116127
# If a private key is used, read it from the specified path and pass it as "private_key" to the connector.
117-
if "key" in kw:
118-
with open(kw.get("key"), "rb") as key:
119-
if "password" in kw:
120-
raise ConnectError("Cannot use password and key at the same time")
121-
if kw.get("private_key_passphrase"):
122-
encoded_passphrase = kw.get("private_key_passphrase").encode()
123-
else:
124-
encoded_passphrase = None
125-
p_key = serialization.load_pem_private_key(
126-
key.read(),
127-
password=encoded_passphrase,
128-
backend=default_backend(),
129-
)
128+
if key_bytes:
129+
if "password" in kw:
130+
raise ConnectError("Cannot use password and key at the same time")
131+
if kw.get("private_key_passphrase"):
132+
encoded_passphrase = kw.get("private_key_passphrase").encode()
133+
else:
134+
encoded_passphrase = None
135+
p_key = serialization.load_pem_private_key(
136+
key_bytes,
137+
password=encoded_passphrase,
138+
backend=default_backend(),
139+
)
130140

131141
kw["private_key"] = p_key.private_bytes(
132142
encoding=serialization.Encoding.DER,

data_diff/hashdiff_tables.py

-8
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,6 @@ def _validate_and_adjust_columns(self, table1: TableSegment, table2: TableSegmen
118118
if lowest.precision != col2.precision:
119119
table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision)
120120

121-
elif isinstance(col1, ColType_UUID):
122-
if strict and not isinstance(col2, ColType_UUID):
123-
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
124-
125-
elif isinstance(col1, StringType):
126-
if strict and not isinstance(col2, StringType):
127-
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
128-
129121
for t in [table1, table2]:
130122
for c in t.relevant_columns:
131123
ctype = t._schema[c]

data_diff/joindiff_tables.py

-1
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,6 @@ def _count_diff_per_column(
343343
table1: Optional[TableSegment] = None,
344344
table2: Optional[TableSegment] = None,
345345
):
346-
logger.info(type(table1))
347346
logger.debug(f"Counting differences per column: {table1.table_path} <> {table2.table_path}")
348347
is_diff_cols_counts = db.query(
349348
diff_rows.select(sum_(this[c]) for c in is_diff_cols),

data_diff/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.9.17"
1+
__version__ = "0.10.0rc0"

0 commit comments

Comments
 (0)