@@ -967,6 +967,7 @@ def test_get_diff_vars_replace_custom_schema(self):
967
967
mock_model .database = "a_dev_db"
968
968
mock_model .schema_ = "a_custom_schema"
969
969
mock_model .config .schema_ = mock_model .schema_
970
+ mock_model .config .database = None
970
971
mock_model .alias = "a_model_name"
971
972
mock_tdatadiffmodelconfig = Mock ()
972
973
mock_tdatadiffmodelconfig .where_filter = "where"
@@ -999,6 +1000,7 @@ def test_get_diff_vars_static_custom_schema(self):
999
1000
primary_keys = ["a_primary_key" ]
1000
1001
mock_model .database = "a_dev_db"
1001
1002
mock_model .schema_ = "a_custom_schema"
1003
+ mock_model .config .database = None
1002
1004
mock_model .config .schema_ = mock_model .schema_
1003
1005
mock_model .alias = "a_model_name"
1004
1006
mock_tdatadiffmodelconfig = Mock ()
@@ -1031,6 +1033,7 @@ def test_get_diff_vars_no_custom_schema_on_model(self):
1031
1033
mock_model .database = "a_dev_db"
1032
1034
mock_model .schema_ = "a_custom_schema"
1033
1035
mock_model .config .schema_ = None
1036
+ mock_model .config .database = None
1034
1037
mock_model .alias = "a_model_name"
1035
1038
mock_tdatadiffmodelconfig = Mock ()
1036
1039
mock_tdatadiffmodelconfig .where_filter = "where"
@@ -1060,6 +1063,7 @@ def test_get_diff_vars_match_dev_schema(self):
1060
1063
mock_model .database = "a_dev_db"
1061
1064
mock_model .schema_ = "a_schema"
1062
1065
mock_model .config .schema_ = None
1066
+ mock_model .config .database = None
1063
1067
mock_model .alias = "a_model_name"
1064
1068
mock_tdatadiffmodelconfig = Mock ()
1065
1069
mock_tdatadiffmodelconfig .where_filter = "where"
@@ -1107,6 +1111,7 @@ def test_get_diff_vars_meta_where(self):
1107
1111
mock_model .database = "a_dev_db"
1108
1112
mock_model .schema_ = "a_schema"
1109
1113
mock_model .config .schema_ = None
1114
+ mock_model .config .database = None
1110
1115
mock_model .alias = "a_model_name"
1111
1116
mock_tdatadiffmodelconfig = Mock ()
1112
1117
mock_tdatadiffmodelconfig .where_filter = "where"
@@ -1136,6 +1141,7 @@ def test_get_diff_vars_meta_unrelated(self):
1136
1141
mock_model .database = "a_dev_db"
1137
1142
mock_model .schema_ = "a_schema"
1138
1143
mock_model .config .schema_ = None
1144
+ mock_model .config .database = None
1139
1145
mock_model .alias = "a_model_name"
1140
1146
mock_tdatadiffmodelconfig = Mock ()
1141
1147
mock_tdatadiffmodelconfig .where_filter = "where"
@@ -1165,6 +1171,7 @@ def test_get_diff_vars_meta_none(self):
1165
1171
mock_model .database = "a_dev_db"
1166
1172
mock_model .schema_ = "a_schema"
1167
1173
mock_model .config .schema_ = None
1174
+ mock_model .config .database = None
1168
1175
mock_model .alias = "a_model_name"
1169
1176
mock_tdatadiffmodelconfig = Mock ()
1170
1177
mock_tdatadiffmodelconfig .where_filter = "where"
@@ -1176,7 +1183,6 @@ def test_get_diff_vars_meta_none(self):
1176
1183
mock_dbt_parser .threads = 0
1177
1184
mock_dbt_parser .get_pk_from_model .return_value = primary_keys
1178
1185
mock_dbt_parser .requires_upper = False
1179
- where = None
1180
1186
mock_model .meta = None
1181
1187
1182
1188
diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
@@ -1188,3 +1194,34 @@ def test_get_diff_vars_meta_none(self):
1188
1194
assert diff_vars .threads == mock_dbt_parser .threads
1189
1195
self .assertEqual (diff_vars .where_filter , mock_tdatadiffmodelconfig .where_filter )
1190
1196
mock_dbt_parser .get_pk_from_model .assert_called_once ()
1197
+
1198
+ def test_get_diff_vars_custom_db (self ):
1199
+ mock_model = Mock ()
1200
+ prod_database = "a_prod_db"
1201
+ primary_keys = ["a_primary_key" ]
1202
+ mock_model .database = "a_dev_db"
1203
+ mock_model .schema_ = "a_schema"
1204
+ mock_model .config .schema_ = None
1205
+ mock_model .config .database = "custom_database"
1206
+ mock_model .alias = "a_model_name"
1207
+ mock_tdatadiffmodelconfig = Mock ()
1208
+ mock_tdatadiffmodelconfig .where_filter = "where"
1209
+ mock_tdatadiffmodelconfig .include_columns = ["include" ]
1210
+ mock_tdatadiffmodelconfig .exclude_columns = ["exclude" ]
1211
+ mock_dbt_parser = Mock ()
1212
+ mock_dbt_parser .get_datadiff_model_config .return_value = mock_tdatadiffmodelconfig
1213
+ mock_dbt_parser .connection = {}
1214
+ mock_dbt_parser .threads = 0
1215
+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
1216
+ mock_dbt_parser .requires_upper = False
1217
+ mock_model .meta = None
1218
+
1219
+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
1220
+
1221
+ assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
1222
+ assert diff_vars .prod_path == [mock_model .config .database , mock_model .schema_ , mock_model .alias ]
1223
+ assert diff_vars .primary_keys == primary_keys
1224
+ assert diff_vars .connection == mock_dbt_parser .connection
1225
+ assert diff_vars .threads == mock_dbt_parser .threads
1226
+ self .assertEqual (diff_vars .where_filter , mock_tdatadiffmodelconfig .where_filter )
1227
+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
0 commit comments