1
1
import logging
2
2
import typing as t
3
+ from types import MappingProxyType
3
4
4
5
from dagster import (
5
6
AssetExecutionContext ,
6
7
ConfigurableResource ,
7
8
MaterializeResult ,
8
9
)
10
+ from dagster ._core .errors import DagsterInvalidPropertyError
9
11
from sqlmesh import Model
10
12
from sqlmesh .core .context import Context as SQLMeshContext
11
13
from sqlmesh .core .snapshot import Snapshot , SnapshotInfoLike , SnapshotTableInfo
12
14
from sqlmesh .utils .dag import DAG
13
15
from sqlmesh .utils .date import TimeLike
16
+ from sqlmesh .utils .errors import SQLMeshError
14
17
15
18
from dagster_sqlmesh .controller .base import (
16
19
DEFAULT_CONTEXT_FACTORY ,
@@ -113,20 +116,41 @@ def event_name(self):
113
116
return self ._event .__class__ .__name__
114
117
115
118
119
+ class GenericSQLMeshError (Exception ):
120
+ pass
121
+
122
+
123
+ class FailedModelError (Exception ):
124
+ def __init__ (self , model_name : str , message : str | None ) -> None :
125
+ super ().__init__ (message )
126
+ self .model_name = model_name
127
+ self .message = message
128
+
129
+
130
+ class PlanOrRunFailedError (Exception ):
131
+ def __init__ (self , stage : str , message : str , errors : list [Exception ]) -> None :
132
+ super ().__init__ (message )
133
+ self .stage = stage
134
+ self .errors = errors
135
+
136
+
116
137
class DagsterSQLMeshEventHandler :
117
138
def __init__ (
118
139
self ,
119
140
context : AssetExecutionContext ,
120
141
models_map : dict [str , Model ],
121
142
dag : DAG [t .Any ],
122
143
prefix : str ,
144
+ is_testing : bool = False ,
123
145
) -> None :
124
146
self ._models_map = models_map
125
147
self ._prefix = prefix
126
148
self ._context = context
127
149
self ._logger = context .log
128
150
self ._tracker = MaterializationTracker (dag .sorted [:], self ._logger )
129
151
self ._stage = "plan"
152
+ self ._errors : list [Exception ] = []
153
+ self ._is_testing = is_testing
130
154
131
155
def process_events (self , event : console .ConsoleEvent ) -> None :
132
156
self .report_event (event )
@@ -150,14 +174,17 @@ def notify_success(
150
174
# If the model is not in models_map, we can skip any notification
151
175
if model :
152
176
output_key = sqlmesh_model_name_to_key (model .name )
153
- asset_key = self ._context .asset_key_for_output (output_key )
154
- yield MaterializeResult (
155
- asset_key = asset_key ,
156
- metadata = {
157
- "updated" : update_status ,
158
- "duration_ms" : 0 ,
159
- },
160
- )
177
+ if not self ._is_testing :
178
+ # Stupidly dagster when testing cannot use the following
179
+ # method so we must specifically skip this when testing
180
+ asset_key = self ._context .asset_key_for_output (output_key )
181
+ yield MaterializeResult (
182
+ asset_key = asset_key ,
183
+ metadata = {
184
+ "updated" : update_status ,
185
+ "duration_ms" : 0 ,
186
+ },
187
+ )
161
188
notify = self ._tracker .notify_queue_next ()
162
189
163
190
def report_event (self , event : console .ConsoleEvent ) -> None :
@@ -210,19 +237,22 @@ def report_event(self, event: console.ConsoleEvent) -> None:
210
237
if success :
211
238
log_context .info ("sqlmesh ran successfully" )
212
239
else :
213
- log_context .error ("sqlmesh failed" )
214
- raise Exception ("sqlmesh failed during run" )
240
+ log_context .error ("sqlmesh failed. check collected errors" )
215
241
case console .LogError (message = message ):
216
242
log_context .error (
217
243
f"sqlmesh reported an error: { message } " ,
218
244
)
219
- case console .LogFailedModels (models = models ):
220
- if len (models ) != 0 :
245
+ self ._errors .append (GenericSQLMeshError (message ))
246
+ case console .LogFailedModels (errors = errors ):
247
+ if len (errors ) != 0 :
221
248
failed_models = "\n " .join (
222
- [f"{ model !s} \n { model .__cause__ !s} " for model in models ]
249
+ [f"{ error . node !s} \n { error .__cause__ !s} " for error in errors ]
223
250
)
224
251
log_context .error (f"sqlmesh failed models: { failed_models } " )
225
- raise Exception ("sqlmesh has failed models" )
252
+ for error in errors :
253
+ self ._errors .append (
254
+ FailedModelError (error .node , str (error .__cause__ ))
255
+ )
226
256
case console .UpdatePromotionProgress (snapshot = snapshot , promoted = promoted ):
227
257
log_context .info (
228
258
"Promotion progress update" ,
@@ -263,9 +293,18 @@ def log(
263
293
def update_stage (self , stage : str ):
264
294
self ._stage = stage
265
295
296
+ @property
297
+ def stage (self ) -> str :
298
+ return self ._stage
299
+
300
+ @property
301
+ def errors (self ) -> list [Exception ]:
302
+ return self ._errors [:]
303
+
266
304
267
305
class SQLMeshResource (ConfigurableResource ):
268
306
config : SQLMeshContextConfig
307
+ is_testing : bool = False
269
308
270
309
def run (
271
310
self ,
@@ -293,25 +332,16 @@ def run(
293
332
with controller .instance (environment ) as mesh :
294
333
dag = mesh .models_dag ()
295
334
296
- select_models = []
297
-
298
335
models = mesh .models ()
299
336
models_map = models .copy ()
300
337
all_available_models = set (
301
338
[model .fqn for model , _ in mesh .non_external_models_dag ()]
302
339
)
303
- if context .selected_output_names :
304
- models_map = {}
305
- for key , model in models .items ():
306
- if (
307
- sqlmesh_model_name_to_key (model .name )
308
- in context .selected_output_names
309
- ):
310
- models_map [key ] = model
311
- select_models .append (model .name )
312
- selected_models_set = set (models_map .keys ())
313
-
314
- if all_available_models == selected_models_set :
340
+ selected_models_set , models_map , select_models = (
341
+ self ._get_selected_models_from_context (context , models )
342
+ )
343
+
344
+ if all_available_models == selected_models_set or select_models is None :
315
345
logger .info ("all models selected" )
316
346
317
347
# Setting this to none to allow sqlmesh to select all models and
@@ -321,24 +351,61 @@ def run(
321
351
logger .info (f"selected models: { select_models } " )
322
352
323
353
event_handler = DagsterSQLMeshEventHandler (
324
- context , models_map , dag , "sqlmesh: "
354
+ context , models_map , dag , "sqlmesh: " , is_testing = self . is_testing
325
355
)
326
356
327
- for event in mesh .plan_and_run (
328
- start = start ,
329
- end = end ,
330
- select_models = select_models ,
331
- restate_models = restate_models ,
332
- restate_selected = restate_selected ,
333
- skip_run = skip_run ,
334
- plan_options = plan_options ,
335
- run_options = run_options ,
336
- ):
337
- logger .debug (f"sqlmesh event: { event } " )
338
- event_handler .process_events (event )
339
-
357
+ try :
358
+ for event in mesh .plan_and_run (
359
+ start = start ,
360
+ end = end ,
361
+ select_models = select_models ,
362
+ restate_models = restate_models ,
363
+ restate_selected = restate_selected ,
364
+ skip_run = skip_run ,
365
+ plan_options = plan_options ,
366
+ run_options = run_options ,
367
+ ):
368
+ logger .debug (f"sqlmesh event: { event } " )
369
+ event_handler .process_events (event )
370
+ except SQLMeshError as e :
371
+ logger .error (f"sqlmesh error: { e } " )
372
+ errors = event_handler .errors
373
+ for error in errors :
374
+ logger .error (f"sqlmesh encountered the following error during sqlmesh { event_handler .stage } : { error } " )
375
+ raise PlanOrRunFailedError (
376
+ event_handler .stage ,
377
+ f"sqlmesh failed during { event_handler .stage } with { len (event_handler .errors ) + 1 } errors" ,
378
+ [e , * event_handler .errors ],
379
+ )
340
380
yield from event_handler .notify_success (mesh .context )
341
381
382
+ def _get_selected_models_from_context (
383
+ self , context : AssetExecutionContext , models : MappingProxyType [str , Model ]
384
+ ) -> tuple [set [str ], dict [str , Model ], list [str ] | None ]:
385
+ models_map = models .copy ()
386
+ try :
387
+ selected_output_names = set (context .selected_output_names )
388
+ except (DagsterInvalidPropertyError , AttributeError ) as e :
389
+ # Special case for direct execution context when testing. This is related to:
390
+ # https://github.com/dagster-io/dagster/issues/23633
391
+ if "DirectOpExecutionContext" in str (e ):
392
+ context .log .warning ("Caught an error that is likely a direct execution" )
393
+ return (set (models_map .keys ()), models_map , None )
394
+ else :
395
+ raise e
396
+
397
+ select_models : list [str ] = []
398
+ models_map = {}
399
+ for key , model in models .items ():
400
+ if sqlmesh_model_name_to_key (model .name ) in selected_output_names :
401
+ models_map [key ] = model
402
+ select_models .append (model .name )
403
+ return (
404
+ set (models_map .keys ()),
405
+ models_map ,
406
+ select_models ,
407
+ )
408
+
342
409
def get_controller (
343
410
self ,
344
411
context_factory : ContextFactory [ContextCls ],
0 commit comments