Skip to content

Commit e155d30

Browse files
kaxilHariGS-DB
authored andcommitted
AIP-72: Support DAG parsing context in Task SDK (apache#45694)
1 parent 8d6f2c4 commit e155d30

File tree

15 files changed

+214
-53
lines changed

15 files changed

+214
-53
lines changed

Diff for: airflow/cli/commands/remote_commands/dag_command.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
from airflow.jobs.job import Job
3838
from airflow.models import DagBag, DagModel, DagRun, TaskInstance
3939
from airflow.models.serialized_dag import SerializedDagModel
40+
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
4041
from airflow.utils import cli as cli_utils, timezone
4142
from airflow.utils.cli import get_dag, process_subdir, suppress_logs_and_warning
42-
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
4343
from airflow.utils.dot_renderer import render_dag, render_dag_dependencies
4444
from airflow.utils.helpers import ask_yesno
4545
from airflow.utils.providers_configuration_loader import providers_configuration_loaded

Diff for: airflow/task/standard_task_runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@
3333
from airflow.configuration import conf
3434
from airflow.exceptions import AirflowConfigException
3535
from airflow.models.taskinstance import TaskReturnCode
36+
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
3637
from airflow.settings import CAN_FORK
3738
from airflow.stats import Stats
3839
from airflow.utils.configuration import tmp_configuration_copy
39-
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
4040
from airflow.utils.log.logging_mixin import LoggingMixin
4141
from airflow.utils.net import get_hostname
4242
from airflow.utils.platform import IS_WINDOWS, getuser

Diff for: airflow/utils/dag_parsing_context.py

+11-38
Original file line numberDiff line numberDiff line change
@@ -14,47 +14,20 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
from __future__ import annotations
18-
19-
import os
20-
from contextlib import contextmanager
21-
from typing import NamedTuple
22-
23-
24-
class AirflowParsingContext(NamedTuple):
25-
"""
26-
Context of parsing for the DAG.
27-
28-
If these values are not None, they will contain the specific DAG and Task ID that Airflow is requesting to
29-
execute. You can use these for optimizing dynamically generated DAG files.
30-
"""
31-
32-
dag_id: str | None
33-
task_id: str | None
3417

18+
from __future__ import annotations
3519

36-
_AIRFLOW_PARSING_CONTEXT_DAG_ID = "_AIRFLOW_PARSING_CONTEXT_DAG_ID"
37-
_AIRFLOW_PARSING_CONTEXT_TASK_ID = "_AIRFLOW_PARSING_CONTEXT_TASK_ID"
20+
import warnings
3821

22+
from airflow.sdk.definitions.context import get_parsing_context
3923

40-
@contextmanager
41-
def _airflow_parsing_context_manager(dag_id: str | None = None, task_id: str | None = None):
42-
old_dag_id = os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID)
43-
old_task_id = os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID)
44-
if dag_id is not None:
45-
os.environ[_AIRFLOW_PARSING_CONTEXT_DAG_ID] = dag_id
46-
if task_id is not None:
47-
os.environ[_AIRFLOW_PARSING_CONTEXT_TASK_ID] = task_id
48-
yield
49-
if old_task_id is not None:
50-
os.environ[_AIRFLOW_PARSING_CONTEXT_TASK_ID] = old_task_id
51-
if old_dag_id is not None:
52-
os.environ[_AIRFLOW_PARSING_CONTEXT_DAG_ID] = old_dag_id
24+
# TODO: Remove this module in Airflow 3.2
5325

26+
warnings.warn(
27+
"Import from the airflow.utils.dag_parsing_context module is deprecated and "
28+
"will be removed in Airflow 3.2. Please import it from 'airflow.sdk'.",
29+
DeprecationWarning,
30+
stacklevel=2,
31+
)
5432

55-
def get_parsing_context() -> AirflowParsingContext:
56-
"""Return the current (DAG) parsing context info."""
57-
return AirflowParsingContext(
58-
dag_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID),
59-
task_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID),
60-
)
33+
__all__ = ["get_parsing_context"]

Diff for: docs/apache-airflow/howto/dynamic-dag-generation.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ of the context are set to ``None``.
207207
:emphasize-lines: 4,8,9
208208
209209
from airflow.models.dag import DAG
210-
from airflow.utils.dag_parsing_context import get_parsing_context
210+
from airflow.sdk import get_parsing_context
211211
212212
current_dag_id = get_parsing_context().dag_id
213213

Diff for: newsfragments/45694.significant.rst

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
``get_parsing_context`` have been moved to Task SDK
2+
3+
As part of AIP-72: Task SDK, the function ``get_parsing_context`` has been moved to ``airflow.sdk`` module.
4+
Previously, it was located in ``airflow.utils.dag_parsing_context`` module.
5+
6+
This function is used to optimize DAG parsing during execution when DAGs are generated dynamically.
7+
8+
Before:
9+
10+
.. code-block:: python
11+
12+
from airflow.models.dag import DAG
13+
from airflow.utils.dag_parsing_context import get_parsing_context
14+
15+
current_dag_id = get_parsing_context().dag_id
16+
17+
for thing in list_of_things:
18+
dag_id = f"generated_dag_{thing}"
19+
if current_dag_id is not None and current_dag_id != dag_id:
20+
continue # skip generation of non-selected DAG
21+
22+
with DAG(dag_id=dag_id, ...):
23+
...
24+
25+
After:
26+
27+
.. code-block:: python
28+
29+
from airflow.sdk import get_parsing_context
30+
31+
current_dag_id = get_parsing_context().dag_id
32+
33+
# The rest of the code remains the same
34+
35+
* Types of change
36+
37+
* [x] DAG changes
38+
* [ ] Config changes
39+
* [ ] API changes
40+
* [ ] CLI changes
41+
* [ ] Behaviour changes
42+
* [ ] Plugin changes
43+
* [ ] Dependency change
44+
45+
* Migration rules needed
46+
47+
* ruff
48+
49+
* AIR302
50+
51+
* [ ] ``airflow.utils.dag_parsing_context.get_parsing_context`` -> ``airflow.sdk.get_parsing_context``

Diff for: providers/src/airflow/providers/celery/executors/celery_executor_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
from airflow.configuration import conf
4545
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout
4646
from airflow.executors.base_executor import BaseExecutor
47+
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
4748
from airflow.stats import Stats
48-
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
4949
from airflow.utils.log.logging_mixin import LoggingMixin
5050
from airflow.utils.net import get_hostname
5151
from airflow.utils.providers_configuration_loader import providers_configuration_loaded

Diff for: scripts/cov/core_coverage.py

-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@
104104
"airflow/utils/code_utils.py",
105105
"airflow/utils/context.py",
106106
"airflow/utils/dag_cycle_tester.py",
107-
"airflow/utils/dag_parsing_context.py",
108107
"airflow/utils/dates.py",
109108
"airflow/utils/db.py",
110109
"airflow/utils/db_cleanup.py",

Diff for: task_sdk/src/airflow/sdk/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"dag",
2828
"Connection",
2929
"get_current_context",
30+
"get_parsing_context",
3031
"__version__",
3132
]
3233

@@ -35,7 +36,7 @@
3536
if TYPE_CHECKING:
3637
from airflow.sdk.definitions.baseoperator import BaseOperator
3738
from airflow.sdk.definitions.connection import Connection
38-
from airflow.sdk.definitions.context import get_current_context
39+
from airflow.sdk.definitions.context import get_current_context, get_parsing_context
3940
from airflow.sdk.definitions.dag import DAG, dag
4041
from airflow.sdk.definitions.edges import EdgeModifier, Label
4142
from airflow.sdk.definitions.taskgroup import TaskGroup
@@ -50,6 +51,7 @@
5051
"Connection": ".definitions.connection",
5152
"Variable": ".definitions.variable",
5253
"get_current_context": ".definitions.context",
54+
"get_parsing_context": ".definitions.context",
5355
}
5456

5557

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
import os
20+
from contextlib import contextmanager
21+
22+
from airflow.sdk.definitions.context import _AIRFLOW_PARSING_CONTEXT_DAG_ID, _AIRFLOW_PARSING_CONTEXT_TASK_ID
23+
24+
25+
@contextmanager
26+
def _airflow_parsing_context_manager(dag_id: str | None = None, task_id: str | None = None):
27+
old_dag_id = os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID)
28+
old_task_id = os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID)
29+
if dag_id is not None:
30+
os.environ[_AIRFLOW_PARSING_CONTEXT_DAG_ID] = dag_id
31+
if task_id is not None:
32+
os.environ[_AIRFLOW_PARSING_CONTEXT_TASK_ID] = task_id
33+
yield
34+
if old_task_id is not None:
35+
os.environ[_AIRFLOW_PARSING_CONTEXT_TASK_ID] = old_task_id
36+
if old_dag_id is not None:
37+
os.environ[_AIRFLOW_PARSING_CONTEXT_DAG_ID] = old_dag_id

Diff for: task_sdk/src/airflow/sdk/definitions/context.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20-
from typing import TYPE_CHECKING, Any, TypedDict
20+
import os
21+
from typing import TYPE_CHECKING, Any, NamedTuple, TypedDict
2122

2223
if TYPE_CHECKING:
2324
# TODO: Should we use pendulum.DateTime instead of datetime like AF 2.x?
@@ -105,3 +106,27 @@ def my_task():
105106
from airflow.sdk.definitions._internal.contextmanager import _get_current_context
106107

107108
return _get_current_context()
109+
110+
111+
class AirflowParsingContext(NamedTuple):
112+
"""
113+
Context of parsing for the DAG.
114+
115+
If these values are not None, they will contain the specific DAG and Task ID that Airflow is requesting to
116+
execute. You can use these for optimizing dynamically generated DAG files.
117+
"""
118+
119+
dag_id: str | None
120+
task_id: str | None
121+
122+
123+
_AIRFLOW_PARSING_CONTEXT_DAG_ID = "_AIRFLOW_PARSING_CONTEXT_DAG_ID"
124+
_AIRFLOW_PARSING_CONTEXT_TASK_ID = "_AIRFLOW_PARSING_CONTEXT_TASK_ID"
125+
126+
127+
def get_parsing_context() -> AirflowParsingContext:
128+
"""Return the current (DAG) parsing context info."""
129+
return AirflowParsingContext(
130+
dag_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID),
131+
task_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID),
132+
)

Diff for: task_sdk/src/airflow/sdk/execution_time/task_runner.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from airflow.dag_processing.bundles.manager import DagBundlesManager
3535
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState, TIRunContext
36+
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
3637
from airflow.sdk.definitions.baseoperator import BaseOperator
3738
from airflow.sdk.execution_time.comms import (
3839
DeferTask,
@@ -406,8 +407,8 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]:
406407
setproctitle(f"airflow worker -- {msg.ti.id}")
407408

408409
log = structlog.get_logger(logger_name="task")
409-
# TODO: set the "magic loop" context vars for parsing
410-
ti = parse(msg)
410+
with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id):
411+
ti = parse(msg)
411412
log.debug("DAG file parsed", file=msg.dag_rel_path)
412413
else:
413414
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")

Diff for: task_sdk/tests/dags/dag_parsing_context.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
from datetime import datetime
20+
21+
from airflow.sdk import DAG, BaseOperator, get_parsing_context
22+
23+
DAG_ID = "dag_parsing_context_test"
24+
25+
current_dag_id = get_parsing_context().dag_id
26+
27+
with DAG(
28+
DAG_ID,
29+
start_date=datetime(2024, 2, 21),
30+
schedule=None,
31+
) as the_dag:
32+
BaseOperator(task_id="visible_task")
33+
34+
if current_dag_id == DAG_ID:
35+
# this task will be invisible if the DAG ID is not properly set in the parsing context.
36+
BaseOperator(task_id="conditional_task")

Diff for: task_sdk/tests/execution_time/test_task_runner.py

+40
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,46 @@ def execute(self, context):
542542
)
543543

544544

545+
def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch, test_dags_dir):
546+
"""
547+
Test that the DAG parsing context is correctly set during the startup process.
548+
549+
This test verifies that the DAG and task IDs are correctly set in the parsing context
550+
when a DAG is started up.
551+
"""
552+
dag_id = "dag_parsing_context_test"
553+
task_id = "conditional_task"
554+
555+
what = StartupDetails(
556+
ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1),
557+
dag_rel_path="dag_parsing_context.py",
558+
bundle_info=BundleInfo(name="my-bundle", version=None),
559+
requests_fd=0,
560+
ti_context=make_ti_context(dag_id=dag_id, run_id="c"),
561+
)
562+
563+
mock_supervisor_comms.get_message.return_value = what
564+
565+
# Set the environment variable for DAG bundles
566+
# We use the DAG defined in `task_sdk/tests/dags/dag_parsing_context.py` for this test!
567+
dag_bundle_val = json.dumps(
568+
[
569+
{
570+
"name": "my-bundle",
571+
"classpath": "airflow.dag_processing.bundles.local.LocalDagBundle",
572+
"kwargs": {"local_folder": str(test_dags_dir), "refresh_interval": 1},
573+
}
574+
]
575+
)
576+
577+
monkeypatch.setenv("AIRFLOW__DAG_BUNDLES__BACKENDS", dag_bundle_val)
578+
ti, _ = startup()
579+
580+
# Presence of `conditional_task` below means DAG ID is properly set in the parsing context!
581+
# Check the dag file for the actual logic!
582+
assert ti.task.dag.task_dict.keys() == {"visible_task", "conditional_task"}
583+
584+
545585
class TestRuntimeTaskInstance:
546586
def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context):
547587
"""Test get_template_context without ti_context_from_server."""

Diff for: tests/dags/test_dag_parsing_context.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from airflow.models.dag import DAG
2222
from airflow.operators.empty import EmptyOperator
23-
from airflow.utils.dag_parsing_context import get_parsing_context
23+
from airflow.sdk.definitions.context import get_parsing_context
2424

2525
DAG_ID = "test_dag_parsing_context"
2626

Diff for: tests/dags/test_parsing_context.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,16 @@
1818
from __future__ import annotations
1919

2020
from pathlib import Path
21-
from typing import TYPE_CHECKING
2221

2322
from airflow.models.dag import DAG
2423
from airflow.operators.empty import EmptyOperator
25-
from airflow.utils.dag_parsing_context import (
24+
from airflow.sdk.definitions.context import (
2625
_AIRFLOW_PARSING_CONTEXT_DAG_ID,
2726
_AIRFLOW_PARSING_CONTEXT_TASK_ID,
27+
Context,
2828
)
2929
from airflow.utils.timezone import datetime
3030

31-
if TYPE_CHECKING:
32-
from airflow.sdk.definitions.context import Context
33-
3431

3532
class DagWithParsingContext(EmptyOperator):
3633
def execute(self, context: Context):

0 commit comments

Comments
 (0)