Skip to content

Commit 62de813

Browse files
committed
Check arguments and function decorated with @task
1 parent 72a9dd3 commit 62de813

File tree

3 files changed

+478
-288
lines changed

3 files changed

+478
-288
lines changed

crates/ruff_linter/resources/test/fixtures/airflow/AIR302_context.py

+39
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,43 @@
1+
import pendulum
12
from airflow.models import DAG
23
from airflow.operators.dummy import DummyOperator
34
from datetime import datetime
45
from airflow.plugins_manager import AirflowPlugin
56
from airflow.decorators import task, get_current_context
67
from airflow.models.baseoperator import BaseOperator
8+
from airflow.decorators import dag, task
9+
from airflow.providers.standard.operators.python import PythonOperator
10+
11+
12+
def access_invalid_key_in_context(**context):
13+
print("access invalid key", context["conf"])
14+
15+
16+
@task
17+
def access_invalid_key_task_out_of_dag(**context):
18+
print("access invalid key", context.get("conf"))
19+
20+
21+
@dag(
22+
schedule=None,
23+
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
24+
catchup=False,
25+
tags=[""],
26+
)
27+
def invalid_dag():
28+
@task()
29+
def access_invalid_key_task(**context):
30+
print("access invalid key", context.get("conf"))
31+
32+
task1 = PythonOperator(
33+
task_id="task1",
34+
python_callable=access_invalid_key_in_context,
35+
)
36+
access_invalid_key_task() >> task1
37+
access_invalid_key_task_out_of_dag()
38+
39+
40+
invalid_dag()
741

842
@task
943
def print_config(**context):
@@ -74,3 +108,8 @@ def execute(self, context):
74108
tomorrow_ds = context["tomorrow_ds"]
75109
yesterday_ds = context["yesterday_ds"]
76110
yesterday_ds_nodash = context["yesterday_ds_nodash"]
111+
112+
@task
113+
def access_invalid_key_task_out_of_dag(execution_date, **context):
114+
print("execution date", execution_date)
115+
print("access invalid key", context.get("conf"))

crates/ruff_linter/src/rules/airflow/rules/removal_in_3.rs

+117-18
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1+
use crate::checkers::ast::Checker;
12
use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation};
23
use ruff_macros::{derive_message_formats, ViolationMetadata};
4+
use ruff_python_ast::helpers::map_callable;
35
use ruff_python_ast::{
46
name::QualifiedName, Arguments, Expr, ExprAttribute, ExprCall, ExprContext, ExprName,
5-
ExprStringLiteral, ExprSubscript, StmtClassDef,
7+
ExprStringLiteral, ExprSubscript, Stmt, StmtClassDef, StmtFunctionDef,
68
};
79
use ruff_python_semantic::analyze::typing;
810
use ruff_python_semantic::Modules;
911
use ruff_python_semantic::ScopeKind;
1012
use ruff_text_size::Ranged;
1113
use ruff_text_size::TextRange;
1214

13-
use crate::checkers::ast::Checker;
14-
1515
/// ## What it does
1616
/// Checks for uses of deprecated Airflow functions and values.
1717
///
@@ -71,6 +71,21 @@ impl Violation for Airflow3Removal {
7171
}
7272
}
7373

74+
const REMOVED_CONTEXT_KEYS: [&str; 12] = [
75+
"conf",
76+
"execution_date",
77+
"next_ds",
78+
"next_ds_nodash",
79+
"next_execution_date",
80+
"prev_ds",
81+
"prev_ds_nodash",
82+
"prev_execution_date",
83+
"prev_execution_date_success",
84+
"tomorrow_ds",
85+
"yesterday_ds",
86+
"yesterday_ds_nodash",
87+
];
88+
7489
fn extract_name_from_slice(slice: &Expr) -> Option<String> {
7590
match slice {
7691
Expr::StringLiteral(ExprStringLiteral { value, .. }) => Some(value.to_string()),
@@ -79,21 +94,6 @@ fn extract_name_from_slice(slice: &Expr) -> Option<String> {
7994
}
8095

8196
pub(crate) fn removed_context_variable(checker: &mut Checker, expr: &Expr) {
82-
const REMOVED_CONTEXT_KEYS: [&str; 12] = [
83-
"conf",
84-
"execution_date",
85-
"next_ds",
86-
"next_ds_nodash",
87-
"next_execution_date",
88-
"prev_ds",
89-
"prev_ds_nodash",
90-
"prev_execution_date",
91-
"prev_execution_date_success",
92-
"tomorrow_ds",
93-
"yesterday_ds",
94-
"yesterday_ds_nodash",
95-
];
96-
9797
if let Expr::Subscript(ExprSubscript { value, slice, .. }) = expr {
9898
if let Expr::Name(ExprName { id, .. }) = &**value {
9999
if id.as_str() == "context" {
@@ -144,6 +144,7 @@ pub(crate) fn removed_in_3(checker: &mut Checker, expr: &Expr) {
144144
check_call_arguments(checker, &qualname, arguments);
145145
};
146146
check_method(checker, call_expr);
147+
check_context_get(checker, call_expr);
147148
}
148149
Expr::Attribute(attribute_expr @ ExprAttribute { attr, .. }) => {
149150
check_name(checker, expr, attr.range());
@@ -307,6 +308,52 @@ fn check_class_attribute(checker: &mut Checker, attribute_expr: &ExprAttribute)
307308
}
308309
}
309310

311+
/// Check whether a removed context key is access through context.get("key").
312+
///
313+
/// ```python
314+
/// from airflow.decorators import task
315+
///
316+
///
317+
/// @task
318+
/// def access_invalid_key_task_out_of_dag(**context):
319+
/// print("access invalid key", context.get("conf"))
320+
/// ```
321+
fn check_context_get(checker: &mut Checker, call_expr: &ExprCall) {
322+
if is_task_context_referenced(checker, &call_expr.func) {
323+
return;
324+
}
325+
326+
let Expr::Attribute(ExprAttribute { value, attr, .. }) = &*call_expr.func else {
327+
return;
328+
};
329+
330+
// Ensure the method called on `context`
331+
if !value
332+
.as_name_expr()
333+
.is_some_and(|name| matches!(name.id.as_str(), "context"))
334+
{
335+
return;
336+
}
337+
338+
// Ensure the method called on `get`
339+
if attr.as_str() != "get" {
340+
return;
341+
}
342+
343+
for removed_key in REMOVED_CONTEXT_KEYS {
344+
if let Some(argument) = call_expr.arguments.find_argument_value(removed_key, 0) {
345+
checker.diagnostics.push(Diagnostic::new(
346+
Airflow3Removal {
347+
deprecated: removed_key.to_string(),
348+
replacement: Replacement::None,
349+
},
350+
argument.range(),
351+
));
352+
return;
353+
}
354+
}
355+
}
356+
310357
/// Check whether a removed Airflow class method is called.
311358
///
312359
/// For example:
@@ -909,3 +956,55 @@ fn is_airflow_builtin_or_provider(segments: &[&str], module: &str, symbol_suffix
909956
_ => false,
910957
}
911958
}
959+
960+
fn is_task_context_referenced(checker: &mut Checker, expr: &Expr) -> bool {
961+
let parents: Vec<_> = checker.semantic().current_statements().collect();
962+
963+
for stmt in parents {
964+
if let Stmt::FunctionDef(function_def) = stmt {
965+
if is_task_decorated_function(checker, function_def) {
966+
let arguments = extract_task_function_arguments(function_def);
967+
968+
for deprecated_arg in REMOVED_CONTEXT_KEYS {
969+
if arguments.contains(&deprecated_arg.to_string()) {
970+
checker.diagnostics.push(Diagnostic::new(
971+
Airflow3Removal {
972+
deprecated: deprecated_arg.to_string(),
973+
replacement: Replacement::None,
974+
},
975+
expr.range(),
976+
));
977+
return true;
978+
}
979+
}
980+
}
981+
}
982+
}
983+
984+
false
985+
}
986+
987+
fn extract_task_function_arguments(stmt: &StmtFunctionDef) -> Vec<String> {
988+
let mut arguments = Vec::new();
989+
990+
for param in &stmt.parameters.args {
991+
arguments.push(param.parameter.name.to_string());
992+
}
993+
994+
if let Some(vararg) = &stmt.parameters.kwarg {
995+
arguments.push(format!("**{}", vararg.name));
996+
}
997+
998+
arguments
999+
}
1000+
1001+
fn is_task_decorated_function(checker: &mut Checker, stmt: &StmtFunctionDef) -> bool {
1002+
stmt.decorator_list.iter().any(|decorator| {
1003+
checker
1004+
.semantic()
1005+
.resolve_qualified_name(map_callable(&decorator.expression))
1006+
.is_some_and(|qualified_name| {
1007+
matches!(qualified_name.segments(), ["airflow", "decorators", "task"])
1008+
})
1009+
})
1010+
}

0 commit comments

Comments
 (0)