Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 8f07c1c

Browse files
committed
squash add simple status
1 parent 40a785d commit 8f07c1c

File tree

5 files changed

+133
-53
lines changed

5 files changed

+133
-53
lines changed

data_diff/__main__.py

+33-8
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@
66
import json
77
import logging
88
from itertools import islice
9-
from typing import Optional
9+
from typing import Dict, Optional
1010

1111
import rich
12+
from rich.logging import RichHandler
1213
import click
1314

1415
from data_diff.sqeleton.schema import create_schema
1516
from data_diff.sqeleton.queries.api import current_timestamp
1617

1718
from .dbt import dbt_diff
18-
from .utils import eval_name_template, remove_password_from_url, safezip, match_like
19+
from .utils import eval_name_template, remove_password_from_url, safezip, match_like, LogStatusHandler
1920
from .diff_tables import Algorithm
2021
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
2122
from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
@@ -27,9 +28,6 @@
2728
from .version import __version__
2829

2930

30-
LOG_FORMAT = "[%(asctime)s] %(levelname)s - %(message)s"
31-
DATE_FORMAT = "%H:%M:%S"
32-
3331
COLOR_SCHEME = {
3432
"+": "green",
3533
"-": "red",
@@ -38,6 +36,28 @@
3836
set_entrypoint_name("CLI")
3937

4038

39+
def _get_log_handlers(is_dbt: Optional[bool] = False) -> Dict[str, logging.Handler]:
40+
handlers = {}
41+
date_format = "%H:%M:%S"
42+
log_format_rich = "%(message)s"
43+
44+
# limits to 100 characters arbitrarily
45+
log_format_status = "%(message).100s"
46+
rich_handler = RichHandler(rich_tracebacks=True)
47+
rich_handler.setFormatter(logging.Formatter(log_format_rich, datefmt=date_format))
48+
rich_handler.setLevel(logging.WARN)
49+
handlers["rich_handler"] = rich_handler
50+
51+
# only use log_status_handler in a terminal
52+
if rich_handler.console.is_terminal and is_dbt:
53+
log_status_handler = LogStatusHandler()
54+
log_status_handler.setFormatter(logging.Formatter(log_format_status, datefmt=date_format))
55+
log_status_handler.setLevel(logging.DEBUG)
56+
handlers["log_status_handler"] = log_status_handler
57+
58+
return handlers
59+
60+
4161
def _remove_passwords_in_dict(d: dict):
4262
for k, v in d.items():
4363
if k == "password":
@@ -244,6 +264,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
244264
help="Specify manifest to utilize for 'prod' comparison paths instead of using configuration.",
245265
)
246266
def main(conf, run, **kw):
267+
log_handlers = _get_log_handlers(kw["dbt"])
247268
if kw["table2"] is None and kw["database2"]:
248269
# Use the "database table table" form
249270
kw["table2"] = kw["database2"]
@@ -263,15 +284,18 @@ def main(conf, run, **kw):
263284
kw["debug"] = True
264285

265286
if kw["debug"]:
266-
logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT)
287+
log_handlers["rich_handler"].setLevel(logging.DEBUG)
288+
logging.basicConfig(level=logging.DEBUG, handlers=list(log_handlers.values()))
267289
if kw.get("__conf__"):
268290
kw["__conf__"] = deepcopy(kw["__conf__"])
269291
_remove_passwords_in_dict(kw["__conf__"])
270292
logging.debug(f"Applied run configuration: {kw['__conf__']}")
271293
elif kw.get("verbose"):
272-
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT)
294+
log_handlers["rich_handler"].setLevel(logging.INFO)
295+
logging.basicConfig(level=logging.DEBUG, handlers=list(log_handlers.values()))
273296
else:
274-
logging.basicConfig(level=logging.WARNING, format=LOG_FORMAT, datefmt=DATE_FORMAT)
297+
log_handlers["rich_handler"].setLevel(logging.WARNING)
298+
logging.basicConfig(level=logging.DEBUG, handlers=list(log_handlers.values()))
275299

276300
try:
277301
state = kw.pop("state", None)
@@ -285,6 +309,7 @@ def main(conf, run, **kw):
285309
project_dir_override = os.path.expanduser(project_dir_override)
286310
if kw["dbt"]:
287311
dbt_diff(
312+
log_status_handler=log_handlers.get("log_status_handler"),
288313
profiles_dir_override=profiles_dir_override,
289314
project_dir_override=project_dir_override,
290315
is_cloud=kw["cloud"],

data_diff/cloud/datafold_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def poll_data_diff_results(self, diff_id: int) -> TCloudApiDataDiffSummaryResult
246246

247247
diff_url = f"{self.host}/datadiffs/{diff_id}/overview"
248248
while not summary_results:
249-
logger.debug(f"Polling: {diff_url}")
249+
logger.debug("Polling Datafold for results...")
250250
response = self.make_get_request(url=f"api/v1/datadiffs/{diff_id}/summary_results")
251251
response_json = response.json()
252252
if response_json["status"] == "success":

data_diff/dbt.py

+61-43
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import nullcontext
12
import json
23
import os
34
import re
@@ -42,6 +43,7 @@
4243
run_as_daemon,
4344
truncate_error,
4445
print_version_info,
46+
LogStatusHandler,
4547
)
4648

4749
logger = getLogger(__name__)
@@ -67,6 +69,7 @@ def dbt_diff(
6769
dbt_selection: Optional[str] = None,
6870
json_output: bool = False,
6971
state: Optional[str] = None,
72+
log_status_handler: Optional[LogStatusHandler] = None,
7073
) -> None:
7174
print_version_info()
7275
diff_threads = []
@@ -88,7 +91,6 @@ def dbt_diff(
8891
if not api:
8992
return
9093
org_meta = api.get_org_meta()
91-
9294
if config.datasource_id is None:
9395
rich.print("[red]Data source ID not found in dbt_project.yml")
9496
raise DataDiffNoDatasourceIdError(
@@ -102,48 +104,54 @@ def dbt_diff(
102104
else:
103105
dbt_parser.set_connection()
104106

105-
for model in models:
106-
diff_vars = _get_diff_vars(dbt_parser, config, model)
107-
108-
# we won't always have a prod path when using state
109-
# when the model DNE in prod manifest, skip the model diff
110-
if (
111-
state and len(diff_vars.prod_path) < 2
112-
): # < 2 because some providers like databricks can legitimately have *only* 2
113-
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
114-
diff_output_str += "[green]New model: nothing to diff![/] \n"
115-
rich.print(diff_output_str)
116-
continue
117-
118-
if diff_vars.primary_keys:
119-
if is_cloud:
120-
diff_thread = run_as_daemon(_cloud_diff, diff_vars, config.datasource_id, api, org_meta)
121-
diff_threads.append(diff_thread)
122-
else:
123-
_local_diff(diff_vars, json_output)
124-
else:
125-
if json_output:
126-
print(
127-
json.dumps(
128-
jsonify_error(
129-
table1=diff_vars.prod_path,
130-
table2=diff_vars.dev_path,
131-
dbt_model=diff_vars.dbt_model,
132-
error="No primary key found. Add uniqueness tests, meta, or tags.",
133-
)
134-
),
135-
flush=True,
136-
)
107+
with log_status_handler.status if log_status_handler else nullcontext():
108+
for model in models:
109+
if log_status_handler:
110+
log_status_handler.set_prefix(f"Diffing {model.alias} \n")
111+
112+
diff_vars = _get_diff_vars(dbt_parser, config, model)
113+
114+
# we won't always have a prod path when using state
115+
# when the model DNE in prod manifest, skip the model diff
116+
if (
117+
state and len(diff_vars.prod_path) < 2
118+
): # < 2 because some providers like databricks can legitimately have *only* 2
119+
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
120+
diff_output_str += "[green]New model: nothing to diff![/] \n"
121+
rich.print(diff_output_str)
122+
continue
123+
124+
if diff_vars.primary_keys:
125+
if is_cloud:
126+
diff_thread = run_as_daemon(
127+
_cloud_diff, diff_vars, config.datasource_id, api, org_meta, log_status_handler
128+
)
129+
diff_threads.append(diff_thread)
130+
else:
131+
_local_diff(diff_vars, json_output)
137132
else:
138-
rich.print(
139-
_diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
140-
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
141-
)
142-
143-
# wait for all threads
144-
if diff_threads:
145-
for thread in diff_threads:
146-
thread.join()
133+
if json_output:
134+
print(
135+
json.dumps(
136+
jsonify_error(
137+
table1=diff_vars.prod_path,
138+
table2=diff_vars.dev_path,
139+
dbt_model=diff_vars.dbt_model,
140+
error="No primary key found. Add uniqueness tests, meta, or tags.",
141+
)
142+
),
143+
flush=True,
144+
)
145+
else:
146+
rich.print(
147+
_diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
148+
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
149+
)
150+
151+
# wait for all threads
152+
if diff_threads:
153+
for thread in diff_threads:
154+
thread.join()
147155

148156

149157
def _get_diff_vars(
@@ -345,7 +353,15 @@ def _initialize_api() -> Optional[DatafoldAPI]:
345353
return DatafoldAPI(api_key=api_key, host=datafold_host)
346354

347355

348-
def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_meta: TCloudApiOrgMeta) -> None:
356+
def _cloud_diff(
357+
diff_vars: TDiffVars,
358+
datasource_id: int,
359+
api: DatafoldAPI,
360+
org_meta: TCloudApiOrgMeta,
361+
log_status_handler: Optional[LogStatusHandler] = None,
362+
) -> None:
363+
if log_status_handler:
364+
log_status_handler.cloud_diff_started(diff_vars.dev_path[-1])
349365
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
350366
payload = TCloudApiDataDiff(
351367
data_source1_id=datasource_id,
@@ -414,6 +430,8 @@ def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_
414430
diff_output_str += f"\n{diff_url}\n{no_differences_template()}\n"
415431
rich.print(diff_output_str)
416432

433+
if log_status_handler:
434+
log_status_handler.cloud_diff_finished(diff_vars.dev_path[-1])
417435
except BaseException as ex: # Catch KeyboardInterrupt too
418436
error = ex
419437
finally:

data_diff/utils.py

+37
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import requests
1111
from tabulate import tabulate
1212
from .version import __version__
13+
from rich.status import Status
1314

1415

1516
def safezip(*args):
@@ -211,3 +212,39 @@ def print_version_info() -> None:
211212
print(f"{base_version_string} (Update {latest_version} is available!)")
212213
else:
213214
print(base_version_string)
215+
216+
217+
class LogStatusHandler(logging.Handler):
218+
"""
219+
This log handler can be used to update a rich.status every time a log is emitted.
220+
"""
221+
222+
def __init__(self):
223+
super().__init__()
224+
self.status = Status("")
225+
self.prefix = ""
226+
self.cloud_diff_status = {}
227+
228+
def emit(self, record):
229+
log_entry = self.format(record)
230+
if self.cloud_diff_status:
231+
self._update_cloud_status(log_entry)
232+
else:
233+
self.status.update(self.prefix + log_entry)
234+
235+
def set_prefix(self, prefix_string):
236+
self.prefix = prefix_string
237+
238+
def cloud_diff_started(self, model_name):
239+
self.cloud_diff_status[model_name] = "[yellow]In Progress[/]"
240+
self._update_cloud_status()
241+
242+
def cloud_diff_finished(self, model_name):
243+
self.cloud_diff_status[model_name] = "[green]Finished [/]"
244+
self._update_cloud_status()
245+
246+
def _update_cloud_status(self, log=None):
247+
cloud_status_string = "\n"
248+
for model_name, status in self.cloud_diff_status.items():
249+
cloud_status_string += f"{status} {model_name}\n"
250+
self.status.update(f"{cloud_status_string}{log or ''}")

tests/test_dbt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def test_diff_is_cloud(
274274

275275
mock_initialize_api.assert_called_once()
276276
mock_api.get_data_source.assert_called_once_with(1)
277-
mock_cloud_diff.assert_called_once_with(diff_vars, 1, mock_api, org_meta)
277+
mock_cloud_diff.assert_called_once_with(diff_vars, 1, mock_api, org_meta, None)
278278
mock_local_diff.assert_not_called()
279279
mock_print.assert_called_once()
280280

0 commit comments

Comments
 (0)