Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 55aa54c

Browse files
committed
Config initial
1 parent 92f5ed4 commit 55aa54c

File tree

4 files changed

+151
-38
lines changed

4 files changed

+151
-38
lines changed

data_diff/__main__.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
DEFAULT_BISECTION_THRESHOLD,
1111
DEFAULT_BISECTION_FACTOR,
1212
)
13-
from .databases.connect import connect_to_uri
13+
from .databases.connect import connect
1414
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
15+
from .config import apply_config
1516

1617
import rich
1718
import click
@@ -26,19 +27,19 @@
2627

2728

2829
@click.command()
29-
@click.argument("db1_uri")
30-
@click.argument("table1_name")
31-
@click.argument("db2_uri")
32-
@click.argument("table2_name")
33-
@click.option("-k", "--key-column", default="id", help="Name of primary key column")
30+
@click.argument("database1", required=False)
31+
@click.argument("table1", required=False)
32+
@click.argument("database2", required=False)
33+
@click.argument("table2", required=False)
34+
@click.option("-k", "--key-column", default=None, help="Name of primary key column. Default='id'.")
3435
@click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column")
3536
@click.option("-c", "--columns", default=[], multiple=True, help="Names of extra columns to compare")
3637
@click.option("-l", "--limit", default=None, help="Maximum number of differences to find")
37-
@click.option("--bisection-factor", default=DEFAULT_BISECTION_FACTOR, help="Segments per iteration")
38+
@click.option("--bisection-factor", default=None, help=f"Segments per iteration. Default={DEFAULT_BISECTION_FACTOR}.")
3839
@click.option(
3940
"--bisection-threshold",
40-
default=DEFAULT_BISECTION_THRESHOLD,
41-
help="Minimal bisection threshold. Below it, data-diff will download the data and compare it locally.",
41+
default=None,
42+
help=f"Minimal bisection threshold. Below it, data-diff will download the data and compare it locally. Default={DEFAULT_BISECTION_THRESHOLD}.",
4243
)
4344
@click.option(
4445
"--min-age",
@@ -57,16 +58,23 @@
5758
@click.option(
5859
"-j",
5960
"--threads",
60-
default="1",
61+
default=None,
6162
help="Number of worker threads to use per database. Default=1. "
6263
"A higher number will increase performance, but take more capacity from your database. "
6364
"'serial' guarantees a single-threaded execution of the algorithm (useful for debugging).",
6465
)
65-
def main(
66-
db1_uri,
67-
table1_name,
68-
db2_uri,
69-
table2_name,
66+
@click.option("--conf", default=None, help="Path to a configuration.toml file, to provide a default configuration, and a list of possible runs.")
67+
@click.option("--run", default=None, help="Name of run-configuration to run. If used, CLI arguments for database and table must be omitted.")
68+
def main(conf, run, **kw):
69+
if conf:
70+
kw = apply_config(conf, run, kw)
71+
return _main(**kw)
72+
73+
def _main(
74+
database1,
75+
table1,
76+
database2,
77+
table2,
7078
key_column,
7179
update_column,
7280
columns,
@@ -82,35 +90,50 @@ def main(
8290
threads,
8391
keep_column_case,
8492
json_output,
93+
threads1=None,
94+
threads2=None,
95+
__conf__=None,
8596
):
86-
if limit and stats:
87-
print("Error: cannot specify a limit when using the -s/--stats switch")
88-
return
97+
8998
if interactive:
9099
debug = True
91100

92101
if debug:
93102
logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT)
103+
if __conf__:
104+
logging.debug(f"Applied run configuration: {__conf__}")
94105
elif verbose:
95106
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT)
96107

108+
if limit and stats:
109+
logging.error("Cannot specify a limit when using the -s/--stats switch")
110+
return
111+
112+
key_column = key_column or 'id'
113+
if bisection_factor is None:
114+
bisection_factor = DEFAULT_BISECTION_FACTOR
115+
if bisection_threshold is None:
116+
bisection_threshold = DEFAULT_BISECTION_THRESHOLD
117+
97118
threaded = True
98-
if threads is not None:
99-
if threads.lower() == "serial":
100-
threaded = False
101-
threads = 1
102-
else:
103-
try:
104-
threads = int(threads)
105-
except ValueError:
106-
logging.error("Error: threads must be a number, 'auto', or 'serial'.")
107-
return
108-
if threads < 1:
109-
logging.error("Error: threads must be >= 1")
110-
return
111-
112-
db1 = connect_to_uri(db1_uri, threads)
113-
db2 = connect_to_uri(db2_uri, threads)
119+
if threads is None:
120+
threads = 1
121+
elif isinstance(threads, str) and threads.lower() == "serial":
122+
assert not (threads1 or threads2)
123+
threaded = False
124+
threads = 1
125+
else:
126+
try:
127+
threads = int(threads)
128+
except ValueError:
129+
logging.error("Error: threads must be a number, or 'serial'.")
130+
return
131+
if threads < 1:
132+
logging.error("Error: threads must be >= 1")
133+
return
134+
135+
db1 = connect(database1, threads1 or threads)
136+
db2 = connect(database2, threads2 or threads)
114137

115138
if interactive:
116139
db1.enable_interactive()
@@ -128,8 +151,8 @@ def main(
128151
logging.error("Error while parsing age expression: %s" % e)
129152
return
130153

131-
table1 = TableSegment(db1, db1.parse_table_name(table1_name), key_column, update_column, columns, **options)
132-
table2 = TableSegment(db2, db2.parse_table_name(table2_name), key_column, update_column, columns, **options)
154+
table1_seg = TableSegment(db1, db1.parse_table_name(table1), key_column, update_column, columns, **options)
155+
table2_seg = TableSegment(db2, db2.parse_table_name(table2), key_column, update_column, columns, **options)
133156

134157
differ = TableDiffer(
135158
bisection_factor=bisection_factor,
@@ -138,7 +161,7 @@ def main(
138161
max_threadpool_size=threads and threads * 2,
139162
debug=debug,
140163
)
141-
diff_iter = differ.diff_tables(table1, table2)
164+
diff_iter = differ.diff_tables(table1_seg, table2_seg)
142165

143166
if limit:
144167
diff_iter = islice(diff_iter, int(limit))

data_diff/config.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import toml
2+
import logging
3+
4+
class ConfigParseError(Exception):
5+
pass
6+
7+
8+
def is_uri(s: str) -> bool:
9+
return '://' in s
10+
11+
12+
13+
def _load_config(path):
14+
with open(path) as f:
15+
return toml.load(f)
16+
17+
def apply_config(path, run_name, kw):
18+
# Load config
19+
config = _load_config(path)
20+
databases = config.pop('database', {})
21+
runs = config.pop('run', {})
22+
if config:
23+
raise ConfigParseError(f"Unknown option(s): {config}")
24+
25+
# Init run_args
26+
run_args = runs.get('default') or {}
27+
if run_name:
28+
if run_name not in runs:
29+
raise ConfigParseError(f"Cannot find run '{run_name}' in configuration '{path}'.")
30+
run_args.update(runs[run_name])
31+
else:
32+
run_name = 'default'
33+
34+
# Process databases + tables
35+
for index in '12':
36+
args = run_args.pop(index, {})
37+
for attr in ('database', 'table'):
38+
if attr not in args:
39+
raise ConfigParseError(f"Running 'run.{run_name}': Connection #{index} in missing attribute '{attr}'.")
40+
41+
database = args.pop('database')
42+
table = args.pop('table')
43+
threads = args.pop('threads', None)
44+
if args:
45+
raise ConfigParseError(f"Unexpected attributes for connection #{index}: {args}")
46+
47+
if not is_uri(database):
48+
if database not in databases:
49+
raise ConfigParseError(f"Database '{database}' not found in list of databases. Available: {list(databases)}.")
50+
database = dict(databases[database])
51+
assert isinstance(database, dict)
52+
if 'driver' not in database:
53+
raise ConfigParseError(f"Database '{database}' did not specify a driver.")
54+
55+
run_args[f'database{index}'] = database
56+
run_args[f'table{index}'] = table
57+
if threads is not None:
58+
run_args[f'threads{index}'] = int(threads)
59+
60+
# Update keywords
61+
new_kw = dict(kw) # Set defaults
62+
new_kw.update(run_args) # Apply config
63+
new_kw.update({k:v for k, v in kw.items() if v}) # Apply non-empty defaults
64+
65+
new_kw['__conf__'] = run_args
66+
67+
return new_kw

data_diff/databases/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def __init__(self, thread_count=1):
249249
self._init_error = None
250250
self._queue = ThreadPoolExecutor(thread_count, initializer=self.set_conn)
251251
self.thread_local = threading.local()
252+
logger.info(f"[{self.name}] Starting a threadpool, size={thread_count}.")
252253

253254
def set_conn(self):
254255
assert not hasattr(self.thread_local, "conn")

data_diff/databases/connect.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,32 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
128128
kw["host"] = dsn.host
129129
kw["port"] = dsn.port
130130
kw["user"] = dsn.user
131-
kw["password"] = dsn.password
131+
if dsn.password:
132+
kw["password"] = dsn.password
132133
kw = {k: v for k, v in kw.items() if v is not None}
133134

134135
if issubclass(cls, ThreadedDatabase):
135136
return cls(thread_count=thread_count, **kw)
136137

137138
return cls(**kw)
139+
140+
def connect_with_dict(d, thread_count):
141+
d = dict(d)
142+
driver = d.pop('driver')
143+
try:
144+
matcher = MATCH_URI_PATH[driver]
145+
except KeyError:
146+
raise NotImplementedError(f"Driver {driver} currently not supported")
147+
148+
cls = matcher.database_cls
149+
if issubclass(cls, ThreadedDatabase):
150+
return cls(thread_count=thread_count, **d)
151+
152+
return cls(**d)
153+
154+
def connect(x, thread_count):
155+
if isinstance(x, str):
156+
return connect_to_uri(x, thread_count)
157+
elif isinstance(x, dict):
158+
return connect_with_dict(x, thread_count)
159+
raise RuntimeError(x)

0 commit comments

Comments
 (0)