Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit ce52e8b

Browse files
authored
Merge pull request #511 from dave-connors-3/allow-dbt-selectors
Allow dbt selectors
2 parents 6cf709b + d10bf39 commit ce52e8b

File tree

6 files changed

+212
-88
lines changed

6 files changed

+212
-88
lines changed

data_diff/__main__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,13 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
228228
metavar="PATH",
229229
help="Which directory to look in for the dbt_project.yml file. Default is the current working directory and its parents.",
230230
)
231+
@click.option(
232+
"--select",
233+
"-s",
234+
default=None,
235+
metavar="PATH",
236+
help="select dbt resources to compare using dbt selection syntax",
237+
)
231238
def main(conf, run, **kw):
232239
if kw["table2"] is None and kw["database2"]:
233240
# Use the "database table table" form
@@ -264,6 +271,7 @@ def main(conf, run, **kw):
264271
profiles_dir_override=kw["dbt_profiles_dir"],
265272
project_dir_override=kw["dbt_project_dir"],
266273
is_cloud=kw["cloud"],
274+
dbt_selection=kw["select"],
267275
)
268276
else:
269277
return _data_diff(**kw)
@@ -306,6 +314,7 @@ def _data_diff(
306314
cloud,
307315
dbt_profiles_dir,
308316
dbt_project_dir,
317+
select,
309318
threads1=None,
310319
threads2=None,
311320
__conf__=None,

data_diff/dbt.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,6 @@
1818
logger = getLogger(__name__)
1919

2020

21-
def import_dbt():
22-
try:
23-
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
24-
from dbt.config.renderer import ProfileRenderer
25-
import yaml
26-
except ImportError:
27-
raise RuntimeError("Could not import 'dbt' package. You can install it using: pip install 'data-diff[dbt]'.")
28-
29-
return parse_run_results, parse_manifest, ProfileRenderer, yaml
30-
31-
3221
from .tracking import (
3322
set_entrypoint_name,
3423
set_dbt_user_id,
@@ -54,12 +43,15 @@ class DiffVars:
5443

5544

5645
def dbt_diff(
57-
profiles_dir_override: Optional[str] = None, project_dir_override: Optional[str] = None, is_cloud: bool = False
46+
profiles_dir_override: Optional[str] = None,
47+
project_dir_override: Optional[str] = None,
48+
is_cloud: bool = False,
49+
dbt_selection: Optional[str] = None,
5850
) -> None:
5951
diff_threads = []
6052
set_entrypoint_name("CLI-dbt")
6153
dbt_parser = DbtParser(profiles_dir_override, project_dir_override)
62-
models = dbt_parser.get_models()
54+
models = dbt_parser.get_models(dbt_selection)
6355
datadiff_variables = dbt_parser.get_datadiff_variables()
6456
config_prod_database = datadiff_variables.get("prod_database")
6557
config_prod_schema = datadiff_variables.get("prod_schema")

data_diff/dbt_parser.py

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from collections import defaultdict
22
import json
3+
import os
34
from pathlib import Path
4-
from typing import List, Dict, Tuple, Set
5+
from typing import List, Dict, Tuple, Set, Optional
56

67
from packaging.version import parse as parse_version
78

@@ -12,23 +13,34 @@
1213
logger = getLogger(__name__)
1314

1415

15-
def import_dbt():
16+
def import_dbt_dependencies():
1617
try:
1718
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
1819
from dbt.config.renderer import ProfileRenderer
1920
import yaml
2021
except ImportError:
2122
raise RuntimeError("Could not import 'dbt' package. You can install it using: pip install 'data-diff[dbt]'.")
2223

23-
return parse_run_results, parse_manifest, ProfileRenderer, yaml
24+
# dbt 1.5+ specific stuff to power selection of models
25+
try:
26+
from dbt.cli.main import dbtRunner
27+
except ImportError:
28+
dbtRunner = None
29+
30+
if dbtRunner is not None:
31+
dbt_runner = dbtRunner()
32+
else:
33+
dbt_runner = None
34+
35+
return parse_run_results, parse_manifest, ProfileRenderer, yaml, dbt_runner
2436

2537

2638
RUN_RESULTS_PATH = "target/run_results.json"
2739
MANIFEST_PATH = "target/manifest.json"
2840
PROJECT_FILE = "dbt_project.yml"
2941
PROFILES_FILE = "profiles.yml"
3042
LOWER_DBT_V = "1.0.0"
31-
UPPER_DBT_V = "1.4.7"
43+
UPPER_DBT_V = "1.6.0"
3244

3345

3446
# https://github.com/dbt-labs/dbt-core/blob/c952d44ec5c2506995fbad75320acbae49125d3d/core/dbt/cli/resolvers.py#L6
@@ -49,7 +61,13 @@ def legacy_profiles_dir() -> Path:
4961

5062
class DbtParser:
5163
def __init__(self, profiles_dir_override: str, project_dir_override: str) -> None:
52-
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
64+
(
65+
self.parse_run_results,
66+
self.parse_manifest,
67+
self.ProfileRenderer,
68+
self.yaml,
69+
self.dbt_runner,
70+
) = import_dbt_dependencies()
5371
self.profiles_dir = Path(profiles_dir_override or default_profiles_dir())
5472
self.project_dir = Path(project_dir_override or default_project_dir())
5573
self.connection = None
@@ -68,7 +86,60 @@ def get_datadiff_variables(self) -> dict:
6886
vars = get_from_dict_with_raise(self.project_dict, "vars", error_message)
6987
return get_from_dict_with_raise(vars, "data_diff", error_message)
7088

71-
def get_models(self):
89+
def get_models(self, dbt_selection: Optional[str] = None):
90+
dbt_version = parse_version(self.dbt_version)
91+
if dbt_selection:
92+
if (dbt_version.major, dbt_version.minor) >= (1, 5):
93+
if self.dbt_runner:
94+
return self.get_dbt_selection_models(dbt_selection)
95+
# edge case if running data-diff from a separate env than dbt (likely local development)
96+
else:
97+
raise Exception(
98+
"data-diff is using a dbt-core version < 1.5, update the environment's dbt-core version via pip install 'dbt-core>=1.5' in order to use `--select`"
99+
)
100+
else:
101+
raise Exception(
102+
f"Use of the `--select` feature requires dbt >= 1.5. Found dbt manifest: v{dbt_version}"
103+
)
104+
else:
105+
return self.get_run_results_models()
106+
107+
def get_dbt_selection_models(self, dbt_selection: str) -> List[str]:
108+
# log level and format settings needed to prevent dbt from printing to stdout
109+
# ls command is used to get the list of model unique_ids
110+
results = self.dbt_runner.invoke(
111+
[
112+
"--log-format",
113+
"json",
114+
"--log-level",
115+
"none",
116+
"ls",
117+
"--select",
118+
dbt_selection,
119+
"--resource-type",
120+
"model",
121+
"--output",
122+
"json",
123+
"--output-keys",
124+
"unique_id",
125+
"--project-dir",
126+
self.project_dir,
127+
]
128+
)
129+
if results.success and results.result:
130+
model_list = [json.loads(model)["unique_id"] for model in results.result]
131+
models = [self.manifest_obj.nodes.get(x) for x in model_list]
132+
return models
133+
elif not results.result:
134+
raise Exception(f"No dbt models found for `--select {dbt_selection}`")
135+
else:
136+
if results.exception:
137+
raise results.exception
138+
else:
139+
logger.debug(str(results))
140+
raise Exception("Encountered an error while finding `--select` models")
141+
142+
def get_run_results_models(self):
72143
with open(self.project_dir / RUN_RESULTS_PATH) as run_results:
73144
logger.info(f"Parsing file {RUN_RESULTS_PATH}")
74145
run_results_dict = json.load(run_results)
@@ -80,11 +151,11 @@ def get_models(self):
80151
self.profiles_dir = legacy_profiles_dir()
81152

82153
if dbt_version < parse_version(LOWER_DBT_V):
83-
raise Exception(
84-
f"Found dbt: v{dbt_version} Expected the dbt project's version to be >= {LOWER_DBT_V}"
85-
)
154+
raise Exception(f"Found dbt: v{dbt_version} Expected the dbt project's version to be >= {LOWER_DBT_V}")
86155
elif dbt_version >= parse_version(UPPER_DBT_V):
87-
logger.warning(f"{dbt_version} is a recent version of dbt and may not be fully tested with data-diff! \nPlease report any issues to https://github.com/datafold/data-diff/issues")
156+
logger.warning(
157+
f"{dbt_version} is a recent version of dbt and may not be fully tested with data-diff! \nPlease report any issues to https://github.com/datafold/data-diff/issues"
158+
)
88159

89160
success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
90161
models = [self.manifest_obj.nodes.get(x) for x in success_models]

0 commit comments

Comments
 (0)