Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 978d9b0

Browse files
authored
Merge pull request #614 from datafold/DX-721
log improvements and display diff status
2 parents 57c150d + 6838c46 commit 978d9b0

File tree

5 files changed

+133
-53
lines changed

5 files changed

+133
-53
lines changed

data_diff/__main__.py

+33-8
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@
66
import json
77
import logging
88
from itertools import islice
9-
from typing import Optional
9+
from typing import Dict, Optional
1010

1111
import rich
12+
from rich.logging import RichHandler
1213
import click
1314

1415
from data_diff.sqeleton.schema import create_schema
1516
from data_diff.sqeleton.queries.api import current_timestamp
1617

1718
from .dbt import dbt_diff
18-
from .utils import eval_name_template, remove_password_from_url, safezip, match_like
19+
from .utils import eval_name_template, remove_password_from_url, safezip, match_like, LogStatusHandler
1920
from .diff_tables import Algorithm
2021
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
2122
from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
@@ -27,9 +28,6 @@
2728
from .version import __version__
2829

2930

30-
LOG_FORMAT = "[%(asctime)s] %(levelname)s - %(message)s"
31-
DATE_FORMAT = "%H:%M:%S"
32-
3331
COLOR_SCHEME = {
3432
"+": "green",
3533
"-": "red",
@@ -38,6 +36,28 @@
3836
set_entrypoint_name("CLI")
3937

4038

39+
def _get_log_handlers(is_dbt: Optional[bool] = False) -> Dict[str, logging.Handler]:
40+
handlers = {}
41+
date_format = "%H:%M:%S"
42+
log_format_rich = "%(message)s"
43+
44+
# limits to 100 characters arbitrarily
45+
log_format_status = "%(message).100s"
46+
rich_handler = RichHandler(rich_tracebacks=True)
47+
rich_handler.setFormatter(logging.Formatter(log_format_rich, datefmt=date_format))
48+
rich_handler.setLevel(logging.WARN)
49+
handlers["rich_handler"] = rich_handler
50+
51+
# only use log_status_handler in a terminal
52+
if rich_handler.console.is_terminal and is_dbt:
53+
log_status_handler = LogStatusHandler()
54+
log_status_handler.setFormatter(logging.Formatter(log_format_status, datefmt=date_format))
55+
log_status_handler.setLevel(logging.DEBUG)
56+
handlers["log_status_handler"] = log_status_handler
57+
58+
return handlers
59+
60+
4161
def _remove_passwords_in_dict(d: dict):
4262
for k, v in d.items():
4363
if k == "password":
@@ -244,6 +264,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
244264
help="Specify manifest to utilize for 'prod' comparison paths instead of using configuration.",
245265
)
246266
def main(conf, run, **kw):
267+
log_handlers = _get_log_handlers(kw["dbt"])
247268
if kw["table2"] is None and kw["database2"]:
248269
# Use the "database table table" form
249270
kw["table2"] = kw["database2"]
@@ -263,15 +284,18 @@ def main(conf, run, **kw):
263284
kw["debug"] = True
264285

265286
if kw["debug"]:
266-
logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT)
287+
log_handlers["rich_handler"].setLevel(logging.DEBUG)
288+
logging.basicConfig(level=logging.DEBUG, handlers=list(log_handlers.values()))
267289
if kw.get("__conf__"):
268290
kw["__conf__"] = deepcopy(kw["__conf__"])
269291
_remove_passwords_in_dict(kw["__conf__"])
270292
logging.debug(f"Applied run configuration: {kw['__conf__']}")
271293
elif kw.get("verbose"):
272-
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT)
294+
log_handlers["rich_handler"].setLevel(logging.INFO)
295+
logging.basicConfig(level=logging.DEBUG, handlers=list(log_handlers.values()))
273296
else:
274-
logging.basicConfig(level=logging.WARNING, format=LOG_FORMAT, datefmt=DATE_FORMAT)
297+
log_handlers["rich_handler"].setLevel(logging.WARNING)
298+
logging.basicConfig(level=logging.DEBUG, handlers=list(log_handlers.values()))
275299

276300
try:
277301
state = kw.pop("state", None)
@@ -285,6 +309,7 @@ def main(conf, run, **kw):
285309
project_dir_override = os.path.expanduser(project_dir_override)
286310
if kw["dbt"]:
287311
dbt_diff(
312+
log_status_handler=log_handlers.get("log_status_handler"),
288313
profiles_dir_override=profiles_dir_override,
289314
project_dir_override=project_dir_override,
290315
is_cloud=kw["cloud"],

data_diff/cloud/datafold_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def poll_data_diff_results(self, diff_id: int) -> TCloudApiDataDiffSummaryResult
246246

247247
diff_url = f"{self.host}/datadiffs/{diff_id}/overview"
248248
while not summary_results:
249-
logger.debug(f"Polling: {diff_url}")
249+
logger.debug("Polling Datafold for results...")
250250
response = self.make_get_request(url=f"api/v1/datadiffs/{diff_id}/summary_results")
251251
response_json = response.json()
252252
if response_json["status"] == "success":

data_diff/dbt.py

+61-43
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import nullcontext
12
import json
23
import os
34
import re
@@ -42,6 +43,7 @@
4243
run_as_daemon,
4344
truncate_error,
4445
print_version_info,
46+
LogStatusHandler,
4547
)
4648

4749
logger = getLogger(__name__)
@@ -67,6 +69,7 @@ def dbt_diff(
6769
dbt_selection: Optional[str] = None,
6870
json_output: bool = False,
6971
state: Optional[str] = None,
72+
log_status_handler: Optional[LogStatusHandler] = None,
7073
where_flag: Optional[str] = None,
7174
) -> None:
7275
print_version_info()
@@ -89,7 +92,6 @@ def dbt_diff(
8992
if not api:
9093
return
9194
org_meta = api.get_org_meta()
92-
9395
if config.datasource_id is None:
9496
rich.print("[red]Data source ID not found in dbt_project.yml")
9597
raise DataDiffNoDatasourceIdError(
@@ -103,48 +105,54 @@ def dbt_diff(
103105
else:
104106
dbt_parser.set_connection()
105107

106-
for model in models:
107-
diff_vars = _get_diff_vars(dbt_parser, config, model, where_flag)
108-
109-
# we won't always have a prod path when using state
110-
# when the model DNE in prod manifest, skip the model diff
111-
if (
112-
state and len(diff_vars.prod_path) < 2
113-
): # < 2 because some providers like databricks can legitimately have *only* 2
114-
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
115-
diff_output_str += "[green]New model: nothing to diff![/] \n"
116-
rich.print(diff_output_str)
117-
continue
118-
119-
if diff_vars.primary_keys:
120-
if is_cloud:
121-
diff_thread = run_as_daemon(_cloud_diff, diff_vars, config.datasource_id, api, org_meta)
122-
diff_threads.append(diff_thread)
123-
else:
124-
_local_diff(diff_vars, json_output)
125-
else:
126-
if json_output:
127-
print(
128-
json.dumps(
129-
jsonify_error(
130-
table1=diff_vars.prod_path,
131-
table2=diff_vars.dev_path,
132-
dbt_model=diff_vars.dbt_model,
133-
error="No primary key found. Add uniqueness tests, meta, or tags.",
134-
)
135-
),
136-
flush=True,
137-
)
108+
with log_status_handler.status if log_status_handler else nullcontext():
109+
for model in models:
110+
if log_status_handler:
111+
log_status_handler.set_prefix(f"Diffing {model.alias} \n")
112+
113+
diff_vars = _get_diff_vars(dbt_parser, config, model, where_flag)
114+
115+
# we won't always have a prod path when using state
116+
# when the model DNE in prod manifest, skip the model diff
117+
if (
118+
state and len(diff_vars.prod_path) < 2
119+
): # < 2 because some providers like databricks can legitimately have *only* 2
120+
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
121+
diff_output_str += "[green]New model: nothing to diff![/] \n"
122+
rich.print(diff_output_str)
123+
continue
124+
125+
if diff_vars.primary_keys:
126+
if is_cloud:
127+
diff_thread = run_as_daemon(
128+
_cloud_diff, diff_vars, config.datasource_id, api, org_meta, log_status_handler
129+
)
130+
diff_threads.append(diff_thread)
131+
else:
132+
_local_diff(diff_vars, json_output)
138133
else:
139-
rich.print(
140-
_diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
141-
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
142-
)
143-
144-
# wait for all threads
145-
if diff_threads:
146-
for thread in diff_threads:
147-
thread.join()
134+
if json_output:
135+
print(
136+
json.dumps(
137+
jsonify_error(
138+
table1=diff_vars.prod_path,
139+
table2=diff_vars.dev_path,
140+
dbt_model=diff_vars.dbt_model,
141+
error="No primary key found. Add uniqueness tests, meta, or tags.",
142+
)
143+
),
144+
flush=True,
145+
)
146+
else:
147+
rich.print(
148+
_diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
149+
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
150+
)
151+
152+
# wait for all threads
153+
if diff_threads:
154+
for thread in diff_threads:
155+
thread.join()
148156

149157

150158
def _get_diff_vars(
@@ -348,7 +356,15 @@ def _initialize_api() -> Optional[DatafoldAPI]:
348356
return DatafoldAPI(api_key=api_key, host=datafold_host)
349357

350358

351-
def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_meta: TCloudApiOrgMeta) -> None:
359+
def _cloud_diff(
360+
diff_vars: TDiffVars,
361+
datasource_id: int,
362+
api: DatafoldAPI,
363+
org_meta: TCloudApiOrgMeta,
364+
log_status_handler: Optional[LogStatusHandler] = None,
365+
) -> None:
366+
if log_status_handler:
367+
log_status_handler.cloud_diff_started(diff_vars.dev_path[-1])
352368
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
353369
payload = TCloudApiDataDiff(
354370
data_source1_id=datasource_id,
@@ -417,6 +433,8 @@ def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_
417433
diff_output_str += f"\n{diff_url}\n{no_differences_template()}\n"
418434
rich.print(diff_output_str)
419435

436+
if log_status_handler:
437+
log_status_handler.cloud_diff_finished(diff_vars.dev_path[-1])
420438
except BaseException as ex: # Catch KeyboardInterrupt too
421439
error = ex
422440
finally:

data_diff/utils.py

+37
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import requests
1111
from tabulate import tabulate
1212
from .version import __version__
13+
from rich.status import Status
1314

1415

1516
def safezip(*args):
@@ -211,3 +212,39 @@ def print_version_info() -> None:
211212
print(f"{base_version_string} (Update {latest_version} is available!)")
212213
else:
213214
print(base_version_string)
215+
216+
217+
class LogStatusHandler(logging.Handler):
218+
"""
219+
This log handler can be used to update a rich.status every time a log is emitted.
220+
"""
221+
222+
def __init__(self):
223+
super().__init__()
224+
self.status = Status("")
225+
self.prefix = ""
226+
self.cloud_diff_status = {}
227+
228+
def emit(self, record):
229+
log_entry = self.format(record)
230+
if self.cloud_diff_status:
231+
self._update_cloud_status(log_entry)
232+
else:
233+
self.status.update(self.prefix + log_entry)
234+
235+
def set_prefix(self, prefix_string):
236+
self.prefix = prefix_string
237+
238+
def cloud_diff_started(self, model_name):
239+
self.cloud_diff_status[model_name] = "[yellow]In Progress[/]"
240+
self._update_cloud_status()
241+
242+
def cloud_diff_finished(self, model_name):
243+
self.cloud_diff_status[model_name] = "[green]Finished [/]"
244+
self._update_cloud_status()
245+
246+
def _update_cloud_status(self, log=None):
247+
cloud_status_string = "\n"
248+
for model_name, status in self.cloud_diff_status.items():
249+
cloud_status_string += f"{status} {model_name}\n"
250+
self.status.update(f"{cloud_status_string}{log or ''}")

tests/test_dbt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def test_diff_is_cloud(
274274

275275
mock_initialize_api.assert_called_once()
276276
mock_api.get_data_source.assert_called_once_with(1)
277-
mock_cloud_diff.assert_called_once_with(diff_vars, 1, mock_api, org_meta)
277+
mock_cloud_diff.assert_called_once_with(diff_vars, 1, mock_api, org_meta, None)
278278
mock_local_diff.assert_not_called()
279279
mock_print.assert_called_once()
280280

0 commit comments

Comments
 (0)