Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,13 @@ scheduler:
type: integer
example: ~
default: "60"
asset_active_batch_size:
description: |
Batch size used when activating or orphaning assets to avoid oversized SQL ``IN`` clauses.
version_added: 3.3.0
type: integer
example: ~
default: "500"
pool_metrics_interval:
description: |
How often (in seconds) should pool usage stats be sent to StatsD (if statsd_on is enabled)
Expand Down
55 changes: 41 additions & 14 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from datetime import date, datetime, timedelta
from functools import lru_cache, partial
from itertools import groupby
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeVar

from sqlalchemy import and_, delete, desc, exists, func, inspect, or_, select, text, tuple_, update
from sqlalchemy.exc import OperationalError
Expand Down Expand Up @@ -120,6 +120,11 @@
TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT = "stuck in queued reschedule"
""":meta private:"""

ASSET_ACTIVE_BATCH_SIZE = conf.getint("scheduler", "asset_active_batch_size", fallback=500)
""":meta private:"""

T = TypeVar("T")


def _eager_load_dag_run_for_validation() -> tuple[LoaderOption, LoaderOption]:
"""
Expand Down Expand Up @@ -2802,27 +2807,49 @@ def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None:
self._activate_referenced_assets(asset_orphanation.get(False, ()), session=session)

@staticmethod
def _orphan_unreferenced_assets(assets: Collection[AssetModel], *, session: Session) -> None:
if assets:
def _batched(iterable: Iterable[T], size: int) -> Iterator[list[T]]:
iterator = iter(iterable)
while batch := list(itertools.islice(iterator, size)):
yield batch

@staticmethod
def _delete_assets_in_batches(assets: list[AssetModel], *, session: Session) -> None:
for batch in SchedulerJobRunner._batched(assets, ASSET_ACTIVE_BATCH_SIZE):
session.execute(
delete(AssetActive).where(
tuple_(AssetActive.name, AssetActive.uri).in_((a.name, a.uri) for a in assets)
tuple_(AssetActive.name, AssetActive.uri).in_([(a.name, a.uri) for a in batch])
)
)

@staticmethod
def _select_active_assets_in_batches(
assets: list[AssetModel], *, session: Session
) -> set[tuple[str, str]]:
active_assets: set[tuple[str, str]] = set()
for batch in SchedulerJobRunner._batched(assets, ASSET_ACTIVE_BATCH_SIZE):
active_assets.update(
session.execute(
select(AssetActive.name, AssetActive.uri).where(
tuple_(AssetActive.name, AssetActive.uri).in_([(a.name, a.uri) for a in batch])
)
)
)
Stats.gauge("asset.orphaned", len(assets))
return active_assets

@staticmethod
def _orphan_unreferenced_assets(assets: Collection[AssetModel], *, session: Session) -> None:
assets_list = list(assets)
if assets_list:
SchedulerJobRunner._delete_assets_in_batches(assets_list, session=session)
Stats.gauge("asset.orphaned", len(assets_list))

@staticmethod
def _activate_referenced_assets(assets: Collection[AssetModel], *, session: Session) -> None:
if not assets:
assets_list = list(assets)
if not assets_list:
return

active_assets = set(
session.execute(
select(AssetActive.name, AssetActive.uri).where(
tuple_(AssetActive.name, AssetActive.uri).in_((a.name, a.uri) for a in assets)
)
)
)
active_assets = SchedulerJobRunner._select_active_assets_in_batches(assets_list, session=session)

active_name_to_uri: dict[str, str] = {name: uri for name, uri in active_assets}
active_uri_to_name: dict[str, str] = {uri: name for name, uri in active_assets}
Expand All @@ -2848,7 +2875,7 @@ def _generate_warning_message(
def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]:
incoming_name_to_uri: dict[str, str] = {}
incoming_uri_to_name: dict[str, str] = {}
for asset in assets:
for asset in assets_list:
if (asset.name, asset.uri) in active_assets:
continue
existing_uri = active_name_to_uri.get(asset.name) or incoming_name_to_uri.get(asset.name)
Expand Down
87 changes: 87 additions & 0 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import contextlib
import datetime
import math
import logging
import os
from collections import Counter, deque
Expand All @@ -36,6 +37,8 @@
from pytest import param
from sqlalchemy import func, select, update
from sqlalchemy.orm import joinedload
from sqlalchemy.sql.dml import Delete
from sqlalchemy.sql.selectable import Select

from airflow import settings
from airflow._shared.timezones import timezone
Expand All @@ -50,6 +53,7 @@
from airflow.executors.executor_constants import MOCK_EXECUTOR
from airflow.executors.executor_loader import ExecutorLoader
from airflow.executors.executor_utils import ExecutorName
from airflow.jobs import scheduler_job_runner
from airflow.jobs.job import Job, run_job
from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
from airflow.models.asset import (
Expand Down Expand Up @@ -7088,6 +7092,89 @@ def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag(
for i in range(100):
assert f"it's duplicate {i}" in dag_warning.message

def test_orphan_unreferenced_assets_batches_deletes(self, session, monkeypatch):
asset_count = 5
assets = [
Asset(
name=f"asset_batch_orphan_{i}",
uri=f"s3://bucket/key/orphan/{i}",
extra={"foo": "bar"},
)
for i in range(asset_count)
]
dag = DAG(dag_id="test_asset_batch_orphan", start_date=DEFAULT_DATE, schedule=assets)
sync_dag_to_db(dag, session=session)

asset_models = session.scalars(
select(AssetModel).where(AssetModel.name.like("asset_batch_orphan_%"))
).all()

SchedulerJobRunner._activate_referenced_assets(asset_models, session=session)
session.flush()

statements: list[object] = []
original_execute = session.execute

def tracking_execute(statement, *args, **kwargs):
statements.append(statement)
return original_execute(statement, *args, **kwargs)

batch_size = 2
monkeypatch.setattr(scheduler_job_runner, "ASSET_ACTIVE_BATCH_SIZE", batch_size)
monkeypatch.setattr(session, "execute", tracking_execute)

SchedulerJobRunner._orphan_unreferenced_assets(asset_models, session=session)
session.flush()

monkeypatch.setattr(session, "execute", original_execute)
captured_statements = list(statements)

delete_statements = [stmt for stmt in captured_statements if isinstance(stmt, Delete)]
assert len(delete_statements) == math.ceil(len(asset_models) / batch_size)
assert session.scalars(select(AssetActive)).all() == []

def test_activate_referenced_assets_batches_active_lookup(self, session, monkeypatch):
asset_count = 5
assets = [
Asset(
name=f"asset_batch_activate_{i}",
uri=f"s3://bucket/key/activate/{i}",
extra={"foo": "bar"},
)
for i in range(asset_count)
]
dag = DAG(dag_id="test_asset_batch_activate", start_date=DEFAULT_DATE, schedule=assets)
sync_dag_to_db(dag, session=session)

asset_models = session.scalars(
select(AssetModel).where(AssetModel.name.like("asset_batch_activate_%"))
).all()

statements: list[object] = []
original_execute = session.execute

def tracking_execute(statement, *args, **kwargs):
statements.append(statement)
return original_execute(statement, *args, **kwargs)

batch_size = 2
monkeypatch.setattr(scheduler_job_runner, "ASSET_ACTIVE_BATCH_SIZE", batch_size)
monkeypatch.setattr(session, "execute", tracking_execute)

SchedulerJobRunner._activate_referenced_assets(asset_models, session=session)
session.flush()

monkeypatch.setattr(session, "execute", original_execute)
captured_statements = list(statements)

asset_active_selects = [
stmt
for stmt in captured_statements
if isinstance(stmt, Select) and AssetActive.__table__ in stmt.get_final_froms()
]
assert len(asset_active_selects) == math.ceil(len(asset_models) / batch_size)
assert len(session.scalars(select(AssetActive)).all()) == len(asset_models)

def test_scheduler_passes_context_from_server_on_heartbeat_timeout(self, dag_maker, session):
"""Test that scheduler passes context_from_server when handling heartbeat timeouts."""
with dag_maker(dag_id="test_dag", session=session):
Expand Down