Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 63948c5

Browse files
authored
Merge branch 'master' into redacted_log
2 parents fc36c70 + 4afeaf0 commit 63948c5

File tree

9 files changed

+160
-31
lines changed

9 files changed

+160
-31
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ $ data-diff \
143143
If a database is not on the list, we'd still love to support it. Open an issue
144144
to discuss it.
145145

146+
Note: Because URLs allow many special characters, and may collide with the syntax of your command-line,
147+
it's recommended to surround them with quotes. Alternatively, you may provide them in a TOML file via the `--config` option.
148+
149+
146150
# How to install
147151

148152
Requires Python 3.7+ with pip.

data_diff/__main__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,8 @@ def _main(
136136
return
137137

138138
key_column = key_column or "id"
139-
if bisection_factor is None:
140-
bisection_factor = DEFAULT_BISECTION_FACTOR
141-
if bisection_threshold is None:
142-
bisection_threshold = DEFAULT_BISECTION_THRESHOLD
139+
bisection_factor = DEFAULT_BISECTION_FACTOR if bisection_factor is None else int(bisection_factor)
140+
bisection_threshold = DEFAULT_BISECTION_THRESHOLD if bisection_threshold is None else int(bisection_threshold)
143141

144142
threaded = True
145143
if threads is None:

data_diff/databases/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType]):
205205
fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns]
206206
samples_by_row = self.query(Select(fields, TableName(table_path), limit=16), list)
207207
if not samples_by_row:
208-
logger.warning(f"Table {table_path} is empty.")
209-
return
208+
raise ValueError(f"Table {table_path} is empty.")
210209

211210
samples_by_col = list(zip(*samples_by_row))
212211

data_diff/databases/oracle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class Oracle(ThreadedDatabase):
2727
ROUNDS_ON_PREC_LOSS = True
2828

2929
def __init__(self, *, host, database, thread_count, **kw):
30-
self.kwargs = dict(dsn="%s/%s" % (host, database), **kw)
30+
self.kwargs = dict(dsn="%s/%s" % (host, database) if database else host, **kw)
3131

3232
self.default_schema = kw.get("user")
3333

data_diff/diff_tables.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections import defaultdict
99
from typing import List, Tuple, Iterator, Optional
1010
import logging
11-
from concurrent.futures import ThreadPoolExecutor
11+
from concurrent.futures import ThreadPoolExecutor, as_completed
1212

1313
from runtype import dataclass
1414

@@ -315,17 +315,16 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
315315
('-', columns) for items in table2 but not in table1
316316
Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra)
317317
"""
318+
# Validate options
318319
if self.bisection_factor >= self.bisection_threshold:
319320
raise ValueError("Incorrect param values (bisection factor must be lower than threshold)")
320321
if self.bisection_factor < 2:
321322
raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)")
322323

324+
# Query and validate schema
323325
table1, table2 = self._threaded_call("with_schema", [table1, table2])
324326
self._validate_and_adjust_columns(table1, table2)
325327

326-
key_ranges = self._threaded_call("query_key_range", [table1, table2])
327-
mins, maxs = zip(*key_ranges)
328-
329328
key_type = table1._schema[table1.key_column]
330329
key_type2 = table2._schema[table2.key_column]
331330
if not isinstance(key_type, IKey):
@@ -334,23 +333,42 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
334333
raise NotImplementedError(f"Cannot use column of type {key_type2} as a key")
335334
assert key_type.python_type is key_type2.python_type
336335

337-
# We add 1 because our ranges are exclusive of the end (like in Python)
338-
try:
339-
min_key = min(map(key_type.python_type, mins))
340-
max_key = max(map(key_type.python_type, maxs)) + 1
341-
except (TypeError, ValueError) as e:
342-
raise type(e)(f"Cannot apply {key_type} to {mins}, {maxs}.") from e
336+
# Query min/max values
337+
key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2])
343338

344-
table1 = table1.new(min_key=min_key, max_key=max_key)
345-
table2 = table2.new(min_key=min_key, max_key=max_key)
339+
# Start with the first completed value, so we don't waste time waiting
340+
min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges))
341+
342+
table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)]
346343

347344
logger.info(
348345
f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. "
349346
f"key-range: {table1.min_key}..{table2.max_key}, "
350347
f"size: {table2.max_key-table1.min_key}"
351348
)
352349

353-
return self._bisect_and_diff_tables(table1, table2)
350+
# Bisect (split) the table into segments, and diff them recursively.
351+
yield from self._bisect_and_diff_tables(table1, table2)
352+
353+
# Now we check for the second min-max, to diff the portions we "missed".
354+
min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges))
355+
356+
if min_key2 < min_key1:
357+
pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)]
358+
yield from self._bisect_and_diff_tables(*pre_tables)
359+
360+
if max_key2 > max_key1:
361+
post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)]
362+
yield from self._bisect_and_diff_tables(*post_tables)
363+
364+
def _parse_key_range_result(self, key_type, key_range):
365+
mn, mx = key_range
366+
cls = key_type.python_type
367+
# We add 1 because our ranges are exclusive of the end (like in Python)
368+
try:
369+
return cls(mn), cls(mx) + 1
370+
except (TypeError, ValueError) as e:
371+
raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e
354372

355373
def _validate_and_adjust_columns(self, table1, table2):
356374
for c in table1._relevant_columns:
@@ -474,12 +492,26 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun
474492
if checksum1 != checksum2:
475493
yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max(count1, count2))
476494

477-
def _thread_map(self, func, iter):
495+
def _thread_map(self, func, iterable):
496+
if not self.threaded:
497+
return map(func, iterable)
498+
499+
with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool:
500+
return task_pool.map(func, iterable)
501+
502+
def _threaded_call(self, func, iterable):
503+
"Calls a method for each object in iterable."
504+
return list(self._thread_map(methodcaller(func), iterable))
505+
506+
def _thread_as_completed(self, func, iterable):
478507
if not self.threaded:
479-
return map(func, iter)
508+
return map(func, iterable)
480509

481-
task_pool = ThreadPoolExecutor(max_workers=self.max_threadpool_size)
482-
return task_pool.map(func, iter)
510+
with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool:
511+
futures = [task_pool.submit(func, item) for item in iterable]
512+
for future in as_completed(futures):
513+
yield future.result()
483514

484-
def _threaded_call(self, func, iter):
485-
return list(self._thread_map(methodcaller(func), iter))
515+
def _threaded_call_as_completed(self, func, iterable):
516+
"Calls a method for each object in iterable. Returned in order of completion."
517+
return self._thread_as_completed(methodcaller(func), iterable)

data_diff/sql.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from runtype import dataclass
88

9+
from .utils import join_iter
10+
911
from .databases.database_types import AbstractDatabase, DbPath, DbKey, DbTime, ArithUUID
1012

1113

@@ -15,6 +17,8 @@ class Sql:
1517

1618
SqlOrStr = Union[Sql, str]
1719

20+
CONCAT_SEP = "|"
21+
1822

1923
@dataclass
2024
class Compiler:
@@ -122,7 +126,8 @@ class Checksum(Sql):
122126
def compile(self, c: Compiler):
123127
if len(self.exprs) > 1:
124128
compiled_exprs = [f"coalesce({c.compile(expr)}, '<null>')" for expr in self.exprs]
125-
expr = c.database.concat(compiled_exprs)
129+
separated = list(join_iter(f"'|'", compiled_exprs))
130+
expr = c.database.concat(separated)
126131
else:
127132
# No need to coalesce - safe to assume that key cannot be null
128133
(expr,) = self.exprs

data_diff/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,11 @@ def remove_password_from_url(url: str, replace_with: str="***") -> str:
6969
netloc = _join_if_any("@", filter(None, [account, host]))
7070
replaced = parsed._replace(netloc=netloc)
7171
return replaced.geturl()
72+
73+
def join_iter(joiner: Any, iterable: iter) -> iter:
74+
it = iter(iterable)
75+
yield next(it)
76+
for i in it:
77+
yield joiner
78+
yield i
79+

tests/test_cli.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import logging
2+
import unittest
3+
import preql
4+
import arrow
5+
import subprocess
6+
import sys
7+
8+
from data_diff import diff_tables, connect_to_table
9+
10+
from .common import TEST_MYSQL_CONN_STRING
11+
12+
13+
def run_datadiff_cli(*args):
14+
try:
15+
stdout = subprocess.check_output([sys.executable, "-m", "data_diff"] + list(args), stderr=subprocess.PIPE)
16+
except subprocess.CalledProcessError as e:
17+
logging.error(e.stderr)
18+
raise
19+
return stdout.splitlines()
20+
21+
22+
class TestCLI(unittest.TestCase):
23+
def setUp(self) -> None:
24+
self.preql = preql.Preql(TEST_MYSQL_CONN_STRING)
25+
self.preql(
26+
r"""
27+
table test_cli {
28+
datetime: datetime
29+
comment: string
30+
}
31+
commit()
32+
33+
func add(date, comment) {
34+
new test_cli(date, comment)
35+
}
36+
"""
37+
)
38+
self.now = now = arrow.get(self.preql.now())
39+
self.preql.add(now, "now")
40+
self.preql.add(now, self.now.shift(seconds=-10))
41+
self.preql.add(now, self.now.shift(seconds=-7))
42+
self.preql.add(now, self.now.shift(seconds=-6))
43+
44+
self.preql(
45+
r"""
46+
const table test_cli_2 = test_cli
47+
commit()
48+
"""
49+
)
50+
51+
self.preql.add(self.now.shift(seconds=-3), "3 seconds ago")
52+
self.preql.commit()
53+
54+
def tearDown(self) -> None:
55+
self.preql.run_statement("drop table if exists test_cli")
56+
self.preql.run_statement("drop table if exists test_cli_2")
57+
self.preql.commit()
58+
self.preql.close()
59+
60+
return super().tearDown()
61+
62+
def test_basic(self):
63+
diff = run_datadiff_cli(TEST_MYSQL_CONN_STRING, "test_cli", TEST_MYSQL_CONN_STRING, "test_cli_2")
64+
assert len(diff) == 1
65+
66+
def test_options(self):
67+
diff = run_datadiff_cli(
68+
TEST_MYSQL_CONN_STRING,
69+
"test_cli",
70+
TEST_MYSQL_CONN_STRING,
71+
"test_cli_2",
72+
"--bisection-factor",
73+
"16",
74+
"--bisection-threshold",
75+
"10000",
76+
"--limit",
77+
"5",
78+
"-t",
79+
"datetime",
80+
"--max-age",
81+
"1h",
82+
)
83+
assert len(diff) == 1

tests/test_diff_tables.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def test_get_values(self):
254254
table = self.table.with_schema()
255255

256256
self.assertEqual(1, table.count())
257-
concatted = str(id_) + time
257+
concatted = str(id_) + "|" + time
258258
self.assertEqual(str_to_checksum(concatted), table.count_and_checksum()[1])
259259

260260
def test_diff_small_tables(self):
@@ -405,7 +405,7 @@ def test_string_keys(self):
405405
f"INSERT INTO {self.table_src} VALUES ('unexpected', '<-- this bad value should not break us')", None
406406
)
407407

408-
self.assertRaises(ValueError, differ.diff_tables, self.a, self.b)
408+
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))
409409

410410

411411
@test_per_database
@@ -592,7 +592,7 @@ def setUp(self):
592592

593593
def test_right_table_empty(self):
594594
differ = TableDiffer()
595-
self.assertRaises(ValueError, differ.diff_tables, self.a, self.b)
595+
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))
596596

597597
def test_left_table_empty(self):
598598
queries = [
@@ -605,4 +605,4 @@ def test_left_table_empty(self):
605605
_commit(self.connection)
606606

607607
differ = TableDiffer()
608-
self.assertRaises(ValueError, differ.diff_tables, self.a, self.b)
608+
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))

0 commit comments

Comments
 (0)