Refactor RET504 to only enforce assignment-then-return pattern (#4997)

## Summary

The `RET504` rule, which looks for unnecessary assignments before return
statements, is a frequent source of issues (#4173, #4236, #4242, #1606,
#2950). Over time, we've tried to refine the logic to handle more cases.
For example, we now avoid analyzing any functions that contain any
function calls or attribute assignments, since those operations can
contain side effects (and so we mark them as a "read" on all variables
in the function -- we could do a better job with code graph analysis to
handle this limitation, but that'd be a more involved change.) We also
avoid flagging any variables that are the target of multiple
assignments. Ultimately, though, I'm not happy with the implementation
-- we just can't do sufficiently reliable analysis of arbitrary code
flow given the limited logic herein, and the existing logic is very hard
to reason about and maintain.

This PR refocuses the rule to only catch cases of the form:

```py
def f():
    x = 1
    return x
```

That is, we now only flag returns that are immediately preceded by an
assignment to the returned variable. While this is more limiting, in
some ways, it lets us flag more cases vis-a-vis the previous
implementation, since we no longer "fully eject" when functions contain
function calls and other effect-ful operations.

Closes #4173.

Closes #4236.

Closes #4242.
This commit is contained in:
Charlie Marsh 2023-06-10 00:05:01 -04:00 committed by GitHub
parent 5abb8ec0dc
commit 02b8ce82af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 156 additions and 318 deletions

View file

@ -1,4 +1,5 @@
use ruff_text_size::TextSize; use ruff_text_size::TextSize;
use rustpython_parser::ast;
use rustpython_parser::ast::{Expr, Ranged, Stmt}; use rustpython_parser::ast::{Expr, Ranged, Stmt};
use ruff_python_ast::source_code::Locator; use ruff_python_ast::source_code::Locator;
@ -6,15 +7,14 @@ use ruff_python_whitespace::UniversalNewlines;
/// Return `true` if a function's return statement include at least one /// Return `true` if a function's return statement include at least one
/// non-`None` value. /// non-`None` value.
pub(super) fn result_exists(returns: &[(&Stmt, Option<&Expr>)]) -> bool { pub(super) fn result_exists(returns: &[&ast::StmtReturn]) -> bool {
returns.iter().any(|(_, expr)| { returns.iter().any(|stmt| {
expr.map(|expr| { stmt.value.as_deref().map_or(false, |value| {
!matches!( !matches!(
expr, value,
Expr::Constant(ref constant) if constant.value.is_none() Expr::Constant(constant) if constant.value.is_none()
) )
}) })
.unwrap_or(false)
}) })
} }
@ -26,12 +26,11 @@ pub(super) fn result_exists(returns: &[(&Stmt, Option<&Expr>)]) -> bool {
/// This method assumes that the statement is the last statement in its body; specifically, that /// This method assumes that the statement is the last statement in its body; specifically, that
/// the statement isn't followed by a semicolon, followed by a multi-line statement. /// the statement isn't followed by a semicolon, followed by a multi-line statement.
pub(super) fn end_of_last_statement(stmt: &Stmt, locator: &Locator) -> TextSize { pub(super) fn end_of_last_statement(stmt: &Stmt, locator: &Locator) -> TextSize {
// End-of-file, so just return the end of the statement.
if stmt.end() == locator.text_len() { if stmt.end() == locator.text_len() {
// End-of-file, so just return the end of the statement.
stmt.end() stmt.end()
} } else {
// Otherwise, find the end of the last line that's "part of" the statement. // Otherwise, find the end of the last line that's "part of" the statement.
else {
let contents = locator.after(stmt.end()); let contents = locator.after(stmt.end());
for line in contents.universal_newlines() { for line in contents.universal_newlines() {

View file

@ -1,5 +1,3 @@
use itertools::Itertools;
use ruff_text_size::{TextRange, TextSize};
use rustpython_parser::ast::{self, Constant, Expr, Ranged, Stmt}; use rustpython_parser::ast::{self, Constant, Expr, Ranged, Stmt};
use ruff_diagnostics::{AlwaysAutofixableViolation, Violation}; use ruff_diagnostics::{AlwaysAutofixableViolation, Violation};
@ -139,8 +137,8 @@ impl AlwaysAutofixableViolation for ImplicitReturn {
} }
/// ## What it does /// ## What it does
/// Checks for variable assignments that are unused between the assignment and /// Checks for variable assignments that immediately precede a `return` of the
/// a `return` of the variable. /// assigned variable.
/// ///
/// ## Why is this bad? /// ## Why is this bad?
/// The variable assignment is not necessary as the value can be returned /// The variable assignment is not necessary as the value can be returned
@ -159,12 +157,15 @@ impl AlwaysAutofixableViolation for ImplicitReturn {
/// return 1 /// return 1
/// ``` /// ```
#[violation] #[violation]
pub struct UnnecessaryAssign; pub struct UnnecessaryAssign {
name: String,
}
impl Violation for UnnecessaryAssign { impl Violation for UnnecessaryAssign {
#[derive_message_formats] #[derive_message_formats]
fn message(&self) -> String { fn message(&self) -> String {
format!("Unnecessary variable assignment before `return` statement") let UnnecessaryAssign { name } = self;
format!("Unnecessary assignment to `{name}` before `return` statement")
} }
} }
@ -326,8 +327,8 @@ impl Violation for SuperfluousElseBreak {
/// RET501 /// RET501
fn unnecessary_return_none(checker: &mut Checker, stack: &Stack) { fn unnecessary_return_none(checker: &mut Checker, stack: &Stack) {
for (stmt, expr) in &stack.returns { for stmt in &stack.returns {
let Some(expr) = expr else { let Some(expr) = stmt.value.as_deref() else {
continue; continue;
}; };
if !matches!( if !matches!(
@ -339,10 +340,9 @@ fn unnecessary_return_none(checker: &mut Checker, stack: &Stack) {
) { ) {
continue; continue;
} }
let mut diagnostic = Diagnostic::new(UnnecessaryReturnNone, stmt.range()); let mut diagnostic = Diagnostic::new(UnnecessaryReturnNone, stmt.range);
if checker.patch(diagnostic.kind.rule()) { if checker.patch(diagnostic.kind.rule()) {
#[allow(deprecated)] diagnostic.set_fix(Fix::automatic(Edit::range_replacement(
diagnostic.set_fix(Fix::unspecified(Edit::range_replacement(
"return".to_string(), "return".to_string(),
stmt.range(), stmt.range(),
))); )));
@ -353,16 +353,15 @@ fn unnecessary_return_none(checker: &mut Checker, stack: &Stack) {
/// RET502 /// RET502
fn implicit_return_value(checker: &mut Checker, stack: &Stack) { fn implicit_return_value(checker: &mut Checker, stack: &Stack) {
for (stmt, expr) in &stack.returns { for stmt in &stack.returns {
if expr.is_some() { if stmt.value.is_some() {
continue; continue;
} }
let mut diagnostic = Diagnostic::new(ImplicitReturnValue, stmt.range()); let mut diagnostic = Diagnostic::new(ImplicitReturnValue, stmt.range);
if checker.patch(diagnostic.kind.rule()) { if checker.patch(diagnostic.kind.rule()) {
#[allow(deprecated)] diagnostic.set_fix(Fix::automatic(Edit::range_replacement(
diagnostic.set_fix(Fix::unspecified(Edit::range_replacement(
"return None".to_string(), "return None".to_string(),
stmt.range(), stmt.range,
))); )));
} }
checker.diagnostics.push(diagnostic); checker.diagnostics.push(diagnostic);
@ -417,8 +416,7 @@ fn implicit_return(checker: &mut Checker, stmt: &Stmt) {
content.push_str(checker.stylist.line_ending().as_str()); content.push_str(checker.stylist.line_ending().as_str());
content.push_str(indent); content.push_str(indent);
content.push_str("return None"); content.push_str("return None");
#[allow(deprecated)] diagnostic.set_fix(Fix::suggested(Edit::insertion(
diagnostic.set_fix(Fix::unspecified(Edit::insertion(
content, content,
end_of_last_statement(stmt, checker.locator), end_of_last_statement(stmt, checker.locator),
))); )));
@ -456,8 +454,7 @@ fn implicit_return(checker: &mut Checker, stmt: &Stmt) {
content.push_str(checker.stylist.line_ending().as_str()); content.push_str(checker.stylist.line_ending().as_str());
content.push_str(indent); content.push_str(indent);
content.push_str("return None"); content.push_str("return None");
#[allow(deprecated)] diagnostic.set_fix(Fix::suggested(Edit::insertion(
diagnostic.set_fix(Fix::unspecified(Edit::insertion(
content, content,
end_of_last_statement(stmt, checker.locator), end_of_last_statement(stmt, checker.locator),
))); )));
@ -494,8 +491,7 @@ fn implicit_return(checker: &mut Checker, stmt: &Stmt) {
content.push_str(checker.stylist.line_ending().as_str()); content.push_str(checker.stylist.line_ending().as_str());
content.push_str(indent); content.push_str(indent);
content.push_str("return None"); content.push_str("return None");
#[allow(deprecated)] diagnostic.set_fix(Fix::suggested(Edit::insertion(
diagnostic.set_fix(Fix::unspecified(Edit::insertion(
content, content,
end_of_last_statement(stmt, checker.locator), end_of_last_statement(stmt, checker.locator),
))); )));
@ -506,129 +502,51 @@ fn implicit_return(checker: &mut Checker, stmt: &Stmt) {
} }
} }
/// Return `true` if the `id` has multiple declarations within the function.
fn has_multiple_declarations(id: &str, stack: &Stack) -> bool {
stack
.declarations
.get(&id)
.map_or(false, |declarations| declarations.len() > 1)
}
/// Return `true` if the `id` has a (read) reference between the `return_location` and its
/// preceding declaration.
fn has_references_before_next_declaration(
id: &str,
return_range: TextRange,
stack: &Stack,
) -> bool {
let mut declaration_before_return: Option<TextSize> = None;
let mut declaration_after_return: Option<TextSize> = None;
if let Some(assignments) = stack.declarations.get(&id) {
for location in assignments.iter().sorted() {
if *location > return_range.start() {
declaration_after_return = Some(*location);
break;
}
declaration_before_return = Some(*location);
}
}
// If there is no declaration before the return, then the variable must be declared in
// some other way (e.g., a function argument). No need to check for references.
let Some(declaration_before_return) = declaration_before_return else {
return true;
};
if let Some(references) = stack.references.get(&id) {
for location in references {
if return_range.contains(*location) {
continue;
}
if declaration_before_return < *location {
if let Some(declaration_after_return) = declaration_after_return {
if *location <= declaration_after_return {
return true;
}
} else {
return true;
}
}
}
}
false
}
/// Return `true` if the `id` has a read or write reference within a `try` or loop body.
fn has_references_or_declarations_within_try_or_loop(id: &str, stack: &Stack) -> bool {
if let Some(references) = stack.references.get(&id) {
for location in references {
for try_range in &stack.tries {
if try_range.contains(*location) {
return true;
}
}
for loop_range in &stack.loops {
if loop_range.contains(*location) {
return true;
}
}
}
}
if let Some(references) = stack.declarations.get(&id) {
for location in references {
for try_range in &stack.tries {
if try_range.contains(*location) {
return true;
}
}
for loop_range in &stack.loops {
if loop_range.contains(*location) {
return true;
}
}
}
}
false
}
/// RET504 /// RET504
fn unnecessary_assign(checker: &mut Checker, stack: &Stack, expr: &Expr) { fn unnecessary_assign(checker: &mut Checker, stack: &Stack) {
if let Expr::Name(ast::ExprName { id, .. }) = expr { for (stmt_assign, stmt_return) in &stack.assignments {
if !stack.assigned_names.contains(id.as_str()) { // Identify, e.g., `return x`.
return; let Some(value) = stmt_return.value.as_ref() else {
continue;
};
let Expr::Name(ast::ExprName { id: returned_id, .. }) = value.as_ref() else {
continue;
};
// Identify, e.g., `x = 1`.
if stmt_assign.targets.len() > 1 {
continue;
} }
if !stack.references.contains_key(id.as_str()) { let Some(target) = stmt_assign.targets.first() else {
checker continue;
.diagnostics };
.push(Diagnostic::new(UnnecessaryAssign, expr.range()));
return; let Expr::Name(ast::ExprName { id: assigned_id, .. }) = target else {
continue;
};
if returned_id != assigned_id {
continue;
} }
if has_multiple_declarations(id, stack) if stack.non_locals.contains(assigned_id.as_str()) {
|| has_references_before_next_declaration(id, expr.range(), stack) continue;
|| has_references_or_declarations_within_try_or_loop(id, stack)
{
return;
} }
if stack.non_locals.contains(id.as_str()) { checker.diagnostics.push(Diagnostic::new(
return; UnnecessaryAssign {
} name: assigned_id.to_string(),
},
checker value.range(),
.diagnostics ));
.push(Diagnostic::new(UnnecessaryAssign, expr.range()));
} }
} }
/// RET505, RET506, RET507, RET508 /// RET505, RET506, RET507, RET508
fn superfluous_else_node(checker: &mut Checker, stmt: &Stmt, branch: Branch) -> bool { fn superfluous_else_node(checker: &mut Checker, stmt: &ast::StmtIf, branch: Branch) -> bool {
let Stmt::If(ast::StmtIf { body, .. }) = stmt else { let ast::StmtIf { body, .. } = stmt;
return false;
};
for child in body { for child in body {
if child.is_return_stmt() { if child.is_return_stmt() {
let diagnostic = Diagnostic::new( let diagnostic = Diagnostic::new(
@ -708,7 +626,7 @@ pub(crate) fn function(checker: &mut Checker, body: &[Stmt], returns: Option<&Ex
}; };
// Avoid false positives for generators. // Avoid false positives for generators.
if !stack.yields.is_empty() { if stack.is_generator {
return; return;
} }
@ -737,11 +655,7 @@ pub(crate) fn function(checker: &mut Checker, body: &[Stmt], returns: Option<&Ex
} }
if checker.enabled(Rule::UnnecessaryAssign) { if checker.enabled(Rule::UnnecessaryAssign) {
for (_, expr) in &stack.returns { unnecessary_assign(checker, &stack);
if let Some(expr) = expr {
unnecessary_assign(checker, &stack, expr);
}
}
} }
} else { } else {
if checker.enabled(Rule::UnnecessaryReturnNone) { if checker.enabled(Rule::UnnecessaryReturnNone) {

View file

@ -10,7 +10,7 @@ RET501.py:4:5: RET501 [*] Do not explicitly `return None` in function if it is t
| |
= help: Remove explicit `return None` = help: Remove explicit `return None`
Suggested fix Fix
1 1 | def x(y): 1 1 | def x(y):
2 2 | if not y: 2 2 | if not y:
3 3 | return 3 3 | return
@ -29,7 +29,7 @@ RET501.py:14:9: RET501 [*] Do not explicitly `return None` in function if it is
| |
= help: Remove explicit `return None` = help: Remove explicit `return None`
Suggested fix Fix
11 11 | 11 11 |
12 12 | def get(self, key: str) -> None: 12 12 | def get(self, key: str) -> None:
13 13 | print(f"{key} not found") 13 13 | print(f"{key} not found")

View file

@ -11,7 +11,7 @@ RET502.py:3:9: RET502 [*] Do not implicitly `return None` in function able to re
| |
= help: Add explicit `None` return value = help: Add explicit `None` return value
Suggested fix Fix
1 1 | def x(y): 1 1 | def x(y):
2 2 | if not y: 2 2 | if not y:
3 |- return # error 3 |- return # error

View file

@ -1,7 +1,7 @@
--- ---
source: crates/ruff/src/rules/flake8_return/mod.rs source: crates/ruff/src/rules/flake8_return/mod.rs
--- ---
RET504.py:6:12: RET504 Unnecessary variable assignment before `return` statement RET504.py:6:12: RET504 Unnecessary assignment to `a` before `return` statement
| |
4 | def x(): 4 | def x():
5 | a = 1 5 | a = 1
@ -9,7 +9,23 @@ RET504.py:6:12: RET504 Unnecessary variable assignment before `return` statement
| ^ RET504 | ^ RET504
| |
RET504.py:250:12: RET504 Unnecessary variable assignment before `return` statement RET504.py:23:12: RET504 Unnecessary assignment to `formatted` before `return` statement
|
21 | # clean up after any blank components
22 | formatted = formatted.replace("()", "").replace(" ", " ").strip()
23 | return formatted
| ^^^^^^^^^ RET504
|
RET504.py:245:12: RET504 Unnecessary assignment to `queryset` before `return` statement
|
243 | queryset = Model.filter(a=1)
244 | queryset = queryset.filter(c=3)
245 | return queryset
| ^^^^^^^^ RET504
|
RET504.py:250:12: RET504 Unnecessary assignment to `queryset` before `return` statement
| |
248 | def get_queryset(): 248 | def get_queryset():
249 | queryset = Model.filter(a=1) 249 | queryset = Model.filter(a=1)
@ -17,7 +33,7 @@ RET504.py:250:12: RET504 Unnecessary variable assignment before `return` stateme
| ^^^^^^^^ RET504 | ^^^^^^^^ RET504
| |
RET504.py:268:12: RET504 Unnecessary variable assignment before `return` statement RET504.py:268:12: RET504 Unnecessary assignment to `val` before `return` statement
| |
266 | return val 266 | return val
267 | val = 1 267 | val = 1

View file

@ -1,118 +1,65 @@
use ruff_text_size::{TextRange, TextSize}; use rustc_hash::FxHashSet;
use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_parser::ast::{self, Expr, Identifier, Stmt};
use rustpython_parser::ast::{self, Expr, Identifier, Ranged, Stmt};
use ruff_python_ast::visitor; use ruff_python_ast::visitor;
use ruff_python_ast::visitor::Visitor; use ruff_python_ast::visitor::Visitor;
#[derive(Default)] #[derive(Default)]
pub(crate) struct Stack<'a> { pub(crate) struct Stack<'a> {
pub(crate) returns: Vec<(&'a Stmt, Option<&'a Expr>)>, /// The `return` statements in the current function.
pub(crate) yields: Vec<&'a Expr>, pub(crate) returns: Vec<&'a ast::StmtReturn>,
pub(crate) elses: Vec<&'a Stmt>, /// The `else` statements in the current function.
pub(crate) elifs: Vec<&'a Stmt>, pub(crate) elses: Vec<&'a ast::StmtIf>,
/// The names that are assigned to in the current scope (e.g., anything on the left-hand side of /// The `elif` statements in the current function.
/// an assignment). pub(crate) elifs: Vec<&'a ast::StmtIf>,
pub(crate) assigned_names: FxHashSet<&'a str>, /// The non-local variables in the current function.
/// The names that are declared in the current scope, and the ranges of those declarations
/// (e.g., assignments, but also function and class definitions).
pub(crate) declarations: FxHashMap<&'a str, Vec<TextSize>>,
pub(crate) references: FxHashMap<&'a str, Vec<TextSize>>,
pub(crate) non_locals: FxHashSet<&'a str>, pub(crate) non_locals: FxHashSet<&'a str>,
pub(crate) loops: Vec<TextRange>, /// Whether the current function is a generator.
pub(crate) tries: Vec<TextRange>, pub(crate) is_generator: bool,
/// The `assignment`-to-`return` statement pairs in the current function.
pub(crate) assignments: Vec<(&'a ast::StmtAssign, &'a ast::StmtReturn)>,
} }
#[derive(Default)] #[derive(Default)]
pub(crate) struct ReturnVisitor<'a> { pub(crate) struct ReturnVisitor<'a> {
/// The current stack of nodes.
pub(crate) stack: Stack<'a>, pub(crate) stack: Stack<'a>,
/// The preceding sibling of the current node.
sibling: Option<&'a Stmt>,
/// The parent nodes of the current node.
parents: Vec<&'a Stmt>, parents: Vec<&'a Stmt>,
} }
impl<'a> ReturnVisitor<'a> {
fn visit_assign_target(&mut self, expr: &'a Expr) {
match expr {
Expr::Tuple(ast::ExprTuple { elts, .. }) => {
for elt in elts {
self.visit_assign_target(elt);
}
return;
}
Expr::Name(ast::ExprName { id, .. }) => {
self.stack.assigned_names.insert(id.as_str());
self.stack
.declarations
.entry(id)
.or_insert_with(Vec::new)
.push(expr.start());
return;
}
Expr::Attribute(_) => {
// Attribute assignments are often side-effects (e.g., `self.property = value`),
// so we conservatively treat them as references to every known
// variable.
for name in self.stack.declarations.keys() {
self.stack
.references
.entry(name)
.or_insert_with(Vec::new)
.push(expr.start());
}
}
_ => {}
}
visitor::walk_expr(self, expr);
}
}
impl<'a> Visitor<'a> for ReturnVisitor<'a> { impl<'a> Visitor<'a> for ReturnVisitor<'a> {
fn visit_stmt(&mut self, stmt: &'a Stmt) { fn visit_stmt(&mut self, stmt: &'a Stmt) {
match stmt { match stmt {
Stmt::Global(ast::StmtGlobal { names, range: _ }) Stmt::ClassDef(ast::StmtClassDef { decorator_list, .. }) => {
| Stmt::Nonlocal(ast::StmtNonlocal { names, range: _ }) => { // Visit the decorators, etc.
self.stack self.sibling = Some(stmt);
.non_locals self.parents.push(stmt);
.extend(names.iter().map(Identifier::as_str));
}
Stmt::ClassDef(ast::StmtClassDef {
decorator_list,
name,
..
}) => {
// Mark a declaration.
self.stack
.declarations
.entry(name.as_str())
.or_insert_with(Vec::new)
.push(stmt.start());
// Don't recurse into the body, but visit the decorators, etc.
for decorator in decorator_list { for decorator in decorator_list {
visitor::walk_decorator(self, decorator); visitor::walk_decorator(self, decorator);
} }
self.parents.pop();
// But don't recurse into the body.
return;
} }
Stmt::FunctionDef(ast::StmtFunctionDef { Stmt::FunctionDef(ast::StmtFunctionDef {
name,
args, args,
decorator_list, decorator_list,
returns, returns,
.. ..
}) })
| Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef {
name,
args, args,
decorator_list, decorator_list,
returns, returns,
.. ..
}) => { }) => {
// Mark a declaration. // Visit the decorators, etc.
self.stack self.sibling = Some(stmt);
.declarations self.parents.push(stmt);
.entry(name.as_str())
.or_insert_with(Vec::new)
.push(stmt.start());
// Don't recurse into the body, but visit the decorators, etc.
for decorator in decorator_list { for decorator in decorator_list {
visitor::walk_decorator(self, decorator); visitor::walk_decorator(self, decorator);
} }
@ -120,17 +67,27 @@ impl<'a> Visitor<'a> for ReturnVisitor<'a> {
visitor::walk_expr(self, returns); visitor::walk_expr(self, returns);
} }
visitor::walk_arguments(self, args); visitor::walk_arguments(self, args);
}
Stmt::Return(ast::StmtReturn { value, range: _ }) => {
self.stack
.returns
.push((stmt, value.as_ref().map(|expr| &**expr)));
self.parents.push(stmt);
visitor::walk_stmt(self, stmt);
self.parents.pop(); self.parents.pop();
// But don't recurse into the body.
return;
} }
Stmt::If(ast::StmtIf { orelse, .. }) => { Stmt::Global(ast::StmtGlobal { names, range: _ })
| Stmt::Nonlocal(ast::StmtNonlocal { names, range: _ }) => {
self.stack
.non_locals
.extend(names.iter().map(Identifier::as_str));
}
Stmt::Return(stmt_return) => {
// If the `return` statement is preceded by an `assignment` statement, then the
// `assignment` statement may be redundant.
if let Some(stmt_assign) = self.sibling.and_then(Stmt::as_assign_stmt) {
self.stack.assignments.push((stmt_assign, stmt_return));
}
self.stack.returns.push(stmt_return);
}
Stmt::If(stmt_if) => {
let is_elif_arm = self.parents.iter().any(|parent| { let is_elif_arm = self.parents.iter().any(|parent| {
if let Stmt::If(ast::StmtIf { orelse, .. }) = parent { if let Stmt::If(ast::StmtIf { orelse, .. }) = parent {
orelse.len() == 1 && &orelse[0] == stmt orelse.len() == 1 && &orelse[0] == stmt
@ -141,88 +98,40 @@ impl<'a> Visitor<'a> for ReturnVisitor<'a> {
if !is_elif_arm { if !is_elif_arm {
let has_elif = let has_elif =
orelse.len() == 1 && matches!(orelse.first().unwrap(), Stmt::If(_)); stmt_if.orelse.len() == 1 && stmt_if.orelse.first().unwrap().is_if_stmt();
let has_else = !orelse.is_empty(); let has_else = !stmt_if.orelse.is_empty();
if has_elif { if has_elif {
// `stmt` is an `if` block followed by an `elif` clause. // `stmt` is an `if` block followed by an `elif` clause.
self.stack.elifs.push(stmt); self.stack.elifs.push(stmt_if);
} else if has_else { } else if has_else {
// `stmt` is an `if` block followed by an `else` clause. // `stmt` is an `if` block followed by an `else` clause.
self.stack.elses.push(stmt); self.stack.elses.push(stmt_if);
} }
} }
self.parents.push(stmt);
visitor::walk_stmt(self, stmt);
self.parents.pop();
}
Stmt::Assign(ast::StmtAssign { targets, value, .. }) => {
if let Expr::Name(ast::ExprName { id, .. }) = value.as_ref() {
self.stack
.references
.entry(id)
.or_insert_with(Vec::new)
.push(value.start());
}
visitor::walk_expr(self, value);
if let Some(target) = targets.first() {
// Skip unpacking assignments, like `x, y = my_object`.
if target.is_tuple_expr() && !value.is_tuple_expr() {
return;
}
self.visit_assign_target(target);
}
}
Stmt::For(_) | Stmt::AsyncFor(_) | Stmt::While(_) => {
self.stack.loops.push(stmt.range());
self.parents.push(stmt);
visitor::walk_stmt(self, stmt);
self.parents.pop();
}
Stmt::Try(_) | Stmt::TryStar(_) => {
self.stack.tries.push(stmt.range());
self.parents.push(stmt);
visitor::walk_stmt(self, stmt);
self.parents.pop();
}
_ => {
self.parents.push(stmt);
visitor::walk_stmt(self, stmt);
self.parents.pop();
} }
_ => {}
} }
self.sibling = Some(stmt);
self.parents.push(stmt);
visitor::walk_stmt(self, stmt);
self.parents.pop();
} }
fn visit_expr(&mut self, expr: &'a Expr) { fn visit_expr(&mut self, expr: &'a Expr) {
match expr { match expr {
Expr::Call(_) => {
// Arbitrary function calls can have side effects, so we conservatively treat
// every function call as a reference to every known variable.
for name in self.stack.declarations.keys() {
self.stack
.references
.entry(name)
.or_insert_with(Vec::new)
.push(expr.start());
}
}
Expr::Name(ast::ExprName { id, .. }) => {
self.stack
.references
.entry(id)
.or_insert_with(Vec::new)
.push(expr.start());
}
Expr::YieldFrom(_) | Expr::Yield(_) => { Expr::YieldFrom(_) | Expr::Yield(_) => {
self.stack.yields.push(expr); self.stack.is_generator = true;
} }
_ => visitor::walk_expr(self, expr), _ => visitor::walk_expr(self, expr),
} }
} }
fn visit_body(&mut self, body: &'a [Stmt]) {
let sibling = self.sibling;
self.sibling = None;
visitor::walk_body(self, body);
self.sibling = sibling;
}
} }

View file

@ -1202,10 +1202,8 @@ pub fn first_colon_range(range: TextRange, locator: &Locator) -> Option<TextRang
} }
/// Return the `Range` of the first `Elif` or `Else` token in an `If` statement. /// Return the `Range` of the first `Elif` or `Else` token in an `If` statement.
pub fn elif_else_range(stmt: &Stmt, locator: &Locator) -> Option<TextRange> { pub fn elif_else_range(stmt: &ast::StmtIf, locator: &Locator) -> Option<TextRange> {
let Stmt::If(ast::StmtIf { body, orelse, .. } )= stmt else { let ast::StmtIf { body, orelse, .. } = stmt;
return None;
};
let start = body.last().expect("Expected body to be non-empty").end(); let start = body.last().expect("Expected body to be non-empty").end();
@ -1619,7 +1617,7 @@ mod tests {
use anyhow::Result; use anyhow::Result;
use ruff_text_size::{TextLen, TextRange, TextSize}; use ruff_text_size::{TextLen, TextRange, TextSize};
use rustpython_ast::Suite; use rustpython_ast::{Stmt, Suite};
use rustpython_parser::ast::Cmpop; use rustpython_parser::ast::Cmpop;
use rustpython_parser::Parse; use rustpython_parser::Parse;
@ -1819,6 +1817,7 @@ elif b:
.trim_start(); .trim_start();
let program = Suite::parse(contents, "<filename>")?; let program = Suite::parse(contents, "<filename>")?;
let stmt = program.first().unwrap(); let stmt = program.first().unwrap();
let stmt = Stmt::as_if_stmt(stmt).unwrap();
let locator = Locator::new(contents); let locator = Locator::new(contents);
let range = elif_else_range(stmt, &locator).unwrap(); let range = elif_else_range(stmt, &locator).unwrap();
assert_eq!(range.start(), TextSize::from(14)); assert_eq!(range.start(), TextSize::from(14));
@ -1833,6 +1832,7 @@ else:
.trim_start(); .trim_start();
let program = Suite::parse(contents, "<filename>")?; let program = Suite::parse(contents, "<filename>")?;
let stmt = program.first().unwrap(); let stmt = program.first().unwrap();
let stmt = Stmt::as_if_stmt(stmt).unwrap();
let locator = Locator::new(contents); let locator = Locator::new(contents);
let range = elif_else_range(stmt, &locator).unwrap(); let range = elif_else_range(stmt, &locator).unwrap();
assert_eq!(range.start(), TextSize::from(14)); assert_eq!(range.start(), TextSize::from(14));