Skip to content
Closed
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
1395bf4
add empty task
johnhoran Mar 4, 2026
95b06cd
task name
johnhoran Mar 4, 2026
ff3a1ba
skip equivelent to failure
johnhoran Mar 4, 2026
4f59350
force error
johnhoran Mar 4, 2026
f0d097b
bad image
johnhoran Mar 4, 2026
0802ef8
timeout after 2 seconds
johnhoran Mar 4, 2026
b328a7c
20 seconds
johnhoran Mar 4, 2026
ef55c48
drop empty task
johnhoran Mar 4, 2026
ad3d307
different error
johnhoran Mar 4, 2026
75b33bd
log exception
johnhoran Mar 4, 2026
01e310a
restore empty
johnhoran Mar 4, 2026
9267fa5
disable init logging
johnhoran Mar 4, 2026
b546146
remove
johnhoran Mar 4, 2026
454b071
api exception
johnhoran Mar 4, 2026
d471db4
ruff
johnhoran Mar 4, 2026
b696707
raise airflow skip on retry
johnhoran Mar 4, 2026
273464b
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2026
4fcb25b
fix c901
johnhoran Mar 4, 2026
e4be216
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2026
fc3e923
catch baseexception
johnhoran Mar 5, 2026
405c15e
Merge branch 'main' into empty-gate
johnhoran Mar 5, 2026
c5c8682
raise skip
johnhoran Mar 5, 2026
ee0f516
move empty operator
johnhoran Mar 5, 2026
6ec11a2
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
e85b0f0
make gate downstream of task group for wait_for_downstream
johnhoran Mar 5, 2026
c3a6bed
move gate upstream
johnhoran Mar 5, 2026
cb77108
mypy
johnhoran Mar 5, 2026
06ce6b2
update trigger rule
johnhoran Mar 5, 2026
c783cb3
resolve test failures
johnhoran Mar 5, 2026
40d7dfc
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
6fbf25c
Merge branch 'main' into empty-gate
johnhoran Mar 5, 2026
66161f9
missed another test
johnhoran Mar 5, 2026
001d69e
resolve tests
johnhoran Mar 5, 2026
6765bfa
force wait_for_downstream
johnhoran Mar 6, 2026
41e6d80
add some assertions
johnhoran Mar 6, 2026
01cfee3
Update tests/airflow/test_graph.py
johnhoran Mar 6, 2026
170505b
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
106acb8
copilot changes
johnhoran Mar 6, 2026
43de64a
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
6574f65
add test
johnhoran Mar 6, 2026
f2e773d
Merge branch 'main' into empty-gate
johnhoran Mar 6, 2026
ea7dc72
fix dbtfusion test
johnhoran Mar 6, 2026
5a52371
fix integration test
johnhoran Mar 6, 2026
31a7a9d
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
d0d375f
handle wait_for_downstream in tests
johnhoran Mar 9, 2026
975e2a1
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Mar 9, 2026
255b08c
fix test
johnhoran Mar 9, 2026
92e9e3d
Merge branch 'main' into empty-gate
johnhoran Mar 9, 2026
bccbb33
gate after all and add test for wiring
johnhoran Mar 10, 2026
2f90883
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2026
2134a6d
Merge branch 'main' into empty-gate
johnhoran Mar 10, 2026
466d472
test wiring on after_each
johnhoran Mar 10, 2026
6d70d11
Merge branch 'main' into empty-gate
johnhoran Mar 11, 2026
543a615
coverage
johnhoran Mar 11, 2026
f1c4711
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2026
30856a0
copilot
johnhoran Mar 11, 2026
7fa2b29
add test for coverage
johnhoran Mar 11, 2026
2306178
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2026
e5540c9
copilot
johnhoran Mar 11, 2026
fe3404d
Merge branch 'main' into empty-gate
johnhoran Mar 12, 2026
93c7ab4
Merge branch 'main' into empty-gate
johnhoran Mar 19, 2026
758d53f
fix test
johnhoran Mar 19, 2026
df54719
Merge branch 'main' into empty-gate
johnhoran Mar 20, 2026
a099ea0
fix test
johnhoran Mar 20, 2026
f84cae6
Merge branch 'main' into empty-gate
johnhoran Mar 26, 2026
9914d9c
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Mar 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from typing import Any

try: # Airflow 3
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk.bases.operator import BaseOperator
except ImportError: # Airflow 2
from airflow.models import BaseOperator
from airflow.operators.empty import EmptyOperator # type: ignore[no-redef]

from airflow.models.base import ID_LEN as AIRFLOW_MAX_ID_LENGTH
from airflow.models.dag import DAG
from airflow.utils.trigger_rule import TriggerRule

try:
# Airflow 3.1 onwards
Expand Down Expand Up @@ -679,7 +682,7 @@ def _add_watcher_producer_task(
render_config: RenderConfig | None = None,
execution_mode: ExecutionMode = ExecutionMode.WATCHER,
tests_per_model: dict[str, list[str]] | None = None,
) -> BaseOperator:
) -> tuple[BaseOperator, EmptyOperator]:
"""
Create the producer task for the watcher execution mode and add it to the tasks_map.
The producer task is the task that will be used to produce the events for the watcher execution mode.
Expand Down Expand Up @@ -711,12 +714,22 @@ def _add_watcher_producer_task(
)
producer_airflow_task = create_airflow_task(producer_task_metadata, dag, task_group=task_group)
tasks_map[PRODUCER_WATCHER_TASK_ID] = producer_airflow_task
return producer_airflow_task

producer_task_gate = EmptyOperator( # type: ignore[no-untyped-call]
task_id=f"{PRODUCER_WATCHER_TASK_ID}_gate",
dag=dag,
task_group=task_group,
trigger_rule=TriggerRule.NONE_FAILED,
depends_on_past=producer_airflow_task.depends_on_past,
)
producer_airflow_task >> producer_task_gate
return producer_airflow_task, producer_task_gate
Comment thread
johnhoran marked this conversation as resolved.


def _add_watcher_dependencies(
dag: DAG,
producer_airflow_task: BaseOperator,
producer_gate: BaseOperator,
task_args: dict[str, Any],
tasks_map: dict[str, Any],
nodes: dict[str, DbtNode] | None = None,
Expand All @@ -728,7 +741,7 @@ def _add_watcher_dependencies(
"""
for node_id, task_or_taskgroup in tasks_map.items():
# We do not want to set a dependency between the producer task and itself
if node_id == PRODUCER_WATCHER_TASK_ID:
if node_id == PRODUCER_WATCHER_TASK_ID or node_id == f"{PRODUCER_WATCHER_TASK_ID}_gate":
continue

node_tasks = (
Expand Down Expand Up @@ -758,6 +771,10 @@ def _add_watcher_dependencies(
for task in always_run_tasks:
task.trigger_rule = task_args.get("trigger_rule", "always") # type: ignore[attr-defined]
Comment thread
johnhoran marked this conversation as resolved.

# If depends_on_past isn't true then gating all the tasks isn't really needed.
if producer_airflow_task.wait_for_downstream and not task_or_taskgroup.downstream_task_ids:
task_or_taskgroup >> producer_gate
Comment on lines +774 to +776
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gate wiring only attaches leaf tasks to producer_gate when not task_or_taskgroup.downstream_task_ids. This means tasks that already have downstream edges (e.g., the AFTER_ALL test task added earlier, or any other downstream added during graph build) will not be gated, so dbt_producer_watcher_gate won’t reliably represent “all watcher work completed” when depends_on_past/wait_for_downstream is being used. Consider determining dbt leaves from the dbt graph (e.g., calculate_leaves(...)) and always wiring those leaves (or the AFTER_ALL test task, if present) into the gate when producer_airflow_task.wait_for_downstream is enabled.

Copilot uses AI. Check for mistakes.

Comment thread
johnhoran marked this conversation as resolved.

def should_create_detached_nodes(render_config: RenderConfig) -> bool:
"""
Expand Down Expand Up @@ -908,7 +925,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
task_groups: dict[str, TaskGroup] = {}
task_or_group: TaskGroup | BaseOperator | None
parent_task_group = task_group
producer_task: BaseOperator | None = None
producer_tasks: tuple[BaseOperator, EmptyOperator] | None = None

# Identify test nodes that should be run detached from the associated dbt resource nodes because they
# have multiple parents
Expand All @@ -926,7 +943,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
# We are intentionally creating the producer task ahead of the consumer tasks
# Airflow priority weight is not being respected in multiple versions of the library, including 3.1
# To instantiate the producer before helps having it before on the DAG topological order and scheduling this task before the consumer tasks
producer_task = _add_watcher_producer_task(
producer_tasks = _add_watcher_producer_task(
dag=dag,
task_args={**task_args, **setup_operator_args},
tasks_map=tasks_map,
Expand Down Expand Up @@ -988,6 +1005,9 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes)
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task
if producer_tasks and producer_tasks[0].depends_on_past:
test_task >> producer_tasks[1]
test_task.wait_for_downstream = True
elif render_config.test_behavior in (TestBehavior.BUILD, TestBehavior.AFTER_EACH):
# Handle detached test nodes
for node_id, node in detached_nodes.items():
Expand All @@ -1012,10 +1032,11 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro

create_airflow_task_dependencies(nodes, tasks_map)

if producer_task:
if producer_tasks:
_add_watcher_dependencies(
dag=dag,
producer_airflow_task=producer_task,
producer_airflow_task=producer_tasks[0],
producer_gate=producer_tasks[1],
task_args=task_args,
tasks_map=tasks_map,
nodes=nodes,
Expand Down
8 changes: 5 additions & 3 deletions cosmos/operators/_watcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ def __init__(
self.deferrable = deferrable
self.model_unique_id = extra_context.get("dbt_node_config", {}).get("unique_id")

if self.depends_on_past:
self.wait_for_downstream = True

@staticmethod
def _filter_flags(flags: list[str]) -> list[str]:
"""Filters out dbt flags that are incompatible with retry (e.g., --select, --exclude)."""
Expand Down Expand Up @@ -502,11 +505,10 @@ def poke(self, context: Context) -> bool:
_log_dbt_event(dbt_events)

if status is None:

if producer_task_state == "failed":
if producer_task_state == "failed" or producer_task_state == "skipped":
if self.poke_retry_number > 0:
raise AirflowException(
f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details."
f"The dbt build command {producer_task_state} in the producer task. Please check the log of task {self.producer_task_id} for details."
Comment thread
johnhoran marked this conversation as resolved.
Comment thread
johnhoran marked this conversation as resolved.
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new exception message is ungrammatical for the skipped case: "The dbt build command skipped in the producer task...". Consider adding a verb (e.g., "was {state}" / "{state} before completing") so the message reads correctly for both "failed" and "skipped".

Suggested change
f"The dbt build command {producer_task_state} in the producer task. Please check the log of task {self.producer_task_id} for details."
f"The dbt build command was {producer_task_state} in the producer task. Please check the log of task {self.producer_task_id} for details."

Copilot uses AI. Check for mistakes.
)
Comment on lines 563 to 568
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic now treats a skipped producer as a failure signal. There should be test coverage for the producer_task_state == "skipped" branch (including the two paths for poke_retry_number > 0 vs == 0) to ensure sensors reliably fall back to non-watcher execution rather than looping until timeout.

Copilot uses AI. Check for mistakes.
Comment thread
johnhoran marked this conversation as resolved.
Comment thread
johnhoran marked this conversation as resolved.
else:
# This handles the scenario of tasks that failed with `State.UPSTREAM_FAILED`
Expand Down
6 changes: 3 additions & 3 deletions cosmos/operators/_watcher/triggerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@


class WatcherTrigger(BaseTrigger):

def __init__(
self,
model_unique_id: str,
Expand Down Expand Up @@ -213,10 +212,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
event_data["compiled_sql"] = compiled_sql
yield TriggerEvent(event_data) # type: ignore[no-untyped-call]
return
elif producer_task_state == "failed":
elif producer_task_state == "failed" or producer_task_state == "skipped":
logger.error(
"Watcher producer task '%s' failed before delivering results for node '%s'",
"Watcher producer task '%s' %s before delivering results for node '%s'",
self.producer_task_id,
producer_task_state,
self.model_unique_id,
)
Comment thread
johnhoran marked this conversation as resolved.
Comment on lines 240 to 245
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error log message now reads "Watcher producer task '%s' %s before delivering results...", which is missing a verb and produces awkward output (e.g., "... skipped before ..."). Consider changing it to something like "was %s" or expanding the message so it remains grammatically correct for both states.

Copilot uses AI. Check for mistakes.
yield TriggerEvent({"status": EventStatus.FAILED, "reason": "producer_failed"}) # type: ignore[no-untyped-call]
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change adds a new terminal path when producer_task_state == "skipped", but there isn’t corresponding coverage verifying that a skipped producer yields reason="producer_failed" (similar to the existing failed case). Add/adjust tests for the skipped producer state to ensure triggers behave consistently across Airflow 2/3.

Copilot uses AI. Check for mistakes.
Expand Down
22 changes: 13 additions & 9 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException

from cosmos.config import ProfileConfig
from cosmos.operators._watcher import _parse_compressed_xcom, safe_xcom_push
Expand Down Expand Up @@ -120,6 +120,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
if self.invocation_mode == InvocationMode.SUBPROCESS:
self.log_format = "json"

if self.depends_on_past:
self.wait_for_downstream = True

Comment on lines +109 to +111
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This forces wait_for_downstream=True whenever depends_on_past is enabled, overriding any explicit user configuration passed via operator_args/kwargs. If this is intentional, it should be documented as a behavioral constraint; otherwise, consider only enabling wait_for_downstream when the user didn’t explicitly set it (e.g., by checking for the kwarg before super().__init__).

Copilot uses AI. Check for mistakes.
@staticmethod
def _serialize_event(event_message: EventMsg) -> dict[str, Any]:
"""Convert structured dbt EventMsg to plain dict."""
Expand Down Expand Up @@ -187,12 +190,10 @@ def execute(self, context: Context, **kwargs: Any) -> Any:
try_number = getattr(task_instance, "try_number", 1)

if try_number > 1:
self.log.info(
"Dbt WATCHER producer task does not support Airflow retries. "
"Detected attempt #%s; skipping execution to avoid running a second dbt build.",
try_number,
raise AirflowSkipException(
"DbtProducerWatcherOperator does not support Airflow retries. "
f"Detected attempt #{try_number}; skipping execution to avoid running a second dbt build."
)
Comment thread
johnhoran marked this conversation as resolved.
return None

self.log.info(
"Dbt WATCHER producer task forces Airflow retries to 0 so the dbt build only runs once; "
Expand Down Expand Up @@ -238,9 +239,10 @@ def _callback(event_message: EventMsg) -> None:
safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed")
return return_value

except Exception:
except Exception as e:
safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed")
raise
self.log.exception("DbtProducerWatcherOperator execution failed")
raise AirflowSkipException("Skipping execution due to task failure") from e
Comment thread
johnhoran marked this conversation as resolved.
Comment thread
johnhoran marked this conversation as resolved.
Comment on lines +223 to +226
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The broad except Exception converts all exceptions from super().execute() into AirflowSkipException, including an AirflowSkipException intentionally raised by lower layers (e.g., skip_exit_code in DbtLocalBaseOperator). That will incorrectly log it as a failure and replace the original skip reason/message. Consider explicitly re-raising AirflowSkipException (and any other control-flow exceptions you want preserved) before the generic handler, similar to the Kubernetes variant.

Copilot uses AI. Check for mistakes.


class DbtConsumerWatcherSensor(BaseConsumerSensor, DbtRunLocalOperator): # type: ignore[misc]
Expand Down Expand Up @@ -352,6 +354,8 @@ class DbtTestWatcherOperator(EmptyOperator):
"""

def __init__(self, *args: Any, **kwargs: Any):
default_args = kwargs.get("default_args", {})
desired_keys = ("dag", "task_group", "task_id")
new_kwargs = {key: value for key, value in kwargs.items() if key in desired_keys}
super().__init__(**new_kwargs) # type: ignore[no-untyped-call]
depends_on_past = kwargs.get("depends_on_past", False) or default_args.get("depends_on_past", False)
super().__init__(depends_on_past=depends_on_past, wait_for_downstream=depends_on_past, **new_kwargs) # type: ignore[no-untyped-call]
Comment thread
johnhoran marked this conversation as resolved.
Outdated
30 changes: 22 additions & 8 deletions cosmos/operators/watcher_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from airflow.utils.context import Context # type: ignore[attr-defined]

import kubernetes.client as k8s
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred
from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback, client_type

try:
Expand Down Expand Up @@ -44,7 +44,6 @@


class WatcherKubernetesCallback(KubernetesPodOperatorCallback): # type: ignore[misc]

@staticmethod
def progress_callback(
*,
Expand Down Expand Up @@ -74,7 +73,6 @@ def progress_callback(


class DbtProducerWatcherKubernetesOperator(DbtBuildKubernetesOperator):

template_fields: tuple[str, ...] = tuple(DbtBuildKubernetesOperator.template_fields) + ("deferrable",)
_process_log_line_callable: Callable[[str, dict[str, Any]], None] | None = store_dbt_resource_status_from_log

Expand All @@ -99,6 +97,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(task_id=task_id, *args, **kwargs)
self.dbt_cmd_flags += ["--log-format", "json"]

if self.depends_on_past:
self.wait_for_downstream = True

@cached_property
def pod_manager(self) -> CosmosKubernetesPodManager:
return CosmosKubernetesPodManager(kube_client=self.client, callbacks=self.callbacks)
Expand All @@ -113,18 +114,31 @@ def execute(self, context: Context, **kwargs: Any) -> Any:
try_number = getattr(task_instance, "try_number", 1)

if try_number > 1:
self.log.info(
raise AirflowSkipException(
"DbtProducerWatcherKubernetesOperator does not support Airflow retries. "
"Detected attempt #%s; skipping execution to avoid running a second dbt build.",
try_number,
f"Detected attempt #{try_number}; skipping execution to avoid running a second dbt build."
)
return None

# This global variable is used to make the task context available to the K8s callback.
# While the callback is set during the operator initialization, the context is only created during the operator's execution.
global producer_task_context
producer_task_context = context
return super().execute(context, **kwargs)
try:
return super().execute(context, **kwargs)
except (AirflowSkipException, TaskDeferred):
raise
except Exception as e:
self.log.exception("Dbt execution failed")
raise AirflowSkipException("Skipping execution due to task failure") from e
Comment thread
johnhoran marked this conversation as resolved.
Comment thread
johnhoran marked this conversation as resolved.

def trigger_reentry(self, *args: Any, **kwargs: Any) -> Any:
try:
return super().trigger_reentry(*args, **kwargs)
except (AirflowSkipException, TaskDeferred):
raise
except Exception as e:
self.log.exception("Dbt execution failed")
raise AirflowSkipException("Skipping execution due to task failure") from e
Comment thread
johnhoran marked this conversation as resolved.


class DbtConsumerWatcherKubernetesSensor(BaseConsumerSensor, DbtRunKubernetesOperator):
Expand Down
83 changes: 78 additions & 5 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@
tags=["nightly"],
config={"materialized": "table", "meta": {"cosmos": {"operator_kwargs": {"pool": "custom_pool"}}}},
)

sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node, child2_node]
sample_nodes = {node.unique_id: node for node in sample_nodes_list}

Expand Down Expand Up @@ -1178,13 +1177,87 @@ def test_test_behavior_for_watcher_mode(test_behavior):
tasks = dag.tasks
if test_behavior == TestBehavior.NONE:
for task in tasks:
assert not isinstance(task, DbtTestWatcherOperator or DbtTestLocalOperator)
assert len(tasks) == 5
if test_behavior == TestBehavior.AFTER_EACH:
assert not isinstance(task, (DbtTestWatcherOperator, DbtTestLocalOperator))
assert len(tasks) == 6
if test_behavior == TestBehavior.AFTER_EACH:
assert len(tasks) == 7
if test_behavior == TestBehavior.AFTER_ALL:
assert any(isinstance(task, DbtTestLocalOperator) for task in tasks)
assert len(tasks) == 6
assert len(tasks) == 7


@pytest.mark.parametrize("depends_on_past", [False, True])
@pytest.mark.parametrize("test_behavior", [TestBehavior.NONE, TestBehavior.AFTER_EACH, TestBehavior.AFTER_ALL])
def test_watcher_dependency_wiring(test_behavior, depends_on_past):
with DAG("test-id", start_date=datetime(2022, 1, 1), default_args={"depends_on_past": depends_on_past}) as dag:
task_args = {
"project_dir": SAMPLE_PROJ_PATH,
"conn_id": "fake_conn",
"profile_config": ProfileConfig(
profile_name="default",
target_name="default",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="fake_conn",
profile_args={"schema": "public"},
),
),
}

child_2b = DbtNode(
unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.child2.v2_b",
resource_type=DbtResourceType.MODEL,
depends_on=[parent_node.unique_id],
path_base=SAMPLE_PROJ_PATH,
original_file_path=Path("gen3/models/child2_v2.sql"),
tags=["nightly"],
config={"materialized": "table", "meta": {"cosmos": {"operator_kwargs": {"pool": "custom_pool"}}}},
has_test=True,
has_non_detached_test=True,
)
child_2b_test = DbtNode(
unique_id=f"{DbtResourceType.TEST.value}.{SAMPLE_PROJ_PATH.stem}.child2.test_v2_b",
resource_type=DbtResourceType.TEST,
depends_on=[child_2b.unique_id],
path_base=Path("."),
original_file_path=Path("."),
)

build_airflow_graph(
nodes={child_2b.unique_id: child_2b, child_2b_test.unique_id: child_2b_test, **sample_nodes},
dag=dag,
execution_mode=ExecutionMode.WATCHER,
test_indirect_selection=TestIndirectSelection.EAGER,
task_args=task_args,
render_config=RenderConfig(
test_behavior=test_behavior,
),
dbt_project_name="astro_shop",
)
if not depends_on_past:
assert dag.task_dict["dbt_producer_watcher_gate"].upstream_task_ids == {"dbt_producer_watcher"}
assert all(task.wait_for_downstream is False for task in dag.tasks)
return

assert all(task.wait_for_downstream is True for task in dag.tasks if task.task_id != "dbt_producer_watcher_gate")
if test_behavior == TestBehavior.NONE:
assert dag.task_dict["dbt_producer_watcher_gate"].upstream_task_ids == {
"child_run",
"dbt_producer_watcher",
"child2_v2_run",
"child2_v2_b_run",
}
if test_behavior == TestBehavior.AFTER_EACH:
assert dag.task_dict["dbt_producer_watcher_gate"].upstream_task_ids == {
"child_run",
"dbt_producer_watcher",
"child2_v2_run",
"child2_v2_b.test",
}
if test_behavior == TestBehavior.AFTER_ALL:
assert dag.task_dict["dbt_producer_watcher_gate"].upstream_task_ids == {
"dbt_producer_watcher",
"astro_shop_test",
}


def test_custom_meta():
Expand Down
Loading
Loading