Check AIR001 from builtin or providers operators module (#14631)

## Summary

This PR makes changes to the `AIR001` rule as per
https://github.com/astral-sh/ruff/pull/14627#discussion_r1860212307.

Additionally,
* Avoid returning the `Diagnostic` and update the checker in the rule
logic for consistency
* Remove test case for different keyword position (I don't think it's
required here)

## Test Plan

Add test cases for multiple operators from various modules.
This commit is contained in:
Dhruv Manilawala 2024-12-04 13:30:47 +05:30 committed by GitHub
parent edce559431
commit 575deb5d4d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 69 additions and 38 deletions

View file

@ -1,4 +1,6 @@
from airflow.operators import PythonOperator
from airflow.providers.airbyte.operators.airbyte import AirbyteTriggerSyncOperator
from airflow.providers.amazon.aws.operators.appflow import AppflowFlowRunOperator
def my_callable():
@ -6,11 +8,15 @@ def my_callable():
my_task = PythonOperator(task_id="my_task", callable=my_callable)
my_task_2 = PythonOperator(callable=my_callable, task_id="my_task_2")
incorrect_name = PythonOperator(task_id="my_task") # AIR001
incorrect_name = PythonOperator(task_id="my_task")
incorrect_name_2 = PythonOperator(callable=my_callable, task_id="my_task_2")
my_task = AirbyteTriggerSyncOperator(task_id="my_task", callable=my_callable)
incorrect_name = AirbyteTriggerSyncOperator(task_id="my_task") # AIR001
from my_module import MyClass
my_task = AppflowFlowRunOperator(task_id="my_task", callable=my_callable)
incorrect_name = AppflowFlowRunOperator(task_id="my_task") # AIR001
incorrect_name = MyClass(task_id="my_task")
# Consider only from the `airflow.operators` (or providers operators) module
from airflow import MyOperator
incorrect_name = MyOperator(task_id="my_task")

View file

@ -1554,11 +1554,7 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) {
.rules
.enabled(Rule::AirflowVariableNameTaskIdMismatch)
{
if let Some(diagnostic) =
airflow::rules::variable_name_task_id(checker, targets, value)
{
checker.diagnostics.push(diagnostic);
}
airflow::rules::variable_name_task_id(checker, targets, value);
}
if checker.settings.rules.enabled(Rule::SelfAssigningVariable) {
pylint::rules::self_assignment(checker, assign);

View file

@ -45,21 +45,17 @@ impl Violation for AirflowVariableNameTaskIdMismatch {
}
/// AIR001
pub(crate) fn variable_name_task_id(
checker: &mut Checker,
targets: &[Expr],
value: &Expr,
) -> Option<Diagnostic> {
pub(crate) fn variable_name_task_id(checker: &mut Checker, targets: &[Expr], value: &Expr) {
if !checker.semantic().seen_module(Modules::AIRFLOW) {
return None;
return;
}
// If we have more than one target, we can't do anything.
let [target] = targets else {
return None;
return;
};
let Expr::Name(ast::ExprName { id, .. }) = target else {
return None;
return;
};
// If the value is not a call, we can't do anything.
@ -67,33 +63,58 @@ pub(crate) fn variable_name_task_id(
func, arguments, ..
}) = value
else {
return None;
return;
};
// If the function doesn't come from Airflow, we can't do anything.
// If the function doesn't come from Airflow's operators module (builtin or providers), we
// can't do anything.
if !checker
.semantic()
.resolve_qualified_name(func)
.is_some_and(|qualified_name| matches!(qualified_name.segments(), ["airflow", ..]))
.is_some_and(|qualified_name| {
match qualified_name.segments() {
// Match `airflow.operators.*`
["airflow", "operators", ..] => true,
// Match `airflow.providers.**.operators.*`
["airflow", "providers", rest @ ..] => {
// Ensure 'operators' exists somewhere in the middle
if let Some(pos) = rest.iter().position(|&s| s == "operators") {
pos + 1 < rest.len() // Check that 'operators' is not the last element
} else {
false
}
}
_ => false,
}
})
{
return None;
return;
}
// If the call doesn't have a `task_id` keyword argument, we can't do anything.
let keyword = arguments.find_keyword("task_id")?;
let Some(keyword) = arguments.find_keyword("task_id") else {
return;
};
// If the keyword argument is not a string, we can't do anything.
let ast::ExprStringLiteral { value: task_id, .. } = keyword.value.as_string_literal_expr()?;
let Some(ast::ExprStringLiteral { value: task_id, .. }) =
keyword.value.as_string_literal_expr()
else {
return;
};
// If the target name is the same as the task_id, no violation.
if task_id == id.as_str() {
return None;
return;
}
Some(Diagnostic::new(
let diagnostic = Diagnostic::new(
AirflowVariableNameTaskIdMismatch {
task_id: task_id.to_string(),
},
target.range(),
))
);
checker.diagnostics.push(diagnostic);
}

View file

@ -1,21 +1,29 @@
---
source: crates/ruff_linter/src/rules/airflow/mod.rs
snapshot_kind: text
---
AIR001.py:11:1: AIR001 Task variable name should match the `task_id`: "my_task"
|
9 | my_task_2 = PythonOperator(callable=my_callable, task_id="my_task_2")
10 |
11 | incorrect_name = PythonOperator(task_id="my_task")
10 | my_task = PythonOperator(task_id="my_task", callable=my_callable)
11 | incorrect_name = PythonOperator(task_id="my_task") # AIR001
| ^^^^^^^^^^^^^^ AIR001
12 | incorrect_name_2 = PythonOperator(callable=my_callable, task_id="my_task_2")
12 |
13 | my_task = AirbyteTriggerSyncOperator(task_id="my_task", callable=my_callable)
|
AIR001.py:12:1: AIR001 Task variable name should match the `task_id`: "my_task_2"
AIR001.py:14:1: AIR001 Task variable name should match the `task_id`: "my_task"
|
11 | incorrect_name = PythonOperator(task_id="my_task")
12 | incorrect_name_2 = PythonOperator(callable=my_callable, task_id="my_task_2")
| ^^^^^^^^^^^^^^^^ AIR001
13 |
14 | from my_module import MyClass
13 | my_task = AirbyteTriggerSyncOperator(task_id="my_task", callable=my_callable)
14 | incorrect_name = AirbyteTriggerSyncOperator(task_id="my_task") # AIR001
| ^^^^^^^^^^^^^^ AIR001
15 |
16 | my_task = AppflowFlowRunOperator(task_id="my_task", callable=my_callable)
|
AIR001.py:17:1: AIR001 Task variable name should match the `task_id`: "my_task"
|
16 | my_task = AppflowFlowRunOperator(task_id="my_task", callable=my_callable)
17 | incorrect_name = AppflowFlowRunOperator(task_id="my_task") # AIR001
| ^^^^^^^^^^^^^^ AIR001
18 |
19 | # Consider only from the `airflow.operators` (or providers operators) module
|