Skip to content

Commit b328d77

Browse files
authored
feat: allow restating models (#25)
* feat: allow restating models * chore: bump version
1 parent 4978a6f commit b328d77

File tree

10 files changed

+208
-40
lines changed

10 files changed

+208
-40
lines changed

dagster_sqlmesh/conftest.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def plan_and_run(
106106
enable_debug_console: bool = False,
107107
start: t.Optional[TimeLike] = None,
108108
end: t.Optional[TimeLike] = None,
109-
restate_models: t.Optional[t.List[str]] = None,
109+
select_models: t.Optional[t.List[str]] = None,
110+
restate_selected: bool = False,
111+
skip_run: bool = False,
110112
):
111113
"""Runs plan and run on SQLMesh with the given configuration and record all of the generated events.
112114
@@ -135,19 +137,16 @@ def plan_and_run(
135137
if execution_time:
136138
plan_options["execution_time"] = execution_time
137139
run_options["execution_time"] = execution_time
138-
if restate_models:
139-
plan_options["restate_models"] = restate_models
140-
if start:
141-
plan_options["start"] = start
142-
run_options["start"] = start
143-
if end:
144-
plan_options["end"] = end
145-
run_options["end"] = end
146140

147141
for event in controller.plan_and_run(
148142
environment,
143+
start=start,
144+
end=end,
145+
select_models=select_models,
146+
restate_selected=restate_selected,
149147
plan_options=plan_options,
150148
run_options=run_options,
149+
skip_run=skip_run,
151150
):
152151
recorder(event)
153152

dagster_sqlmesh/console.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from sqlglot.expressions import Alter
99
from sqlmesh.core.console import Console
10+
from sqlmesh.core.model import Model
1011
from sqlmesh.core.context_diff import ContextDiff
1112
from sqlmesh.core.environment import EnvironmentNamingInfo
1213
from sqlmesh.core.linter.rule import RuleViolation
@@ -563,8 +564,7 @@ def show_table_diff_summary(self, table_diff: TableDiff) -> None:
563564
self.publish(ShowTableDiffSummary(table_diff))
564565

565566
def show_linter_violations(
566-
self,
567-
violations: list[RuleViolation],
567+
self, violations: list[RuleViolation], model: Model, is_error: bool = False
568568
) -> None:
569569
self.publish(LogWarning("Linting violations found", str(violations)))
570570

dagster_sqlmesh/controller/base.py

+53-4
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class PlanOptions(t.TypedDict):
3030
execution_time: t.NotRequired[TimeLike]
3131
create_from: t.NotRequired[str]
3232
skip_tests: t.NotRequired[bool]
33-
restate_models: t.NotRequired[t.Iterable[str]]
33+
restate_models: t.NotRequired[t.Collection[str]]
3434
no_gaps: t.NotRequired[bool]
3535
skip_backfill: t.NotRequired[bool]
3636
forward_only: t.NotRequired[bool]
@@ -276,19 +276,57 @@ def run_sqlmesh_thread(
276276

277277
def plan_and_run(
278278
self,
279+
*,
280+
select_models: list[str] | None = None,
281+
restate_selected: bool = False,
282+
start: TimeLike | None = None,
283+
end: TimeLike | None = None,
279284
categorizer: t.Optional[SnapshotCategorizer] = None,
280285
default_catalog: t.Optional[str] = None,
281286
plan_options: t.Optional[PlanOptions] = None,
282287
run_options: t.Optional[RunOptions] = None,
288+
skip_run: bool = False,
283289
):
284-
run_options = run_options or {}
285-
plan_options = plan_options or {}
290+
"""Executes a plan and run operation
291+
292+
This is an opinionated interface for running a plan and run operation in
293+
a single thread. It is recommended to use this method for most use cases.
294+
"""
295+
run_options = run_options or RunOptions()
296+
plan_options = plan_options or PlanOptions()
297+
298+
if plan_options.get("select_models") or run_options.get("select_models"):
299+
raise ValueError(
300+
"select_models should not be set in plan_options or run_options use the `select_models` or `select_models_func` arguments instead"
301+
)
302+
if plan_options.get("restate_models"):
303+
raise ValueError(
304+
"restate_models should not be set in plan_options use the `restate_selected` argument with `select_models` or `select_models_func` instead"
305+
)
306+
select_models = select_models or []
307+
308+
if start:
309+
plan_options["start"] = start
310+
run_options["start"] = start
311+
if end:
312+
plan_options["end"] = end
313+
run_options["end"] = end
314+
315+
if select_models:
316+
if restate_selected:
317+
plan_options["restate_models"] = select_models
318+
plan_options["select_models"] = select_models
319+
else:
320+
plan_options["select_models"] = select_models
321+
run_options["select_models"] = select_models
286322

287323
try:
288324
self.logger.debug("starting sqlmesh plan")
325+
self.logger.debug(f"selected models: {select_models}")
289326
yield from self.plan(categorizer, default_catalog, **plan_options)
290327
self.logger.debug("starting sqlmesh run")
291-
yield from self.run(**run_options)
328+
if not skip_run:
329+
yield from self.run(**run_options)
292330
except Exception as e:
293331
self.logger.error(f"Error during sqlmesh plan and run: {e}")
294332
raise e
@@ -442,15 +480,26 @@ def plan(
442480
def plan_and_run(
443481
self,
444482
environment: str,
483+
*,
445484
categorizer: t.Optional[SnapshotCategorizer] = None,
485+
select_models: list[str] | None = None,
486+
restate_selected: bool = False,
487+
start: TimeLike | None = None,
488+
end: TimeLike | None = None,
446489
default_catalog: t.Optional[str] = None,
447490
plan_options: t.Optional[PlanOptions] = None,
448491
run_options: t.Optional[RunOptions] = None,
492+
skip_run: bool = False,
449493
):
450494
with self.instance(environment, "plan_and_run") as mesh:
451495
yield from mesh.plan_and_run(
496+
start=start,
497+
end=end,
498+
select_models=select_models,
499+
restate_selected=restate_selected,
452500
categorizer=categorizer,
453501
default_catalog=default_catalog,
454502
plan_options=plan_options,
455503
run_options=run_options,
504+
skip_run=skip_run,
456505
)

dagster_sqlmesh/resource.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99
from sqlmesh import Model
1010
from sqlmesh.utils.dag import DAG
11+
from sqlmesh.utils.date import TimeLike
1112
from sqlmesh.core.snapshot import Snapshot
1213
from sqlmesh.core.context import Context as SQLMeshContext
1314

@@ -234,7 +235,11 @@ class SQLMeshResource(ConfigurableResource):
234235
def run(
235236
self,
236237
context: AssetExecutionContext,
238+
*,
237239
environment: str = "dev",
240+
start: TimeLike | None = None,
241+
end: TimeLike | None = None,
242+
restate_selected: bool = False,
238243
plan_options: t.Optional[PlanOptions] = None,
239244
run_options: t.Optional[RunOptions] = None,
240245
) -> t.Iterable[MaterializeResult]:
@@ -248,9 +253,7 @@ def run(
248253
with controller.instance(environment) as mesh:
249254
dag = mesh.models_dag()
250255

251-
plan_options["select_models"] = []
252-
plan_options["backfill_models"] = []
253-
run_options["select_models"] = []
256+
select_models = []
254257

255258
models = mesh.models()
256259
models_map = models.copy()
@@ -264,15 +267,17 @@ def run(
264267
logger.info(f"selected model: {model.name}")
265268

266269
models_map[key] = model
267-
plan_options["select_models"].append(model.name)
268-
plan_options["backfill_models"].append(model.name)
269-
run_options["select_models"].append(model.name)
270+
select_models.append(model.name)
270271

271272
event_handler = DagsterSQLMeshEventHandler(
272273
context, models_map, dag, "sqlmesh: "
273274
)
274275

275276
for event in mesh.plan_and_run(
277+
start=start,
278+
end=end,
279+
select_models=select_models,
280+
restate_selected=restate_selected,
276281
plan_options=plan_options,
277282
run_options=run_options,
278283
):

dagster_sqlmesh/test_asset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ def test_sqlmesh_context_to_asset_outs(sample_sqlmesh_test_context: SQLMeshTestC
99
translator = SQLMeshDagsterTranslator()
1010
outs = controller.to_asset_outs("dev", translator)
1111
assert len(list(outs.deps)) == 1
12-
assert len(outs.outs) == 7
12+
assert len(outs.outs) == 9

dagster_sqlmesh/test_sqlmesh_context.py

+70
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,73 @@ def test_sqlmesh_context(sample_sqlmesh_test_context: SQLMeshTestContext):
122122
"""
123123
)
124124
assert test_source_model_count[0][0] == 6
125+
126+
127+
def test_restating_models(sample_sqlmesh_test_context: SQLMeshTestContext):
128+
sample_sqlmesh_test_context.plan_and_run(
129+
environment="dev",
130+
start="2023-01-01",
131+
end="2024-01-01",
132+
execution_time="2024-01-02",
133+
)
134+
135+
count_query = sample_sqlmesh_test_context.query(
136+
"""
137+
SELECT COUNT(*) FROM sqlmesh_example__dev.staging_model_4
138+
"""
139+
)
140+
assert count_query[0][0] == 366
141+
142+
feb_sum_query = sample_sqlmesh_test_context.query(
143+
"""
144+
SELECT SUM(value) FROM sqlmesh_example__dev.staging_model_4 WHERE time >= '2023-02-01' AND time < '2023-02-28'
145+
"""
146+
)
147+
march_sum_query = sample_sqlmesh_test_context.query(
148+
"""
149+
SELECT SUM(value) FROM sqlmesh_example__dev.staging_model_4 WHERE time >= '2023-03-01' AND time < '2023-03-31'
150+
"""
151+
)
152+
intermediate_2_query = sample_sqlmesh_test_context.query(
153+
"""
154+
SELECT * FROM sqlmesh_example__dev.intermediate_model_2
155+
"""
156+
)
157+
158+
# Restate the model for the month of March
159+
sample_sqlmesh_test_context.plan_and_run(
160+
environment="dev",
161+
start="2023-03-01",
162+
end="2023-03-31",
163+
execution_time="2024-01-02",
164+
select_models=["sqlmesh_example.staging_model_4"],
165+
restate_selected=True,
166+
skip_run=True,
167+
)
168+
169+
# Check that the sum of values for February and March are the same
170+
feb_sum_query_restate = sample_sqlmesh_test_context.query(
171+
"""
172+
SELECT SUM(value) FROM sqlmesh_example__dev.staging_model_4 WHERE time >= '2023-02-01' AND time < '2023-02-28'
173+
"""
174+
)
175+
march_sum_query_restate = sample_sqlmesh_test_context.query(
176+
"""
177+
SELECT SUM(value) FROM sqlmesh_example__dev.staging_model_4 WHERE time >= '2023-03-01' AND time < '2023-03-31'
178+
"""
179+
)
180+
intermediate_2_query_restate = sample_sqlmesh_test_context.query(
181+
"""
182+
SELECT * FROM sqlmesh_example__dev.intermediate_model_2
183+
"""
184+
)
185+
186+
assert (
187+
feb_sum_query_restate[0][0] == feb_sum_query[0][0]
188+
), "February sum should not change"
189+
assert (
190+
march_sum_query_restate[0][0] != march_sum_query[0][0]
191+
), "March sum should change"
192+
assert (
193+
intermediate_2_query_restate[0][0] == intermediate_2_query[0][0]
194+
), "Intermediate model should not change during restate"

pyproject.toml

+5-8
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
[project]
22
name = "dagster-sqlmesh"
3-
version = "0.8.0"
3+
version = "0.9.0"
44
description = ""
5-
authors = [
6-
{name = "Reuven Gonzales", email = "[email protected]"}
7-
]
8-
license = {text = "Apache-2.0"}
5+
authors = [{ name = "Reuven Gonzales", email = "[email protected]" }]
6+
license = { text = "Apache-2.0" }
97
readme = "README.md"
108
requires-python = ">=3.11,<3.13"
119
dependencies = [
1210
"dagster>=1.7.8",
13-
"sqlmesh<1.0",
11+
"sqlmesh==0.164.0",
1412
"pytest>=8.3.2",
1513
"pyarrow>=18.0.0",
1614
]
@@ -38,6 +36,5 @@ exclude = [
3836
"**/.github",
3937
"**/.vscode",
4038
"**/.idea",
41-
"**/.pytest_cache",
39+
"**/.pytest_cache",
4240
]
43-
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
MODEL (
2+
name sqlmesh_example.intermediate_model_2,
3+
kind FULL
4+
);
5+
6+
SELECT
7+
SUM(parent.value) as value
8+
FROM sqlmesh_example.staging_model_4 AS parent
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import typing as t
2+
from datetime import datetime
3+
4+
import pandas as pd
5+
from sqlmesh import ExecutionContext, model
6+
from sqlmesh.core.model import ModelKindName
7+
import numpy as np
8+
9+
10+
@model(
11+
name="sqlmesh_example.staging_model_4",
12+
is_sql=False,
13+
columns={
14+
"time": "TIMESTAMP",
15+
"value": "DOUBLE",
16+
},
17+
kind={"name": ModelKindName.INCREMENTAL_BY_TIME_RANGE, "time_column": "time"},
18+
start="2023-01-01",
19+
)
20+
def staging_model_4(
21+
context: ExecutionContext,
22+
start: datetime,
23+
end: datetime,
24+
**kwargs,
25+
) -> t.Generator[pd.DataFrame, None, None]:
26+
# Generates a set of random rows for the model based on the start and end dates
27+
date_range = pd.date_range(start=start, end=end, freq="D")
28+
num_days = len(date_range)
29+
30+
data = {
31+
"time": date_range,
32+
"value": np.random.rand(num_days)
33+
* 100, # Random double values between 0 and 100
34+
}
35+
36+
df = pd.DataFrame(data)
37+
yield df

0 commit comments

Comments
 (0)