Skip to content

Commit d632e3d

Browse files
committed
AIP-72: Support DAG parsing context in Task SDK
closes #45693
1 parent e164f3c commit d632e3d

File tree

13 files changed

+126
-40
lines changed

13 files changed

+126
-40
lines changed

airflow/cli/commands/remote_commands/dag_command.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
from airflow.jobs.job import Job
3838
from airflow.models import DagBag, DagModel, DagRun, TaskInstance
3939
from airflow.models.serialized_dag import SerializedDagModel
40+
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
4041
from airflow.utils import cli as cli_utils, timezone
4142
from airflow.utils.cli import get_dag, process_subdir, suppress_logs_and_warning
42-
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
4343
from airflow.utils.dot_renderer import render_dag, render_dag_dependencies
4444
from airflow.utils.helpers import ask_yesno
4545
from airflow.utils.providers_configuration_loader import providers_configuration_loaded

airflow/task/standard_task_runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@
3333
from airflow.configuration import conf
3434
from airflow.exceptions import AirflowConfigException
3535
from airflow.models.taskinstance import TaskReturnCode
36+
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
3637
from airflow.settings import CAN_FORK
3738
from airflow.stats import Stats
3839
from airflow.utils.configuration import tmp_configuration_copy
39-
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
4040
from airflow.utils.log.logging_mixin import LoggingMixin
4141
from airflow.utils.net import get_hostname
4242
from airflow.utils.platform import IS_WINDOWS, getuser

docs/apache-airflow/howto/dynamic-dag-generation.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ of the context are set to ``None``.
207207
:emphasize-lines: 4,8,9
208208
209209
from airflow.models.dag import DAG
210-
from airflow.utils.dag_parsing_context import get_parsing_context
210+
from airflow.sdk import get_parsing_context
211211
212212
current_dag_id = get_parsing_context().dag_id
213213

providers/src/airflow/providers/celery/executors/celery_executor_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
from airflow.configuration import conf
4545
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout
4646
from airflow.executors.base_executor import BaseExecutor
47+
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
4748
from airflow.stats import Stats
48-
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
4949
from airflow.utils.log.logging_mixin import LoggingMixin
5050
from airflow.utils.net import get_hostname
5151
from airflow.utils.providers_configuration_loader import providers_configuration_loaded

scripts/cov/core_coverage.py

-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@
104104
"airflow/utils/code_utils.py",
105105
"airflow/utils/context.py",
106106
"airflow/utils/dag_cycle_tester.py",
107-
"airflow/utils/dag_parsing_context.py",
108107
"airflow/utils/dates.py",
109108
"airflow/utils/db.py",
110109
"airflow/utils/db_cleanup.py",

task_sdk/src/airflow/sdk/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"dag",
2828
"Connection",
2929
"get_current_context",
30+
"get_parsing_context",
3031
"__version__",
3132
]
3233

@@ -35,7 +36,7 @@
3536
if TYPE_CHECKING:
3637
from airflow.sdk.definitions.baseoperator import BaseOperator
3738
from airflow.sdk.definitions.connection import Connection
38-
from airflow.sdk.definitions.context import get_current_context
39+
from airflow.sdk.definitions.context import get_current_context, get_parsing_context
3940
from airflow.sdk.definitions.dag import DAG, dag
4041
from airflow.sdk.definitions.edges import EdgeModifier, Label
4142
from airflow.sdk.definitions.taskgroup import TaskGroup
@@ -50,6 +51,7 @@
5051
"Connection": ".definitions.connection",
5152
"Variable": ".definitions.variable",
5253
"get_current_context": ".definitions.context",
54+
"get_parsing_context": ".definitions.context",
5355
}
5456

5557

airflow/utils/dag_parsing_context.py task_sdk/src/airflow/sdk/definitions/_internal/dag_parsing_context.py

+1-24
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,8 @@
1818

1919
import os
2020
from contextlib import contextmanager
21-
from typing import NamedTuple
2221

23-
24-
class AirflowParsingContext(NamedTuple):
25-
"""
26-
Context of parsing for the DAG.
27-
28-
If these values are not None, they will contain the specific DAG and Task ID that Airflow is requesting to
29-
execute. You can use these for optimizing dynamically generated DAG files.
30-
"""
31-
32-
dag_id: str | None
33-
task_id: str | None
34-
35-
36-
_AIRFLOW_PARSING_CONTEXT_DAG_ID = "_AIRFLOW_PARSING_CONTEXT_DAG_ID"
37-
_AIRFLOW_PARSING_CONTEXT_TASK_ID = "_AIRFLOW_PARSING_CONTEXT_TASK_ID"
22+
from airflow.sdk.definitions.context import _AIRFLOW_PARSING_CONTEXT_DAG_ID, _AIRFLOW_PARSING_CONTEXT_TASK_ID
3823

3924

4025
@contextmanager
@@ -50,11 +35,3 @@ def _airflow_parsing_context_manager(dag_id: str | None = None, task_id: str | N
5035
os.environ[_AIRFLOW_PARSING_CONTEXT_TASK_ID] = old_task_id
5136
if old_dag_id is not None:
5237
os.environ[_AIRFLOW_PARSING_CONTEXT_DAG_ID] = old_dag_id
53-
54-
55-
def get_parsing_context() -> AirflowParsingContext:
56-
"""Return the current (DAG) parsing context info."""
57-
return AirflowParsingContext(
58-
dag_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID),
59-
task_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID),
60-
)

task_sdk/src/airflow/sdk/definitions/context.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20-
from typing import TYPE_CHECKING, Any, TypedDict
20+
import os
21+
from typing import TYPE_CHECKING, Any, NamedTuple, TypedDict
2122

2223
if TYPE_CHECKING:
2324
# TODO: Should we use pendulum.DateTime instead of datetime like AF 2.x?
@@ -105,3 +106,27 @@ def my_task():
105106
from airflow.sdk.definitions._internal.contextmanager import _get_current_context
106107

107108
return _get_current_context()
109+
110+
111+
class AirflowParsingContext(NamedTuple):
112+
"""
113+
Context of parsing for the DAG.
114+
115+
If these values are not None, they will contain the specific DAG and Task ID that Airflow is requesting to
116+
execute. You can use these for optimizing dynamically generated DAG files.
117+
"""
118+
119+
dag_id: str | None
120+
task_id: str | None
121+
122+
123+
_AIRFLOW_PARSING_CONTEXT_DAG_ID = "_AIRFLOW_PARSING_CONTEXT_DAG_ID"
124+
_AIRFLOW_PARSING_CONTEXT_TASK_ID = "_AIRFLOW_PARSING_CONTEXT_TASK_ID"
125+
126+
127+
def get_parsing_context() -> AirflowParsingContext:
128+
"""Return the current (DAG) parsing context info."""
129+
return AirflowParsingContext(
130+
dag_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID),
131+
task_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID),
132+
)

task_sdk/src/airflow/sdk/execution_time/task_runner.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from airflow.dag_processing.bundles.manager import DagBundlesManager
3535
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState, TIRunContext
36+
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
3637
from airflow.sdk.definitions.baseoperator import BaseOperator
3738
from airflow.sdk.execution_time.comms import (
3839
DeferTask,
@@ -406,8 +407,11 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]:
406407
setproctitle(f"airflow worker -- {msg.ti.id}")
407408

408409
log = structlog.get_logger(logger_name="task")
409-
# TODO: set the "magic loop" context vars for parsing
410-
ti = parse(msg)
410+
with _airflow_parsing_context_manager(
411+
dag_id=msg.ti.dag_id,
412+
task_id=msg.ti.task_id,
413+
):
414+
ti = parse(msg)
411415
log.debug("DAG file parsed", file=msg.dag_rel_path)
412416
else:
413417
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
from datetime import datetime
20+
21+
from airflow.sdk import DAG, BaseOperator, get_parsing_context
22+
23+
DAG_ID = "dag_parsing_context_test"
24+
25+
current_dag_id = get_parsing_context().dag_id
26+
27+
with DAG(
28+
DAG_ID,
29+
start_date=datetime(2024, 2, 21),
30+
schedule=None,
31+
) as the_dag:
32+
BaseOperator(task_id="visible_task")
33+
34+
if current_dag_id == DAG_ID:
35+
# this task will be invisible if the DAG ID is not properly set in the parsing context.
36+
BaseOperator(task_id="conditional_task")

task_sdk/tests/execution_time/test_task_runner.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
AirflowSkipException,
3737
AirflowTaskTerminated,
3838
)
39-
from airflow.sdk import DAG, BaseOperator, Connection, get_current_context
39+
from airflow.sdk import DAG, BaseOperator, Connection, get_current_context, get_parsing_context
4040
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
4141
from airflow.sdk.definitions.variable import Variable
4242
from airflow.sdk.execution_time.comms import (
@@ -542,6 +542,52 @@ def execute(self, context):
542542
)
543543

544544

545+
def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch, test_dags_dir):
546+
"""
547+
Test that the DAG parsing context is correctly set during the startup process.
548+
549+
This test verifies that the DAG and task IDs are correctly set in the parsing context
550+
when a DAG is started up.
551+
"""
552+
dag_id = "dag_parsing_context_test"
553+
task_id = "conditional_task"
554+
555+
what = StartupDetails(
556+
ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1),
557+
dag_rel_path="dag_parsing_context.py",
558+
bundle_info=BundleInfo(name="my-bundle", version=None),
559+
requests_fd=0,
560+
ti_context=make_ti_context(dag_id=dag_id, run_id="c"),
561+
)
562+
563+
mock_supervisor_comms.get_message.return_value = what
564+
565+
# Ensure the parsing context is initially empty
566+
assert get_parsing_context().dag_id is None
567+
assert get_parsing_context().task_id is None
568+
569+
# Set the environment variable for DAG bundles
570+
# We use the DAG defined in `task_sdk/tests/dags/dag_parsing_context.py` for this test!
571+
dag_bundle_val = json.dumps(
572+
[
573+
{
574+
"name": "my-bundle",
575+
"classpath": "airflow.dag_processing.bundles.local.LocalDagBundle",
576+
"kwargs": {"local_folder": str(test_dags_dir), "refresh_interval": 1},
577+
}
578+
]
579+
)
580+
581+
monkeypatch.setenv("AIRFLOW__DAG_BUNDLES__BACKENDS", dag_bundle_val)
582+
ti, _ = startup()
583+
584+
# Presence of `conditional_task` below means DAG ID is properly set in the parsing context!
585+
# Check the dag file for the actual logic!
586+
assert ti.task.dag.task_dict.keys() == {"visible_task", "conditional_task"}
587+
assert get_parsing_context().dag_id == dag_id
588+
assert get_parsing_context().task_id == task_id
589+
590+
545591
class TestRuntimeTaskInstance:
546592
def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context):
547593
"""Test get_template_context without ti_context_from_server."""

tests/dags/test_dag_parsing_context.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from airflow.models.dag import DAG
2222
from airflow.operators.empty import EmptyOperator
23-
from airflow.utils.dag_parsing_context import get_parsing_context
23+
from airflow.sdk.definitions.context import get_parsing_context
2424

2525
DAG_ID = "test_dag_parsing_context"
2626

tests/dags/test_parsing_context.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,16 @@
1818
from __future__ import annotations
1919

2020
from pathlib import Path
21-
from typing import TYPE_CHECKING
2221

2322
from airflow.models.dag import DAG
2423
from airflow.operators.empty import EmptyOperator
25-
from airflow.utils.dag_parsing_context import (
24+
from airflow.sdk.definitions.context import (
2625
_AIRFLOW_PARSING_CONTEXT_DAG_ID,
2726
_AIRFLOW_PARSING_CONTEXT_TASK_ID,
27+
Context,
2828
)
2929
from airflow.utils.timezone import datetime
3030

31-
if TYPE_CHECKING:
32-
from airflow.sdk.definitions.context import Context
33-
3431

3532
class DagWithParsingContext(EmptyOperator):
3633
def execute(self, context: Context):

0 commit comments

Comments
 (0)