|
36 | 36 | AirflowSkipException,
|
37 | 37 | AirflowTaskTerminated,
|
38 | 38 | )
|
39 |
| -from airflow.sdk import DAG, BaseOperator, Connection, get_current_context |
| 39 | +from airflow.sdk import DAG, BaseOperator, Connection, get_current_context, get_parsing_context |
40 | 40 | from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
|
41 | 41 | from airflow.sdk.definitions.variable import Variable
|
42 | 42 | from airflow.sdk.execution_time.comms import (
|
@@ -542,6 +542,52 @@ def execute(self, context):
|
542 | 542 | )
|
543 | 543 |
|
544 | 544 |
|
| 545 | +def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch, test_dags_dir): |
| 546 | + """ |
| 547 | + Test that the DAG parsing context is correctly set during the startup process. |
| 548 | +
|
| 549 | + This test verifies that the DAG and task IDs are correctly set in the parsing context |
| 550 | + when a DAG is started up. |
| 551 | + """ |
| 552 | + dag_id = "dag_parsing_context_test" |
| 553 | + task_id = "conditional_task" |
| 554 | + |
| 555 | + what = StartupDetails( |
| 556 | + ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), |
| 557 | + dag_rel_path="dag_parsing_context.py", |
| 558 | + bundle_info=BundleInfo(name="my-bundle", version=None), |
| 559 | + requests_fd=0, |
| 560 | + ti_context=make_ti_context(dag_id=dag_id, run_id="c"), |
| 561 | + ) |
| 562 | + |
| 563 | + mock_supervisor_comms.get_message.return_value = what |
| 564 | + |
| 565 | + # Ensure the parsing context is initially empty |
| 566 | + assert get_parsing_context().dag_id is None |
| 567 | + assert get_parsing_context().task_id is None |
| 568 | + |
| 569 | + # Set the environment variable for DAG bundles |
| 570 | + # We use the DAG defined in `task_sdk/tests/dags/dag_parsing_context.py` for this test! |
| 571 | + dag_bundle_val = json.dumps( |
| 572 | + [ |
| 573 | + { |
| 574 | + "name": "my-bundle", |
| 575 | + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", |
| 576 | + "kwargs": {"local_folder": str(test_dags_dir), "refresh_interval": 1}, |
| 577 | + } |
| 578 | + ] |
| 579 | + ) |
| 580 | + |
| 581 | + monkeypatch.setenv("AIRFLOW__DAG_BUNDLES__BACKENDS", dag_bundle_val) |
| 582 | + ti, _ = startup() |
| 583 | + |
| 584 | + # Presence of `conditional_task` below means DAG ID is properly set in the parsing context! |
| 585 | + # Check the dag file for the actual logic! |
| 586 | + assert ti.task.dag.task_dict.keys() == {"visible_task", "conditional_task"} |
| 587 | + assert get_parsing_context().dag_id == dag_id |
| 588 | + assert get_parsing_context().task_id == task_id |
| 589 | + |
| 590 | + |
545 | 591 | class TestRuntimeTaskInstance:
|
546 | 592 | def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context):
|
547 | 593 | """Test get_template_context without ti_context_from_server."""
|
|
0 commit comments