Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 7b7ad4e

Browse files
authored
Merge pull request #586 from dlawin/issue_483
add support for custom database config
2 parents 8fccfff + 15f4cdf commit 7b7ad4e

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

data_diff/dbt.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,13 @@ def _get_diff_vars(
138138

139139
primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key")
140140

141-
prod_database = config_prod_database if config_prod_database else dev_database
141+
# "custom" dbt config database
142+
if model.config.database:
143+
prod_database = model.config.database
144+
elif config_prod_database:
145+
prod_database = config_prod_database
146+
else:
147+
prod_database = dev_database
142148

143149
# prod schema name differs from dev schema name
144150
if config_prod_schema:

tests/test_dbt.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,7 @@ def test_get_diff_vars_replace_custom_schema(self):
967967
mock_model.database = "a_dev_db"
968968
mock_model.schema_ = "a_custom_schema"
969969
mock_model.config.schema_ = mock_model.schema_
970+
mock_model.config.database = None
970971
mock_model.alias = "a_model_name"
971972
mock_tdatadiffmodelconfig = Mock()
972973
mock_tdatadiffmodelconfig.where_filter = "where"
@@ -999,6 +1000,7 @@ def test_get_diff_vars_static_custom_schema(self):
9991000
primary_keys = ["a_primary_key"]
10001001
mock_model.database = "a_dev_db"
10011002
mock_model.schema_ = "a_custom_schema"
1003+
mock_model.config.database = None
10021004
mock_model.config.schema_ = mock_model.schema_
10031005
mock_model.alias = "a_model_name"
10041006
mock_tdatadiffmodelconfig = Mock()
@@ -1031,6 +1033,7 @@ def test_get_diff_vars_no_custom_schema_on_model(self):
10311033
mock_model.database = "a_dev_db"
10321034
mock_model.schema_ = "a_custom_schema"
10331035
mock_model.config.schema_ = None
1036+
mock_model.config.database = None
10341037
mock_model.alias = "a_model_name"
10351038
mock_tdatadiffmodelconfig = Mock()
10361039
mock_tdatadiffmodelconfig.where_filter = "where"
@@ -1060,6 +1063,7 @@ def test_get_diff_vars_match_dev_schema(self):
10601063
mock_model.database = "a_dev_db"
10611064
mock_model.schema_ = "a_schema"
10621065
mock_model.config.schema_ = None
1066+
mock_model.config.database = None
10631067
mock_model.alias = "a_model_name"
10641068
mock_tdatadiffmodelconfig = Mock()
10651069
mock_tdatadiffmodelconfig.where_filter = "where"
@@ -1107,6 +1111,7 @@ def test_get_diff_vars_meta_where(self):
11071111
mock_model.database = "a_dev_db"
11081112
mock_model.schema_ = "a_schema"
11091113
mock_model.config.schema_ = None
1114+
mock_model.config.database = None
11101115
mock_model.alias = "a_model_name"
11111116
mock_tdatadiffmodelconfig = Mock()
11121117
mock_tdatadiffmodelconfig.where_filter = "where"
@@ -1136,6 +1141,7 @@ def test_get_diff_vars_meta_unrelated(self):
11361141
mock_model.database = "a_dev_db"
11371142
mock_model.schema_ = "a_schema"
11381143
mock_model.config.schema_ = None
1144+
mock_model.config.database = None
11391145
mock_model.alias = "a_model_name"
11401146
mock_tdatadiffmodelconfig = Mock()
11411147
mock_tdatadiffmodelconfig.where_filter = "where"
@@ -1165,6 +1171,7 @@ def test_get_diff_vars_meta_none(self):
11651171
mock_model.database = "a_dev_db"
11661172
mock_model.schema_ = "a_schema"
11671173
mock_model.config.schema_ = None
1174+
mock_model.config.database = None
11681175
mock_model.alias = "a_model_name"
11691176
mock_tdatadiffmodelconfig = Mock()
11701177
mock_tdatadiffmodelconfig.where_filter = "where"
@@ -1176,7 +1183,6 @@ def test_get_diff_vars_meta_none(self):
11761183
mock_dbt_parser.threads = 0
11771184
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
11781185
mock_dbt_parser.requires_upper = False
1179-
where = None
11801186
mock_model.meta = None
11811187

11821188
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model)
@@ -1188,3 +1194,34 @@ def test_get_diff_vars_meta_none(self):
11881194
assert diff_vars.threads == mock_dbt_parser.threads
11891195
self.assertEqual(diff_vars.where_filter, mock_tdatadiffmodelconfig.where_filter)
11901196
mock_dbt_parser.get_pk_from_model.assert_called_once()
1197+
1198+
def test_get_diff_vars_custom_db(self):
1199+
mock_model = Mock()
1200+
prod_database = "a_prod_db"
1201+
primary_keys = ["a_primary_key"]
1202+
mock_model.database = "a_dev_db"
1203+
mock_model.schema_ = "a_schema"
1204+
mock_model.config.schema_ = None
1205+
mock_model.config.database = "custom_database"
1206+
mock_model.alias = "a_model_name"
1207+
mock_tdatadiffmodelconfig = Mock()
1208+
mock_tdatadiffmodelconfig.where_filter = "where"
1209+
mock_tdatadiffmodelconfig.include_columns = ["include"]
1210+
mock_tdatadiffmodelconfig.exclude_columns = ["exclude"]
1211+
mock_dbt_parser = Mock()
1212+
mock_dbt_parser.get_datadiff_model_config.return_value = mock_tdatadiffmodelconfig
1213+
mock_dbt_parser.connection = {}
1214+
mock_dbt_parser.threads = 0
1215+
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
1216+
mock_dbt_parser.requires_upper = False
1217+
mock_model.meta = None
1218+
1219+
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model)
1220+
1221+
assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
1222+
assert diff_vars.prod_path == [mock_model.config.database, mock_model.schema_, mock_model.alias]
1223+
assert diff_vars.primary_keys == primary_keys
1224+
assert diff_vars.connection == mock_dbt_parser.connection
1225+
assert diff_vars.threads == mock_dbt_parser.threads
1226+
self.assertEqual(diff_vars.where_filter, mock_tdatadiffmodelconfig.where_filter)
1227+
mock_dbt_parser.get_pk_from_model.assert_called_once()

0 commit comments

Comments
 (0)