Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 1305871

Browse files
authored
Merge branch 'master' into DX-810
2 parents 5614725 + 0d624fa commit 1305871

File tree

12 files changed

+636
-342
lines changed

12 files changed

+636
-342
lines changed

data_diff/__main__.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,14 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
234234
"-s",
235235
default=None,
236236
metavar="PATH",
237-
help="select dbt resources to compare using dbt selection syntax",
237+
help="select dbt resources to compare using dbt selection syntax.",
238+
)
239+
@click.option(
240+
"--state",
241+
"-s",
242+
default=None,
243+
metavar="PATH",
244+
help="Specify manifest to utilize for 'prod' comparison paths instead of using configuration.",
238245
)
239246
def main(conf, run, **kw):
240247
if kw["table2"] is None and kw["database2"]:
@@ -267,6 +274,9 @@ def main(conf, run, **kw):
267274
logging.basicConfig(level=logging.WARNING, format=LOG_FORMAT, datefmt=DATE_FORMAT)
268275

269276
try:
277+
state = kw.pop("state", None)
278+
if state:
279+
state = os.path.expanduser(state)
270280
profiles_dir_override = kw.pop("dbt_profiles_dir", None)
271281
if profiles_dir_override:
272282
profiles_dir_override = os.path.expanduser(profiles_dir_override)
@@ -279,11 +289,12 @@ def main(conf, run, **kw):
279289
project_dir_override=project_dir_override,
280290
is_cloud=kw["cloud"],
281291
dbt_selection=kw["select"],
292+
state=state,
282293
)
283294
else:
284-
return _data_diff(dbt_project_dir=project_dir_override,
285-
dbt_profiles_dir=profiles_dir_override,
286-
**kw)
295+
return _data_diff(
296+
dbt_project_dir=project_dir_override, dbt_profiles_dir=profiles_dir_override, state=state, **kw
297+
)
287298
except Exception as e:
288299
logging.error(e)
289300
if kw["debug"]:
@@ -324,6 +335,7 @@ def _data_diff(
324335
dbt_profiles_dir,
325336
dbt_project_dir,
326337
select,
338+
state,
327339
threads1=None,
328340
threads2=None,
329341
__conf__=None,

data_diff/cloud/data_source.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def _validate_temp_schema(temp_schema: str):
5252

5353

5454
def _get_temp_schema(dbt_parser: DbtParser, db_type: str) -> Optional[str]:
55-
diff_vars = dbt_parser.get_datadiff_variables()
56-
config_prod_database = diff_vars.get("prod_database")
57-
config_prod_schema = diff_vars.get("prod_schema")
55+
config = dbt_parser.get_datadiff_config()
56+
config_prod_database = config.prod_database
57+
config_prod_schema = config.prod_schema
5858
if config_prod_database is not None and config_prod_schema is not None:
5959
temp_schema = f"{config_prod_database}.{config_prod_schema}"
6060
if db_type == "snowflake":

data_diff/dbt.py

Lines changed: 109 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
import os
2+
import re
23
import time
34
import webbrowser
4-
from typing import List, Optional, Dict
5+
from typing import List, Optional, Dict, Tuple, Union
56
import keyring
6-
77
import pydantic
88
import rich
9-
from rich.prompt import Confirm
9+
from rich.prompt import Confirm, Prompt
10+
11+
from data_diff.errors import DataDiffCustomSchemaNoConfigError, DataDiffDbtProjectVarsNotFoundError
1012

1113
from . import connect_to_table, diff_tables, Algorithm
1214
from .cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta, get_or_create_data_source
13-
from .dbt_parser import DbtParser, PROJECT_FILE
15+
from .dbt_parser import DbtParser, PROJECT_FILE, TDatadiffConfig
1416
from .tracking import (
17+
bool_ask_for_email,
18+
create_email_signup_event_json,
1519
set_entrypoint_name,
1620
set_dbt_user_id,
1721
set_dbt_version,
@@ -52,24 +56,21 @@ def dbt_diff(
5256
project_dir_override: Optional[str] = None,
5357
is_cloud: bool = False,
5458
dbt_selection: Optional[str] = None,
59+
state: Optional[str] = None,
5560
) -> None:
5661
print_version_info()
5762
diff_threads = []
5863
set_entrypoint_name("CLI-dbt")
59-
dbt_parser = DbtParser(profiles_dir_override, project_dir_override)
64+
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, state)
6065
models = dbt_parser.get_models(dbt_selection)
61-
datadiff_variables = dbt_parser.get_datadiff_variables()
62-
config_prod_database = datadiff_variables.get("prod_database")
63-
config_prod_schema = datadiff_variables.get("prod_schema")
64-
config_prod_custom_schema = datadiff_variables.get("prod_custom_schema")
65-
datasource_id = datadiff_variables.get("datasource_id")
66-
set_dbt_user_id(dbt_parser.dbt_user_id)
67-
set_dbt_version(dbt_parser.dbt_version)
68-
set_dbt_project_id(dbt_parser.dbt_project_id)
69-
70-
if datadiff_variables.get("custom_schemas") is not None:
71-
logger.warning(
72-
"vars: data_diff: custom_schemas: is no longer used and can be removed.\nTo utilize custom schemas, see the documentation here: https://docs.datafold.com/development_testing/open_source"
66+
config = dbt_parser.get_datadiff_config()
67+
_initialize_events(dbt_parser.dbt_user_id, dbt_parser.dbt_version, dbt_parser.dbt_project_id)
68+
69+
70+
if not state and not (config.prod_database or config.prod_schema):
71+
doc_url = "https://docs.datafold.com/development_testing/open_source#configure-your-dbt-project"
72+
raise DataDiffDbtProjectVarsNotFoundError(
73+
f"""vars: data_diff: section not found in dbt_project.yml.\n\nTo solve this, please configure your dbt project: \n{doc_url}\n\nOr specify a production manifest using the `--state` flag."""
7374
)
7475

7576
if is_cloud:
@@ -79,13 +80,13 @@ def dbt_diff(
7980
return
8081
org_meta = api.get_org_meta()
8182

82-
if datasource_id is None:
83+
if config.datasource_id is None:
8384
rich.print("[red]Data source ID not found in dbt_project.yml")
8485
is_create_data_source = Confirm.ask("Would you like to create a new data source?")
8586
if is_create_data_source:
86-
datasource_id = get_or_create_data_source(api=api, dbt_parser=dbt_parser)
87+
config.datasource_id = get_or_create_data_source(api=api, dbt_parser=dbt_parser)
8788
rich.print(f'To use the data source in next runs, please, update your "{PROJECT_FILE}" with a block:')
88-
rich.print(f"[green]vars:\n data_diff:\n datasource_id: {datasource_id}\n")
89+
rich.print(f"[green]vars:\n data_diff:\n datasource_id: {config.datasource_id}\n")
8990
rich.print(
9091
"Read more about Datafold vars in docs: "
9192
"https://docs.datafold.com/os_diff/dbt_integration/#configure-a-data-source\n"
@@ -96,21 +97,29 @@ def dbt_diff(
9697
"\nvars:\n data_diff:\n datasource_id: 1234"
9798
)
9899

99-
data_source = api.get_data_source(datasource_id)
100+
data_source = api.get_data_source(config.datasource_id)
100101
dbt_parser.set_casing_policy_for(connection_type=data_source.type)
101102
rich.print("[green][bold]\nDiffs in progress...[/][/]\n")
102103

103104
else:
104105
dbt_parser.set_connection()
105106

106107
for model in models:
107-
diff_vars = _get_diff_vars(
108-
dbt_parser, config_prod_database, config_prod_schema, config_prod_custom_schema, model
109-
)
108+
diff_vars = _get_diff_vars(dbt_parser, config, model)
109+
110+
# we won't always have a prod path when using state
111+
# when the model DNE in prod manifest, skip the model diff
112+
if (
113+
state and len(diff_vars.prod_path) < 2
114+
): # < 2 because some providers like databricks can legitimately have *only* 2
115+
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
116+
diff_output_str += "[green]New model: nothing to diff![/] \n"
117+
rich.print(diff_output_str)
118+
continue
110119

111120
if diff_vars.primary_keys:
112121
if is_cloud:
113-
diff_thread = run_as_daemon(_cloud_diff, diff_vars, datasource_id, api, org_meta)
122+
diff_thread = run_as_daemon(_cloud_diff, diff_vars, config.datasource_id, api, org_meta)
114123
diff_threads.append(diff_thread)
115124
else:
116125
_local_diff(diff_vars)
@@ -128,41 +137,19 @@ def dbt_diff(
128137

129138
def _get_diff_vars(
130139
dbt_parser: "DbtParser",
131-
config_prod_database: Optional[str],
132-
config_prod_schema: Optional[str],
133-
config_prod_custom_schema: Optional[str],
140+
config: TDatadiffConfig,
134141
model,
135142
) -> TDiffVars:
136143
dev_database = model.database
137144
dev_schema = model.schema_
138145

139146
primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key")
140147

141-
# "custom" dbt config database
142-
if model.config.database:
143-
prod_database = model.config.database
144-
elif config_prod_database:
145-
prod_database = config_prod_database
146-
else:
147-
prod_database = dev_database
148-
149-
# prod schema name differs from dev schema name
150-
if config_prod_schema:
151-
custom_schema = model.config.schema_
152-
153-
# the model has a custom schema config(schema='some_schema')
154-
if custom_schema:
155-
if not config_prod_custom_schema:
156-
raise ValueError(
157-
f"Found a custom schema on model {model.name}, but no value for\nvars:\n data_diff:\n prod_custom_schema:\nPlease set a value!\n"
158-
+ "For more details see: https://docs.datafold.com/development_testing/open_source"
159-
)
160-
prod_schema = config_prod_custom_schema.replace("<custom_schema>", custom_schema)
161-
# no custom schema, use the default
162-
else:
163-
prod_schema = config_prod_schema
148+
# prod path is constructed via configuration or the prod manifest via --state
149+
if dbt_parser.prod_manifest_obj:
150+
prod_database, prod_schema = _get_prod_path_from_manifest(model, dbt_parser.prod_manifest_obj)
164151
else:
165-
prod_schema = dev_schema
152+
prod_database, prod_schema = _get_prod_path_from_config(config, model, dev_database, dev_schema)
166153

167154
if dbt_parser.requires_upper:
168155
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias] if x]
@@ -186,6 +173,45 @@ def _get_diff_vars(
186173
)
187174

188175

176+
def _get_prod_path_from_config(config, model, dev_database, dev_schema) -> Tuple[str, str]:
177+
# "custom" dbt config database
178+
if model.config.database:
179+
prod_database = model.config.database
180+
elif config.prod_database:
181+
prod_database = config.prod_database
182+
else:
183+
prod_database = dev_database
184+
185+
# prod schema name differs from dev schema name
186+
if config.prod_schema:
187+
custom_schema = model.config.schema_
188+
189+
# the model has a custom schema config(schema='some_schema')
190+
if custom_schema:
191+
if not config.prod_custom_schema:
192+
raise DataDiffCustomSchemaNoConfigError(
193+
f"Found a custom schema on model {model.name}, but no value for\nvars:\n data_diff:\n prod_custom_schema:\nPlease set a value or utilize the `--state` flag!\n\n"
194+
+ "For more details see: https://docs.datafold.com/development_testing/open_source"
195+
)
196+
prod_schema = config.prod_custom_schema.replace("<custom_schema>", custom_schema)
197+
# no custom schema, use the default
198+
else:
199+
prod_schema = config.prod_schema
200+
else:
201+
prod_schema = dev_schema
202+
return prod_database, prod_schema
203+
204+
205+
def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str], Tuple[None, None]]:
206+
prod_database = None
207+
prod_schema = None
208+
prod_model = prod_manifest.nodes.get(model.unique_id, None)
209+
if prod_model:
210+
prod_database = prod_model.database
211+
prod_schema = prod_model.schema_
212+
return prod_database, prod_schema
213+
214+
189215
def _local_diff(diff_vars: TDiffVars) -> None:
190216
dev_qualified_str = ".".join(diff_vars.dev_path)
191217
prod_qualified_str = ".".join(diff_vars.prod_path)
@@ -389,3 +415,34 @@ def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_
389415

390416
def _diff_output_base(dev_path: str, prod_path: str) -> str:
391417
return f"\n[green]{prod_path} <> {dev_path}[/] \n"
418+
419+
420+
def _initialize_events(dbt_user_id: Optional[str], dbt_version: Optional[str], dbt_project_id: Optional[str]) -> None:
421+
set_dbt_user_id(dbt_user_id)
422+
set_dbt_version(dbt_version)
423+
set_dbt_project_id(dbt_project_id)
424+
_email_signup()
425+
426+
427+
def _email_signup() -> None:
428+
email_regex = r'^[\w\.\+-]+@[\w\.-]+\.\w+$'
429+
prompt = "\nWould you like to be notified when a new data-diff version is available?\n\nEnter email or leave blank to opt out (we'll only ask once).\n"
430+
431+
if bool_ask_for_email():
432+
while True:
433+
email_input = Prompt.ask(
434+
prompt=prompt,
435+
default="",
436+
show_default=False,
437+
)
438+
email = email_input.strip()
439+
440+
if email == "" or re.match(email_regex, email):
441+
break
442+
443+
prompt = ""
444+
rich.print("[red]Invalid email. Please enter a valid email or leave it blank to opt out.[/]")
445+
446+
if email:
447+
event_json = create_email_signup_event_json(email)
448+
run_as_daemon(send_event_json, event_json)

0 commit comments

Comments
 (0)