Skip to content

Commit 8de0269

Browse files
authored
fix: ensures that run/plan errors are properly reported (#44)
1 parent 8bd0770 commit 8de0269

File tree

11 files changed

+288
-78
lines changed

11 files changed

+288
-78
lines changed

dagster_sqlmesh/conftest.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,12 @@
66
import typing as t
77

88
import pytest
9-
from sqlmesh.core.config import (
10-
Config as SQLMeshConfig,
11-
DuckDBConnectionConfig,
12-
GatewayConfig,
13-
ModelDefaultsConfig,
14-
)
159

1610
from dagster_sqlmesh.config import SQLMeshContextConfig
17-
from dagster_sqlmesh.testing import SQLMeshTestContext
11+
from dagster_sqlmesh.testing import (
12+
SQLMeshTestContext,
13+
setup_testing_sqlmesh_context_config,
14+
)
1815

1916
logger = logging.getLogger(__name__)
2017

@@ -41,23 +38,19 @@ def sample_sqlmesh_project() -> t.Iterator[str]:
4138
# Initialize the "source" data
4239
yield str(project_dir)
4340

41+
@pytest.fixture
42+
def sample_sqlmesh_db_path(sample_sqlmesh_project: str) -> t.Iterator[str]:
43+
db_path = os.path.join(sample_sqlmesh_project, "db.db")
44+
yield db_path
45+
46+
@pytest.fixture
47+
def sample_sqlmesh_test_context_config(sample_sqlmesh_project: str, sample_sqlmesh_db_path: str) -> t.Iterator[SQLMeshContextConfig]:
48+
yield setup_testing_sqlmesh_context_config(db_path=sample_sqlmesh_db_path, project_path=sample_sqlmesh_project)
4449

4550
@pytest.fixture
4651
def sample_sqlmesh_test_context(
47-
sample_sqlmesh_project: str,
52+
sample_sqlmesh_project: str, sample_sqlmesh_test_context_config: SQLMeshContextConfig, sample_sqlmesh_db_path: str
4853
) -> t.Iterator[SQLMeshTestContext]:
49-
db_path = os.path.join(sample_sqlmesh_project, "db.db")
50-
config = SQLMeshConfig(
51-
gateways={
52-
"local": GatewayConfig(connection=DuckDBConnectionConfig(database=db_path)),
53-
},
54-
default_gateway="local",
55-
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
56-
)
57-
config_as_dict = config.dict()
58-
context_config = SQLMeshContextConfig(
59-
path=sample_sqlmesh_project, gateway="local", config_override=config_as_dict
60-
)
61-
test_context = SQLMeshTestContext(db_path=db_path, context_config=context_config)
54+
test_context = SQLMeshTestContext(db_path=sample_sqlmesh_db_path, context_config=sample_sqlmesh_test_context_config)
6255
test_context.initialize_test_source()
6356
yield test_context

dagster_sqlmesh/console.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,14 +405,14 @@ def publish_known_event(self, event_name: str, **kwargs: t.Any) -> None:
405405

406406
def publish(self, event: ConsoleEvent) -> None:
407407
self.logger.debug(
408-
f"EventConsole[{self.id}]: sending event to {len(self._handlers)}"
408+
f"EventConsole[{self.id}]: sending event {event.__class__.__name__} to {len(self._handlers)}"
409409
)
410410
for handler in self._handlers.values():
411411
handler(event)
412412

413413
def publish_unknown_event(self, event_name: str, **kwargs: t.Any) -> None:
414414
self.logger.debug(
415-
f"EventConsole[{self.id}]: sending unknown event to {len(self._handlers)}"
415+
f"EventConsole[{self.id}]: sending unknown '{event_name}' event to {len(self._handlers)} handlers"
416416
)
417417
self.logger.debug(f"EventConsole[{self.id}]: unknown event {event_name} {kwargs}")
418418

dagster_sqlmesh/controller/base.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,17 @@ def run_sqlmesh_thread(
178178
default_catalog: str,
179179
) -> None:
180180
logger.debug("dagster-sqlmesh: thread started")
181+
182+
def auto_execute_plan(event: ConsoleEvent):
183+
if isinstance(event, Plan):
184+
try:
185+
event.plan_builder.apply()
186+
except Exception as e:
187+
controller.console.exception(e)
188+
return None
189+
181190
try:
191+
controller.console.add_handler(auto_execute_plan)
182192
builder = t.cast(
183193
PlanBuilder,
184194
context.plan_builder(
@@ -191,7 +201,7 @@ def run_sqlmesh_thread(
191201
builder,
192202
auto_apply=True,
193203
default_catalog=default_catalog,
194-
)
204+
)
195205
except Exception as e:
196206
controller.console.exception(e)
197207
except: # noqa: E722
@@ -218,16 +228,18 @@ def run_sqlmesh_thread(
218228
thread.start()
219229

220230
self.logger.debug("waiting for events")
221-
for event in generator.events(thread):
222-
match event:
223-
case ConsoleException(exception=e):
224-
raise e
225-
case Plan(plan_builder=plan_builder, auto_apply=auto_apply):
226-
if auto_apply:
227-
plan_builder.apply()
228-
yield event
229-
case _:
230-
yield event
231+
try:
232+
for event in generator.events(thread):
233+
match event:
234+
case ConsoleException(exception=e):
235+
raise e
236+
case _:
237+
yield event
238+
except Exception as e:
239+
import traceback
240+
print("An exception occurred:")
241+
print(traceback.format_exc())
242+
raise
231243

232244
thread.join()
233245

dagster_sqlmesh/resource.py

Lines changed: 109 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import logging
22
import typing as t
3+
from types import MappingProxyType
34

45
from dagster import (
56
AssetExecutionContext,
67
ConfigurableResource,
78
MaterializeResult,
89
)
10+
from dagster._core.errors import DagsterInvalidPropertyError
911
from sqlmesh import Model
1012
from sqlmesh.core.context import Context as SQLMeshContext
1113
from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo
1214
from sqlmesh.utils.dag import DAG
1315
from sqlmesh.utils.date import TimeLike
16+
from sqlmesh.utils.errors import SQLMeshError
1417

1518
from dagster_sqlmesh.controller.base import (
1619
DEFAULT_CONTEXT_FACTORY,
@@ -113,20 +116,41 @@ def event_name(self):
113116
return self._event.__class__.__name__
114117

115118

119+
class GenericSQLMeshError(Exception):
120+
pass
121+
122+
123+
class FailedModelError(Exception):
124+
def __init__(self, model_name: str, message: str | None) -> None:
125+
super().__init__(message)
126+
self.model_name = model_name
127+
self.message = message
128+
129+
130+
class PlanOrRunFailedError(Exception):
131+
def __init__(self, stage: str, message: str, errors: list[Exception]) -> None:
132+
super().__init__(message)
133+
self.stage = stage
134+
self.errors = errors
135+
136+
116137
class DagsterSQLMeshEventHandler:
117138
def __init__(
118139
self,
119140
context: AssetExecutionContext,
120141
models_map: dict[str, Model],
121142
dag: DAG[t.Any],
122143
prefix: str,
144+
is_testing: bool = False,
123145
) -> None:
124146
self._models_map = models_map
125147
self._prefix = prefix
126148
self._context = context
127149
self._logger = context.log
128150
self._tracker = MaterializationTracker(dag.sorted[:], self._logger)
129151
self._stage = "plan"
152+
self._errors: list[Exception] = []
153+
self._is_testing = is_testing
130154

131155
def process_events(self, event: console.ConsoleEvent) -> None:
132156
self.report_event(event)
@@ -150,14 +174,17 @@ def notify_success(
150174
# If the model is not in models_map, we can skip any notification
151175
if model:
152176
output_key = sqlmesh_model_name_to_key(model.name)
153-
asset_key = self._context.asset_key_for_output(output_key)
154-
yield MaterializeResult(
155-
asset_key=asset_key,
156-
metadata={
157-
"updated": update_status,
158-
"duration_ms": 0,
159-
},
160-
)
177+
if not self._is_testing:
178+
# Stupidly dagster when testing cannot use the following
179+
# method so we must specifically skip this when testing
180+
asset_key = self._context.asset_key_for_output(output_key)
181+
yield MaterializeResult(
182+
asset_key=asset_key,
183+
metadata={
184+
"updated": update_status,
185+
"duration_ms": 0,
186+
},
187+
)
161188
notify = self._tracker.notify_queue_next()
162189

163190
def report_event(self, event: console.ConsoleEvent) -> None:
@@ -210,19 +237,22 @@ def report_event(self, event: console.ConsoleEvent) -> None:
210237
if success:
211238
log_context.info("sqlmesh ran successfully")
212239
else:
213-
log_context.error("sqlmesh failed")
214-
raise Exception("sqlmesh failed during run")
240+
log_context.error("sqlmesh failed. check collected errors")
215241
case console.LogError(message=message):
216242
log_context.error(
217243
f"sqlmesh reported an error: {message}",
218244
)
219-
case console.LogFailedModels(models=models):
220-
if len(models) != 0:
245+
self._errors.append(GenericSQLMeshError(message))
246+
case console.LogFailedModels(errors=errors):
247+
if len(errors) != 0:
221248
failed_models = "\n".join(
222-
[f"{model!s}\n{model.__cause__!s}" for model in models]
249+
[f"{error.node!s}\n{error.__cause__!s}" for error in errors]
223250
)
224251
log_context.error(f"sqlmesh failed models: {failed_models}")
225-
raise Exception("sqlmesh has failed models")
252+
for error in errors:
253+
self._errors.append(
254+
FailedModelError(error.node, str(error.__cause__))
255+
)
226256
case console.UpdatePromotionProgress(snapshot=snapshot, promoted=promoted):
227257
log_context.info(
228258
"Promotion progress update",
@@ -263,9 +293,18 @@ def log(
263293
def update_stage(self, stage: str):
264294
self._stage = stage
265295

296+
@property
297+
def stage(self) -> str:
298+
return self._stage
299+
300+
@property
301+
def errors(self) -> list[Exception]:
302+
return self._errors[:]
303+
266304

267305
class SQLMeshResource(ConfigurableResource):
268306
config: SQLMeshContextConfig
307+
is_testing: bool = False
269308

270309
def run(
271310
self,
@@ -293,25 +332,16 @@ def run(
293332
with controller.instance(environment) as mesh:
294333
dag = mesh.models_dag()
295334

296-
select_models = []
297-
298335
models = mesh.models()
299336
models_map = models.copy()
300337
all_available_models = set(
301338
[model.fqn for model, _ in mesh.non_external_models_dag()]
302339
)
303-
if context.selected_output_names:
304-
models_map = {}
305-
for key, model in models.items():
306-
if (
307-
sqlmesh_model_name_to_key(model.name)
308-
in context.selected_output_names
309-
):
310-
models_map[key] = model
311-
select_models.append(model.name)
312-
selected_models_set = set(models_map.keys())
313-
314-
if all_available_models == selected_models_set:
340+
selected_models_set, models_map, select_models = (
341+
self._get_selected_models_from_context(context, models)
342+
)
343+
344+
if all_available_models == selected_models_set or select_models is None:
315345
logger.info("all models selected")
316346

317347
# Setting this to none to allow sqlmesh to select all models and
@@ -321,24 +351,61 @@ def run(
321351
logger.info(f"selected models: {select_models}")
322352

323353
event_handler = DagsterSQLMeshEventHandler(
324-
context, models_map, dag, "sqlmesh: "
354+
context, models_map, dag, "sqlmesh: ", is_testing=self.is_testing
325355
)
326356

327-
for event in mesh.plan_and_run(
328-
start=start,
329-
end=end,
330-
select_models=select_models,
331-
restate_models=restate_models,
332-
restate_selected=restate_selected,
333-
skip_run=skip_run,
334-
plan_options=plan_options,
335-
run_options=run_options,
336-
):
337-
logger.debug(f"sqlmesh event: {event}")
338-
event_handler.process_events(event)
339-
357+
try:
358+
for event in mesh.plan_and_run(
359+
start=start,
360+
end=end,
361+
select_models=select_models,
362+
restate_models=restate_models,
363+
restate_selected=restate_selected,
364+
skip_run=skip_run,
365+
plan_options=plan_options,
366+
run_options=run_options,
367+
):
368+
logger.debug(f"sqlmesh event: {event}")
369+
event_handler.process_events(event)
370+
except SQLMeshError as e:
371+
logger.error(f"sqlmesh error: {e}")
372+
errors = event_handler.errors
373+
for error in errors:
374+
logger.error(f"sqlmesh encountered the following error during sqlmesh {event_handler.stage}: {error}")
375+
raise PlanOrRunFailedError(
376+
event_handler.stage,
377+
f"sqlmesh failed during {event_handler.stage} with {len(event_handler.errors) + 1} errors",
378+
[e, *event_handler.errors],
379+
)
340380
yield from event_handler.notify_success(mesh.context)
341381

382+
def _get_selected_models_from_context(
383+
self, context: AssetExecutionContext, models: MappingProxyType[str, Model]
384+
) -> tuple[set[str], dict[str, Model], list[str] | None]:
385+
models_map = models.copy()
386+
try:
387+
selected_output_names = set(context.selected_output_names)
388+
except (DagsterInvalidPropertyError, AttributeError) as e:
389+
# Special case for direct execution context when testing. This is related to:
390+
# https://github.com/dagster-io/dagster/issues/23633
391+
if "DirectOpExecutionContext" in str(e):
392+
context.log.warning("Caught an error that is likely a direct execution")
393+
return (set(models_map.keys()), models_map, None)
394+
else:
395+
raise e
396+
397+
select_models: list[str] = []
398+
models_map = {}
399+
for key, model in models.items():
400+
if sqlmesh_model_name_to_key(model.name) in selected_output_names:
401+
models_map[key] = model
402+
select_models.append(model.name)
403+
return (
404+
set(models_map.keys()),
405+
models_map,
406+
select_models,
407+
)
408+
342409
def get_controller(
343410
self,
344411
context_factory: ContextFactory[ContextCls],

dagster_sqlmesh/test_asset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ def test_sqlmesh_context_to_asset_outs(sample_sqlmesh_test_context: SQLMeshTestC
77
translator = SQLMeshDagsterTranslator()
88
outs = controller.to_asset_outs("dev", translator)
99
assert len(list(outs.deps)) == 1
10-
assert len(outs.outs) == 9
10+
assert len(outs.outs) == 10

0 commit comments

Comments
 (0)