1
+ from contextlib import nullcontext
1
2
import json
2
3
import os
3
4
import re
42
43
run_as_daemon ,
43
44
truncate_error ,
44
45
print_version_info ,
46
+ LogStatusHandler ,
45
47
)
46
48
47
49
logger = getLogger (__name__ )
@@ -67,6 +69,7 @@ def dbt_diff(
67
69
dbt_selection : Optional [str ] = None ,
68
70
json_output : bool = False ,
69
71
state : Optional [str ] = None ,
72
+ log_status_handler : Optional [LogStatusHandler ] = None ,
70
73
) -> None :
71
74
print_version_info ()
72
75
diff_threads = []
@@ -88,7 +91,6 @@ def dbt_diff(
88
91
if not api :
89
92
return
90
93
org_meta = api .get_org_meta ()
91
-
92
94
if config .datasource_id is None :
93
95
rich .print ("[red]Data source ID not found in dbt_project.yml" )
94
96
raise DataDiffNoDatasourceIdError (
@@ -102,48 +104,54 @@ def dbt_diff(
102
104
else :
103
105
dbt_parser .set_connection ()
104
106
105
- for model in models :
106
- diff_vars = _get_diff_vars (dbt_parser , config , model )
107
-
108
- # we won't always have a prod path when using state
109
- # when the model DNE in prod manifest, skip the model diff
110
- if (
111
- state and len (diff_vars .prod_path ) < 2
112
- ): # < 2 because some providers like databricks can legitimately have *only* 2
113
- diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
114
- diff_output_str += "[green]New model: nothing to diff![/] \n "
115
- rich .print (diff_output_str )
116
- continue
117
-
118
- if diff_vars .primary_keys :
119
- if is_cloud :
120
- diff_thread = run_as_daemon (_cloud_diff , diff_vars , config .datasource_id , api , org_meta )
121
- diff_threads .append (diff_thread )
122
- else :
123
- _local_diff (diff_vars , json_output )
124
- else :
125
- if json_output :
126
- print (
127
- json .dumps (
128
- jsonify_error (
129
- table1 = diff_vars .prod_path ,
130
- table2 = diff_vars .dev_path ,
131
- dbt_model = diff_vars .dbt_model ,
132
- error = "No primary key found. Add uniqueness tests, meta, or tags." ,
133
- )
134
- ),
135
- flush = True ,
136
- )
107
+ with log_status_handler .status if log_status_handler else nullcontext ():
108
+ for model in models :
109
+ if log_status_handler :
110
+ log_status_handler .set_prefix (f"Diffing { model .alias } \n " )
111
+
112
+ diff_vars = _get_diff_vars (dbt_parser , config , model )
113
+
114
+ # we won't always have a prod path when using state
115
+ # when the model DNE in prod manifest, skip the model diff
116
+ if (
117
+ state and len (diff_vars .prod_path ) < 2
118
+ ): # < 2 because some providers like databricks can legitimately have *only* 2
119
+ diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
120
+ diff_output_str += "[green]New model: nothing to diff![/] \n "
121
+ rich .print (diff_output_str )
122
+ continue
123
+
124
+ if diff_vars .primary_keys :
125
+ if is_cloud :
126
+ diff_thread = run_as_daemon (
127
+ _cloud_diff , diff_vars , config .datasource_id , api , org_meta , log_status_handler
128
+ )
129
+ diff_threads .append (diff_thread )
130
+ else :
131
+ _local_diff (diff_vars , json_output )
137
132
else :
138
- rich .print (
139
- _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
140
- + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n "
141
- )
142
-
143
- # wait for all threads
144
- if diff_threads :
145
- for thread in diff_threads :
146
- thread .join ()
133
+ if json_output :
134
+ print (
135
+ json .dumps (
136
+ jsonify_error (
137
+ table1 = diff_vars .prod_path ,
138
+ table2 = diff_vars .dev_path ,
139
+ dbt_model = diff_vars .dbt_model ,
140
+ error = "No primary key found. Add uniqueness tests, meta, or tags." ,
141
+ )
142
+ ),
143
+ flush = True ,
144
+ )
145
+ else :
146
+ rich .print (
147
+ _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
148
+ + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n "
149
+ )
150
+
151
+ # wait for all threads
152
+ if diff_threads :
153
+ for thread in diff_threads :
154
+ thread .join ()
147
155
148
156
149
157
def _get_diff_vars (
@@ -345,7 +353,15 @@ def _initialize_api() -> Optional[DatafoldAPI]:
345
353
return DatafoldAPI (api_key = api_key , host = datafold_host )
346
354
347
355
348
- def _cloud_diff (diff_vars : TDiffVars , datasource_id : int , api : DatafoldAPI , org_meta : TCloudApiOrgMeta ) -> None :
356
+ def _cloud_diff (
357
+ diff_vars : TDiffVars ,
358
+ datasource_id : int ,
359
+ api : DatafoldAPI ,
360
+ org_meta : TCloudApiOrgMeta ,
361
+ log_status_handler : Optional [LogStatusHandler ] = None ,
362
+ ) -> None :
363
+ if log_status_handler :
364
+ log_status_handler .cloud_diff_started (diff_vars .dev_path [- 1 ])
349
365
diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
350
366
payload = TCloudApiDataDiff (
351
367
data_source1_id = datasource_id ,
@@ -414,6 +430,8 @@ def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_
414
430
diff_output_str += f"\n { diff_url } \n { no_differences_template ()} \n "
415
431
rich .print (diff_output_str )
416
432
433
+ if log_status_handler :
434
+ log_status_handler .cloud_diff_finished (diff_vars .dev_path [- 1 ])
417
435
except BaseException as ex : # Catch KeyboardInterrupt too
418
436
error = ex
419
437
finally :
0 commit comments