From e66fdb83d06d875491cd37fce77c6ad90bcc267a Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Sat, 6 May 2023 22:32:40 -0400 Subject: [PATCH] Respect insertion location when importing symbols (#4258) --- .../test/fixtures/pylint/sys_exit_alias_7.py | 5 ++ .../test/fixtures/pylint/sys_exit_alias_8.py | 5 ++ .../test/fixtures/pylint/sys_exit_alias_9.py | 5 ++ crates/ruff/src/autofix/actions.rs | 6 +- crates/ruff/src/importer.rs | 69 ++++++++++--------- .../rules/abstract_base_class.rs | 10 ++- .../rules/suppressible_exception.rs | 1 + .../rules/isort/rules/add_required_imports.rs | 7 +- crates/ruff/src/rules/pylint/mod.rs | 3 + .../src/rules/pylint/rules/sys_exit_alias.rs | 1 + ...t__tests__PLR1722_sys_exit_alias_7.py.snap | 21 ++++++ ...t__tests__PLR1722_sys_exit_alias_8.py.snap | 19 +++++ ...t__tests__PLR1722_sys_exit_alias_9.py.snap | 21 ++++++ .../rules/lru_cache_with_maxsize_none.rs | 1 + 14 files changed, 134 insertions(+), 40 deletions(-) create mode 100644 crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_7.py create mode 100644 crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_8.py create mode 100644 crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_9.py create mode 100644 crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_7.py.snap create mode 100644 crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_8.py.snap create mode 100644 crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_9.py.snap diff --git a/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_7.py b/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_7.py new file mode 100644 index 0000000000..2771bfa603 --- /dev/null +++ b/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_7.py @@ -0,0 +1,5 @@ +def main(): + exit(0) + + +import functools diff --git a/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_8.py b/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_8.py new file mode 100644 index 0000000000..822bec52c7 --- /dev/null +++ b/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_8.py @@ -0,0 +1,5 @@ +from sys import argv + + +def main(): + exit(0) diff --git a/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_9.py b/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_9.py new file mode 100644 index 0000000000..326901d183 --- /dev/null +++ b/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_9.py @@ -0,0 +1,5 @@ +def main(): + exit(0) + + +from sys import argv diff --git a/crates/ruff/src/autofix/actions.rs b/crates/ruff/src/autofix/actions.rs index 283c9298a9..8b26783edf 100644 --- a/crates/ruff/src/autofix/actions.rs +++ b/crates/ruff/src/autofix/actions.rs @@ -435,6 +435,7 @@ pub fn remove_argument( pub fn get_or_import_symbol( module: &str, member: &str, + at: TextSize, context: &Context, importer: &Importer, locator: &Locator, @@ -462,7 +463,7 @@ pub fn get_or_import_symbol( Edit::range_replacement(locator.slice(source.range()).to_string(), source.range()); Ok((import_edit, binding)) } else { - if let Some(stmt) = importer.get_import_from(module) { + if let Some(stmt) = importer.find_import_from(module, at) { // Case 1: `from functools import lru_cache` is in scope, and we're trying to reference // `functools.cache`; thus, we add `cache` to the import, and return `"cache"` as the // bound name. @@ -485,7 +486,8 @@ pub fn get_or_import_symbol( .find_binding(module) .map_or(true, |binding| binding.kind.is_builtin()) { - let import_edit = importer.add_import(&AnyImport::Import(Import::module(module))); + let import_edit = + importer.add_import(&AnyImport::Import(Import::module(module)), at); Ok((import_edit, format!("{module}.{member}"))) } else { bail!( diff --git a/crates/ruff/src/importer.rs b/crates/ruff/src/importer.rs index fe003d1819..eb87bd749e 100644 --- a/crates/ruff/src/importer.rs +++ b/crates/ruff/src/importer.rs @@ -3,7 +3,6 @@ use anyhow::Result; use libcst_native::{Codegen, CodegenState, ImportAlias, Name, NameOrAttribute}; use ruff_text_size::TextSize; -use rustc_hash::FxHashMap; use rustpython_parser::ast::{Stmt, StmtKind, Suite}; use rustpython_parser::{lexer, Mode, Tok}; @@ -18,10 +17,7 @@ pub struct Importer<'a> { python_ast: &'a Suite, locator: &'a Locator<'a>, stylist: &'a Stylist<'a>, - /// A map from module name to top-level `StmtKind::ImportFrom` statements. - import_from_map: FxHashMap<&'a str, &'a Stmt>, - /// The last top-level import statement. - trailing_import: Option<&'a Stmt>, + ordered_imports: Vec<&'a Stmt>, } impl<'a> Importer<'a> { @@ -30,34 +26,21 @@ impl<'a> Importer<'a> { python_ast, locator, stylist, - import_from_map: FxHashMap::default(), - trailing_import: None, + ordered_imports: Vec::default(), } } /// Visit a top-level import statement. pub fn visit_import(&mut self, import: &'a Stmt) { - // Store a reference to the import statement in the appropriate map. - match &import.node { - StmtKind::Import { .. } => { - // Nothing to do here, we don't extend top-level `import` statements at all, so - // no need to track them. - } - StmtKind::ImportFrom { module, level, .. } => { - // Store a reverse-map from module name to `import ... from` statement. - if level.map_or(true, |level| level == 0) { - if let Some(module) = module { - self.import_from_map.insert(module.as_str(), import); - } - } - } - _ => { - panic!("Expected StmtKind::Import | StmtKind::ImportFrom"); - } - } + self.ordered_imports.push(import); + } - // Store a reference to the last top-level import statement. - self.trailing_import = Some(import); + /// Return the import statement that precedes the given position, if any. + fn preceding_import(&self, at: TextSize) -> Option<&Stmt> { + self.ordered_imports + .partition_point(|stmt| stmt.start() < at) + .checked_sub(1) + .map(|idx| self.ordered_imports[idx]) } /// Add an import statement to import the given module. @@ -65,9 +48,9 @@ impl<'a> Importer<'a> { /// If there are no existing imports, the new import will be added at the top /// of the file. Otherwise, it will be added after the most recent top-level /// import statement. - pub fn add_import(&self, import: &AnyImport) -> Edit { + pub fn add_import(&self, import: &AnyImport, at: TextSize) -> Edit { let required_import = import.to_string(); - if let Some(stmt) = self.trailing_import { + if let Some(stmt) = self.preceding_import(at) { // Insert after the last top-level import. let Insertion { prefix, @@ -88,10 +71,28 @@ impl<'a> Importer<'a> { } } - /// Return the top-level [`Stmt`] that imports the given module using `StmtKind::ImportFrom`. - /// if it exists. - pub fn get_import_from(&self, module: &str) -> Option<&Stmt> { - self.import_from_map.get(module).copied() + /// Return the top-level [`Stmt`] that imports the given module using `StmtKind::ImportFrom` + /// preceding the given position, if any. + pub fn find_import_from(&self, module: &str, at: TextSize) -> Option<&Stmt> { + let mut import_from = None; + for stmt in &self.ordered_imports { + if stmt.start() >= at { + break; + } + if let StmtKind::ImportFrom { + module: name, + level, + .. + } = &stmt.node + { + if level.map_or(true, |level| level == 0) + && name.as_ref().map_or(false, |name| name == module) + { + import_from = Some(*stmt); + } + } + } + import_from } /// Add the given member to an existing `StmtKind::ImportFrom` statement. @@ -240,11 +241,11 @@ fn top_of_file_insertion(body: &[Stmt], locator: &Locator, stylist: &Stylist) -> #[cfg(test)] mod tests { use anyhow::Result; - use ruff_python_ast::newlines::LineEnding; use ruff_text_size::TextSize; use rustpython_parser as parser; use rustpython_parser::lexer::LexResult; + use ruff_python_ast::newlines::LineEnding; use ruff_python_ast::source_code::{Locator, Stylist}; use crate::importer::{top_of_file_insertion, Insertion}; diff --git a/crates/ruff/src/rules/flake8_bugbear/rules/abstract_base_class.rs b/crates/ruff/src/rules/flake8_bugbear/rules/abstract_base_class.rs index f49db94cee..16e3c6f76e 100644 --- a/crates/ruff/src/rules/flake8_bugbear/rules/abstract_base_class.rs +++ b/crates/ruff/src/rules/flake8_bugbear/rules/abstract_base_class.rs @@ -85,8 +85,14 @@ fn fix_abstractmethod_missing( stmt: &Stmt, ) -> Result { let indent = indentation(locator, stmt).ok_or(anyhow!("Unable to detect indentation"))?; - let (import_edit, binding) = - get_or_import_symbol("abc", "abstractmethod", context, importer, locator)?; + let (import_edit, binding) = get_or_import_symbol( + "abc", + "abstractmethod", + stmt.start(), + context, + importer, + locator, + )?; let reference_edit = Edit::insertion( format!( "@{binding}{line_ending}{indent}", diff --git a/crates/ruff/src/rules/flake8_simplify/rules/suppressible_exception.rs b/crates/ruff/src/rules/flake8_simplify/rules/suppressible_exception.rs index faea06b192..5432560399 100644 --- a/crates/ruff/src/rules/flake8_simplify/rules/suppressible_exception.rs +++ b/crates/ruff/src/rules/flake8_simplify/rules/suppressible_exception.rs @@ -94,6 +94,7 @@ pub fn suppressible_exception( let (import_edit, binding) = get_or_import_symbol( "contextlib", "suppress", + stmt.start(), &checker.ctx, &checker.importer, checker.locator, diff --git a/crates/ruff/src/rules/isort/rules/add_required_imports.rs b/crates/ruff/src/rules/isort/rules/add_required_imports.rs index 2696e05f65..7eeaf991fd 100644 --- a/crates/ruff/src/rules/isort/rules/add_required_imports.rs +++ b/crates/ruff/src/rules/isort/rules/add_required_imports.rs @@ -1,5 +1,5 @@ use log::error; -use ruff_text_size::TextRange; +use ruff_text_size::{TextRange, TextSize}; use rustpython_parser as parser; use rustpython_parser::ast::{StmtKind, Suite}; @@ -120,7 +120,10 @@ fn add_required_import( TextRange::default(), ); if autofix.into() && settings.rules.should_fix(Rule::MissingRequiredImport) { - diagnostic.set_fix(Importer::new(python_ast, locator, stylist).add_import(required_import)); + diagnostic.set_fix( + Importer::new(python_ast, locator, stylist) + .add_import(required_import, TextSize::default()), + ); } Some(diagnostic) } diff --git a/crates/ruff/src/rules/pylint/mod.rs b/crates/ruff/src/rules/pylint/mod.rs index 2cac8fb3d2..2a06e513d2 100644 --- a/crates/ruff/src/rules/pylint/mod.rs +++ b/crates/ruff/src/rules/pylint/mod.rs @@ -37,6 +37,9 @@ mod tests { #[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_4.py"); "PLR1722_4")] #[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_5.py"); "PLR1722_5")] #[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_6.py"); "PLR1722_6")] + #[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_7.py"); "PLR1722_7")] + #[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_8.py"); "PLR1722_8")] + #[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_9.py"); "PLR1722_9")] #[test_case(Rule::ContinueInFinally, Path::new("continue_in_finally.py"); "PLE0116")] #[test_case(Rule::GlobalStatement, Path::new("global_statement.py"); "PLW0603")] #[test_case(Rule::GlobalVariableNotAssigned, Path::new("global_variable_not_assigned.py"); "PLW0602")] diff --git a/crates/ruff/src/rules/pylint/rules/sys_exit_alias.rs b/crates/ruff/src/rules/pylint/rules/sys_exit_alias.rs index cc563eba2b..1a452e6d6b 100644 --- a/crates/ruff/src/rules/pylint/rules/sys_exit_alias.rs +++ b/crates/ruff/src/rules/pylint/rules/sys_exit_alias.rs @@ -49,6 +49,7 @@ pub fn sys_exit_alias(checker: &mut Checker, func: &Expr) { let (import_edit, binding) = get_or_import_symbol( "sys", "exit", + func.start(), &checker.ctx, &checker.importer, checker.locator, diff --git a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_7.py.snap b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_7.py.snap new file mode 100644 index 0000000000..f1e22c8894 --- /dev/null +++ b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_7.py.snap @@ -0,0 +1,21 @@ +--- +source: crates/ruff/src/rules/pylint/mod.rs +--- +sys_exit_alias_7.py:2:5: PLR1722 [*] Use `sys.exit()` instead of `exit` + | +2 | def main(): +3 | exit(0) + | ^^^^ PLR1722 + | + = help: Replace `exit` with `sys.exit()` + +ℹ Suggested fix + 1 |+import sys +1 2 | def main(): +2 |- exit(0) + 3 |+ sys.exit(0) +3 4 | +4 5 | +5 6 | import functools + + diff --git a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_8.py.snap b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_8.py.snap new file mode 100644 index 0000000000..521226dc31 --- /dev/null +++ b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_8.py.snap @@ -0,0 +1,19 @@ +--- +source: crates/ruff/src/rules/pylint/mod.rs +--- +sys_exit_alias_8.py:5:5: PLR1722 [*] Use `sys.exit()` instead of `exit` + | +5 | def main(): +6 | exit(0) + | ^^^^ PLR1722 + | + = help: Replace `exit` with `sys.exit()` + +ℹ Suggested fix +1 |-from sys import argv + 1 |+from sys import argv, exit +2 2 | +3 3 | +4 4 | def main(): + + diff --git a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_9.py.snap b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_9.py.snap new file mode 100644 index 0000000000..4cf0610c85 --- /dev/null +++ b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_9.py.snap @@ -0,0 +1,21 @@ +--- +source: crates/ruff/src/rules/pylint/mod.rs +--- +sys_exit_alias_9.py:2:5: PLR1722 [*] Use `sys.exit()` instead of `exit` + | +2 | def main(): +3 | exit(0) + | ^^^^ PLR1722 + | + = help: Replace `exit` with `sys.exit()` + +ℹ Suggested fix + 1 |+import sys +1 2 | def main(): +2 |- exit(0) + 3 |+ sys.exit(0) +3 4 | +4 5 | +5 6 | from sys import argv + + diff --git a/crates/ruff/src/rules/pyupgrade/rules/lru_cache_with_maxsize_none.rs b/crates/ruff/src/rules/pyupgrade/rules/lru_cache_with_maxsize_none.rs index c0f266e968..710a80089e 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/lru_cache_with_maxsize_none.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/lru_cache_with_maxsize_none.rs @@ -62,6 +62,7 @@ pub fn lru_cache_with_maxsize_none(checker: &mut Checker, decorator_list: &[Expr let (import_edit, binding) = get_or_import_symbol( "functools", "cache", + expr.start(), &checker.ctx, &checker.importer, checker.locator,