Skip to content

Commit 2f44ea5

Browse files
authored
Refactor Sqlalchemy queries to 2.0 style (Part 5) (#32474)
1 parent 4671516 commit 2f44ea5

8 files changed

+103
-51
lines changed

airflow/cli/commands/dag_command.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import warnings
2929

3030
from graphviz.dot import Dot
31-
from sqlalchemy import delete
31+
from sqlalchemy import delete, select
3232
from sqlalchemy.orm import Session
3333

3434
from airflow import settings
@@ -287,7 +287,7 @@ def dag_state(args, session: Session = NEW_SESSION) -> None:
287287

288288
if not dag:
289289
raise SystemExit(f"DAG: {args.dag_id} does not exist in 'dag' table")
290-
dr = session.query(DagRun).filter_by(dag_id=args.dag_id, execution_date=args.execution_date).one_or_none()
290+
dr = session.scalar(select(DagRun).filter_by(dag_id=args.dag_id, execution_date=args.execution_date))
291291
out = dr.state if dr else None
292292
conf_out = ""
293293
if out and dr.conf:
@@ -309,7 +309,9 @@ def dag_next_execution(args) -> None:
309309
print("[INFO] Please be reminded this DAG is PAUSED now.", file=sys.stderr)
310310

311311
with create_session() as session:
312-
last_parsed_dag: DagModel = session.query(DagModel).filter(DagModel.dag_id == dag.dag_id).one()
312+
last_parsed_dag: DagModel = session.scalars(
313+
select(DagModel).where(DagModel.dag_id == dag.dag_id)
314+
).one()
313315

314316
def print_execution_interval(interval: DataInterval | None):
315317
if interval is None:
@@ -428,8 +430,10 @@ def dag_list_jobs(args, dag: DAG | None = None, session: Session = NEW_SESSION)
428430
queries.append(Job.state == args.state)
429431

430432
fields = ["dag_id", "state", "job_type", "start_date", "end_date"]
431-
all_jobs = session.query(Job).filter(*queries).order_by(Job.start_date.desc()).limit(args.limit).all()
432-
all_jobs = [{f: str(job.__getattribute__(f)) for f in fields} for job in all_jobs]
433+
all_jobs_iter = session.scalars(
434+
select(Job).where(*queries).order_by(Job.start_date.desc()).limit(args.limit)
435+
)
436+
all_jobs = [{f: str(job.__getattribute__(f)) for f in fields} for job in all_jobs_iter]
433437

434438
AirflowConsole().print_as(
435439
data=all_jobs,
@@ -492,14 +496,12 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No
492496
imgcat = args.imgcat_dagrun
493497
filename = args.save_dagrun
494498
if show_dagrun or imgcat or filename:
495-
tis = (
496-
session.query(TaskInstance)
497-
.filter(
499+
tis = session.scalars(
500+
select(TaskInstance).where(
498501
TaskInstance.dag_id == args.dag_id,
499502
TaskInstance.execution_date == execution_date,
500503
)
501-
.all()
502-
)
504+
).all()
503505

504506
dot_graph = render_dag(dag, tis=tis)
505507
print()

airflow/cli/commands/jobs_command.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
from sqlalchemy import select
1920
from sqlalchemy.orm import Session
2021

2122
from airflow.jobs.job import Job
@@ -32,17 +33,17 @@ def check(args, session: Session = NEW_SESSION) -> None:
3233
if args.hostname and args.local:
3334
raise SystemExit("You can't use --hostname and --local at the same time")
3435

35-
query = session.query(Job).filter(Job.state == State.RUNNING).order_by(Job.latest_heartbeat.desc())
36+
query = select(Job).where(Job.state == State.RUNNING).order_by(Job.latest_heartbeat.desc())
3637
if args.job_type:
37-
query = query.filter(Job.job_type == args.job_type)
38+
query = query.where(Job.job_type == args.job_type)
3839
if args.hostname:
39-
query = query.filter(Job.hostname == args.hostname)
40+
query = query.where(Job.hostname == args.hostname)
4041
if args.local:
41-
query = query.filter(Job.hostname == get_hostname())
42+
query = query.where(Job.hostname == get_hostname())
4243
if args.limit > 0:
4344
query = query.limit(args.limit)
4445

45-
alive_jobs: list[Job] = [job for job in query.all() if job.is_alive()]
46+
alive_jobs: list[Job] = [job for job in session.scalars(query) if job.is_alive()]
4647

4748
count_alive_jobs = len(alive_jobs)
4849
if count_alive_jobs == 0:

airflow/cli/commands/rotate_fernet_key_command.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
"""Rotate Fernet key command."""
1818
from __future__ import annotations
1919

20+
from sqlalchemy import select
21+
2022
from airflow.models import Connection, Variable
2123
from airflow.utils import cli as cli_utils
2224
from airflow.utils.session import create_session
@@ -26,7 +28,8 @@
2628
def rotate_fernet_key(args):
2729
"""Rotates all encrypted connection credentials and variables."""
2830
with create_session() as session:
29-
for conn in session.query(Connection).filter(Connection.is_encrypted | Connection.is_extra_encrypted):
31+
conns_query = select(Connection).where(Connection.is_encrypted | Connection.is_extra_encrypted)
32+
for conn in session.scalars(conns_query):
3033
conn.rotate_fernet_key()
31-
for var in session.query(Variable).filter(Variable.is_encrypted):
34+
for var in session.scalars(select(Variable).where(Variable.is_encrypted)):
3235
var.rotate_fernet_key()

airflow/cli/commands/task_command.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import pendulum
3131
from pendulum.parsing.exceptions import ParserError
32+
from sqlalchemy import select
3233
from sqlalchemy.orm.exc import NoResultFound
3334
from sqlalchemy.orm.session import Session
3435

@@ -111,11 +112,9 @@ def _get_dag_run(
111112
with suppress(ParserError, TypeError):
112113
execution_date = timezone.parse(exec_date_or_run_id)
113114
try:
114-
dag_run = (
115-
session.query(DagRun)
116-
.filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
117-
.one()
118-
)
115+
dag_run = session.scalars(
116+
select(DagRun).where(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
117+
).one()
119118
except NoResultFound:
120119
if not create_if_necessary:
121120
raise DagRunNotFound(
@@ -534,18 +533,14 @@ def _guess_debugger() -> _SupportedDebugger:
534533
@provide_session
535534
def task_states_for_dag_run(args, session: Session = NEW_SESSION) -> None:
536535
"""Get the status of all task instances in a DagRun."""
537-
dag_run = (
538-
session.query(DagRun)
539-
.filter(DagRun.run_id == args.execution_date_or_run_id, DagRun.dag_id == args.dag_id)
540-
.one_or_none()
536+
dag_run = session.scalar(
537+
select(DagRun).where(DagRun.run_id == args.execution_date_or_run_id, DagRun.dag_id == args.dag_id)
541538
)
542539
if not dag_run:
543540
try:
544541
execution_date = timezone.parse(args.execution_date_or_run_id)
545-
dag_run = (
546-
session.query(DagRun)
547-
.filter(DagRun.execution_date == execution_date, DagRun.dag_id == args.dag_id)
548-
.one_or_none()
542+
dag_run = session.scalar(
543+
select(DagRun).where(DagRun.execution_date == execution_date, DagRun.dag_id == args.dag_id)
549544
)
550545
except (ParserError, TypeError) as err:
551546
raise AirflowException(f"Error parsing the supplied execution_date. Error: {str(err)}")

airflow/cli/commands/variable_command.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import os
2323
from json import JSONDecodeError
2424

25+
from sqlalchemy import select
26+
2527
from airflow.cli.simple_table import AirflowConsole
2628
from airflow.models import Variable
2729
from airflow.utils import cli as cli_utils
@@ -33,7 +35,7 @@
3335
def variables_list(args):
3436
"""Displays all the variables."""
3537
with create_session() as session:
36-
variables = session.query(Variable)
38+
variables = session.scalars(select(Variable)).all()
3739
AirflowConsole().print_as(data=variables, output=args.output, mapper=lambda x: {"key": x.key})
3840

3941

@@ -107,7 +109,7 @@ def _variable_export_helper(filepath):
107109
"""Helps export all the variables to the file."""
108110
var_dict = {}
109111
with create_session() as session:
110-
qry = session.query(Variable).all()
112+
qry = session.scalars(select(Variable))
111113

112114
data = json.JSONDecoder()
113115
for var in qry:

airflow/models/renderedtifields.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,13 @@ def get_templated_fields(cls, ti: TaskInstance, session: Session = NEW_SESSION)
134134
:param session: SqlAlchemy Session
135135
:return: Rendered Templated TI field
136136
"""
137-
result = (
138-
session.query(cls.rendered_fields)
139-
.filter(
137+
result = session.scalar(
138+
select(cls).where(
140139
cls.dag_id == ti.dag_id,
141140
cls.task_id == ti.task_id,
142141
cls.run_id == ti.run_id,
143142
cls.map_index == ti.map_index,
144143
)
145-
.one_or_none()
146144
)
147145

148146
if result:
@@ -162,15 +160,13 @@ def get_k8s_pod_yaml(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> d
162160
:param session: SqlAlchemy Session
163161
:return: Kubernetes Pod Yaml
164162
"""
165-
result = (
166-
session.query(cls.k8s_pod_yaml)
167-
.filter(
163+
result = session.scalar(
164+
select(cls).where(
168165
cls.dag_id == ti.dag_id,
169166
cls.task_id == ti.task_id,
170167
cls.run_id == ti.run_id,
171168
cls.map_index == ti.map_index,
172169
)
173-
.one_or_none()
174170
)
175171
return result.k8s_pod_yaml if result else None
176172

@@ -243,7 +239,8 @@ def _do_delete_old_records(
243239
cls.task_id == task_id,
244240
tuple_not_in_condition(
245241
(cls.dag_id, cls.task_id, cls.run_id),
246-
session.query(ti_clause.c.dag_id, ti_clause.c.task_id, ti_clause.c.run_id),
242+
select(ti_clause.c.dag_id, ti_clause.c.task_id, ti_clause.c.run_id),
243+
session=session,
247244
),
248245
)
249246
.execution_options(synchronize_session=False)

airflow/utils/sqlalchemy.py

+58-6
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
import datetime
2323
import json
2424
import logging
25-
from typing import TYPE_CHECKING, Any, Generator, Iterable
25+
from typing import TYPE_CHECKING, Any, Generator, Iterable, overload
2626

2727
import pendulum
2828
from dateutil import relativedelta
2929
from sqlalchemy import TIMESTAMP, PickleType, and_, event, false, nullsfirst, or_, true, tuple_
3030
from sqlalchemy.dialects import mssql, mysql
3131
from sqlalchemy.exc import OperationalError
32-
from sqlalchemy.sql import ColumnElement
32+
from sqlalchemy.sql import ColumnElement, Select
3333
from sqlalchemy.sql.expression import ColumnOperators
3434
from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText
3535

@@ -515,11 +515,31 @@ def is_lock_not_available_error(error: OperationalError):
515515
return False
516516

517517

518+
@overload
518519
def tuple_in_condition(
519520
columns: tuple[ColumnElement, ...],
520521
collection: Iterable[Any],
521522
) -> ColumnOperators:
522-
"""Generates a tuple-in-collection operator to use in ``.filter()``.
523+
...
524+
525+
526+
@overload
527+
def tuple_in_condition(
528+
columns: tuple[ColumnElement, ...],
529+
collection: Select,
530+
*,
531+
session: Session,
532+
) -> ColumnOperators:
533+
...
534+
535+
536+
def tuple_in_condition(
537+
columns: tuple[ColumnElement, ...],
538+
collection: Iterable[Any] | Select,
539+
*,
540+
session: Session | None = None,
541+
) -> ColumnOperators:
542+
"""Generates a tuple-in-collection operator to use in ``.where()``.
523543
524544
For most SQL backends, this generates a simple ``([col, ...]) IN [condition]``
525545
clause. This however does not work with MSSQL, where we need to expand to
@@ -529,25 +549,57 @@ def tuple_in_condition(
529549
"""
530550
if settings.engine.dialect.name != "mssql":
531551
return tuple_(*columns).in_(collection)
532-
clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in collection]
552+
if not isinstance(collection, Select):
553+
rows = collection
554+
elif session is None:
555+
raise TypeError("session is required when passing in a subquery")
556+
else:
557+
rows = session.execute(collection)
558+
clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in rows]
533559
if not clauses:
534560
return false()
535561
return or_(*clauses)
536562

537563

564+
@overload
538565
def tuple_not_in_condition(
539566
columns: tuple[ColumnElement, ...],
540567
collection: Iterable[Any],
541568
) -> ColumnOperators:
542-
"""Generates a tuple-not-in-collection operator to use in ``.filter()``.
569+
...
570+
571+
572+
@overload
573+
def tuple_not_in_condition(
574+
columns: tuple[ColumnElement, ...],
575+
collection: Select,
576+
*,
577+
session: Session,
578+
) -> ColumnOperators:
579+
...
580+
581+
582+
def tuple_not_in_condition(
583+
columns: tuple[ColumnElement, ...],
584+
collection: Iterable[Any] | Select,
585+
*,
586+
session: Session | None = None,
587+
) -> ColumnOperators:
588+
"""Generates a tuple-not-in-collection operator to use in ``.where()``.
543589
544590
This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
545591
546592
:meta private:
547593
"""
548594
if settings.engine.dialect.name != "mssql":
549595
return tuple_(*columns).not_in(collection)
550-
clauses = [or_(*(c != v for c, v in zip(columns, values))) for values in collection]
596+
if not isinstance(collection, Select):
597+
rows = collection
598+
elif session is None:
599+
raise TypeError("session is required when passing in a subquery")
600+
else:
601+
rows = session.execute(collection)
602+
clauses = [or_(*(c != v for c, v in zip(columns, values))) for values in rows]
551603
if not clauses:
552604
return true()
553605
return and_(*clauses)

tests/cli/commands/test_task_command.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -500,15 +500,15 @@ def test_mapped_task_render_with_template(self, dag_maker):
500500
assert 'echo "2022-01-01"' in output
501501
assert 'echo "2022-01-08"' in output
502502

503-
@mock.patch("sqlalchemy.orm.session.Session.query")
503+
@mock.patch("airflow.cli.commands.task_command.select")
504+
@mock.patch("airflow.cli.commands.task_command.Session.scalars")
504505
@mock.patch("airflow.cli.commands.task_command.DagRun")
505-
def test_task_render_with_custom_timetable(self, mock_dagrun, mock_query):
506+
def test_task_render_with_custom_timetable(self, mock_dagrun, mock_scalars, mock_select):
506507
"""
507508
when calling `tasks render` on dag with custom timetable, the DagRun object should be created with
508509
data_intervals.
509510
"""
510-
mock_query.side_effect = sqlalchemy.exc.NoResultFound
511-
511+
mock_scalars.side_effect = sqlalchemy.exc.NoResultFound
512512
task_command.task_render(
513513
self.parser.parse_args(["tasks", "render", "example_workday_timetable", "run_this", "2022-01-01"])
514514
)

0 commit comments

Comments
 (0)