1
+ from contextlib import nullcontext
1
2
import json
2
3
import os
3
4
import re
42
43
run_as_daemon ,
43
44
truncate_error ,
44
45
print_version_info ,
46
+ LogStatusHandler ,
45
47
)
46
48
47
49
logger = getLogger (__name__ )
@@ -67,6 +69,7 @@ def dbt_diff(
67
69
dbt_selection : Optional [str ] = None ,
68
70
json_output : bool = False ,
69
71
state : Optional [str ] = None ,
72
+ log_status_handler : Optional [LogStatusHandler ] = None ,
70
73
where_flag : Optional [str ] = None ,
71
74
) -> None :
72
75
print_version_info ()
@@ -89,7 +92,6 @@ def dbt_diff(
89
92
if not api :
90
93
return
91
94
org_meta = api .get_org_meta ()
92
-
93
95
if config .datasource_id is None :
94
96
rich .print ("[red]Data source ID not found in dbt_project.yml" )
95
97
raise DataDiffNoDatasourceIdError (
@@ -103,48 +105,54 @@ def dbt_diff(
103
105
else :
104
106
dbt_parser .set_connection ()
105
107
106
- for model in models :
107
- diff_vars = _get_diff_vars (dbt_parser , config , model , where_flag )
108
-
109
- # we won't always have a prod path when using state
110
- # when the model DNE in prod manifest, skip the model diff
111
- if (
112
- state and len (diff_vars .prod_path ) < 2
113
- ): # < 2 because some providers like databricks can legitimately have *only* 2
114
- diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
115
- diff_output_str += "[green]New model: nothing to diff![/] \n "
116
- rich .print (diff_output_str )
117
- continue
118
-
119
- if diff_vars .primary_keys :
120
- if is_cloud :
121
- diff_thread = run_as_daemon (_cloud_diff , diff_vars , config .datasource_id , api , org_meta )
122
- diff_threads .append (diff_thread )
123
- else :
124
- _local_diff (diff_vars , json_output )
125
- else :
126
- if json_output :
127
- print (
128
- json .dumps (
129
- jsonify_error (
130
- table1 = diff_vars .prod_path ,
131
- table2 = diff_vars .dev_path ,
132
- dbt_model = diff_vars .dbt_model ,
133
- error = "No primary key found. Add uniqueness tests, meta, or tags." ,
134
- )
135
- ),
136
- flush = True ,
137
- )
108
+ with log_status_handler .status if log_status_handler else nullcontext ():
109
+ for model in models :
110
+ if log_status_handler :
111
+ log_status_handler .set_prefix (f"Diffing { model .alias } \n " )
112
+
113
+ diff_vars = _get_diff_vars (dbt_parser , config , model , where_flag )
114
+
115
+ # we won't always have a prod path when using state
116
+ # when the model DNE in prod manifest, skip the model diff
117
+ if (
118
+ state and len (diff_vars .prod_path ) < 2
119
+ ): # < 2 because some providers like databricks can legitimately have *only* 2
120
+ diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
121
+ diff_output_str += "[green]New model: nothing to diff![/] \n "
122
+ rich .print (diff_output_str )
123
+ continue
124
+
125
+ if diff_vars .primary_keys :
126
+ if is_cloud :
127
+ diff_thread = run_as_daemon (
128
+ _cloud_diff , diff_vars , config .datasource_id , api , org_meta , log_status_handler
129
+ )
130
+ diff_threads .append (diff_thread )
131
+ else :
132
+ _local_diff (diff_vars , json_output )
138
133
else :
139
- rich .print (
140
- _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
141
- + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n "
142
- )
143
-
144
- # wait for all threads
145
- if diff_threads :
146
- for thread in diff_threads :
147
- thread .join ()
134
+ if json_output :
135
+ print (
136
+ json .dumps (
137
+ jsonify_error (
138
+ table1 = diff_vars .prod_path ,
139
+ table2 = diff_vars .dev_path ,
140
+ dbt_model = diff_vars .dbt_model ,
141
+ error = "No primary key found. Add uniqueness tests, meta, or tags." ,
142
+ )
143
+ ),
144
+ flush = True ,
145
+ )
146
+ else :
147
+ rich .print (
148
+ _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
149
+ + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n "
150
+ )
151
+
152
+ # wait for all threads
153
+ if diff_threads :
154
+ for thread in diff_threads :
155
+ thread .join ()
148
156
149
157
150
158
def _get_diff_vars (
@@ -348,7 +356,15 @@ def _initialize_api() -> Optional[DatafoldAPI]:
348
356
return DatafoldAPI (api_key = api_key , host = datafold_host )
349
357
350
358
351
- def _cloud_diff (diff_vars : TDiffVars , datasource_id : int , api : DatafoldAPI , org_meta : TCloudApiOrgMeta ) -> None :
359
+ def _cloud_diff (
360
+ diff_vars : TDiffVars ,
361
+ datasource_id : int ,
362
+ api : DatafoldAPI ,
363
+ org_meta : TCloudApiOrgMeta ,
364
+ log_status_handler : Optional [LogStatusHandler ] = None ,
365
+ ) -> None :
366
+ if log_status_handler :
367
+ log_status_handler .cloud_diff_started (diff_vars .dev_path [- 1 ])
352
368
diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
353
369
payload = TCloudApiDataDiff (
354
370
data_source1_id = datasource_id ,
@@ -417,6 +433,8 @@ def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_
417
433
diff_output_str += f"\n { diff_url } \n { no_differences_template ()} \n "
418
434
rich .print (diff_output_str )
419
435
436
+ if log_status_handler :
437
+ log_status_handler .cloud_diff_finished (diff_vars .dev_path [- 1 ])
420
438
except BaseException as ex : # Catch KeyboardInterrupt too
421
439
error = ex
422
440
finally :
0 commit comments