Respect insertion location when importing symbols (#4258)

This commit is contained in:
Charlie Marsh 2023-05-06 22:32:40 -04:00 committed by GitHub
parent a95bafefb0
commit e66fdb83d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 134 additions and 40 deletions

View file

@ -0,0 +1,5 @@
def main():
exit(0)
import functools

View file

@ -0,0 +1,5 @@
from sys import argv
def main():
exit(0)

View file

@ -0,0 +1,5 @@
def main():
exit(0)
from sys import argv

View file

@ -435,6 +435,7 @@ pub fn remove_argument(
pub fn get_or_import_symbol(
module: &str,
member: &str,
at: TextSize,
context: &Context,
importer: &Importer,
locator: &Locator,
@ -462,7 +463,7 @@ pub fn get_or_import_symbol(
Edit::range_replacement(locator.slice(source.range()).to_string(), source.range());
Ok((import_edit, binding))
} else {
if let Some(stmt) = importer.get_import_from(module) {
if let Some(stmt) = importer.find_import_from(module, at) {
// Case 1: `from functools import lru_cache` is in scope, and we're trying to reference
// `functools.cache`; thus, we add `cache` to the import, and return `"cache"` as the
// bound name.
@ -485,7 +486,8 @@ pub fn get_or_import_symbol(
.find_binding(module)
.map_or(true, |binding| binding.kind.is_builtin())
{
let import_edit = importer.add_import(&AnyImport::Import(Import::module(module)));
let import_edit =
importer.add_import(&AnyImport::Import(Import::module(module)), at);
Ok((import_edit, format!("{module}.{member}")))
} else {
bail!(

View file

@ -3,7 +3,6 @@
use anyhow::Result;
use libcst_native::{Codegen, CodegenState, ImportAlias, Name, NameOrAttribute};
use ruff_text_size::TextSize;
use rustc_hash::FxHashMap;
use rustpython_parser::ast::{Stmt, StmtKind, Suite};
use rustpython_parser::{lexer, Mode, Tok};
@ -18,10 +17,7 @@ pub struct Importer<'a> {
python_ast: &'a Suite,
locator: &'a Locator<'a>,
stylist: &'a Stylist<'a>,
/// A map from module name to top-level `StmtKind::ImportFrom` statements.
import_from_map: FxHashMap<&'a str, &'a Stmt>,
/// The last top-level import statement.
trailing_import: Option<&'a Stmt>,
ordered_imports: Vec<&'a Stmt>,
}
impl<'a> Importer<'a> {
@ -30,34 +26,21 @@ impl<'a> Importer<'a> {
python_ast,
locator,
stylist,
import_from_map: FxHashMap::default(),
trailing_import: None,
ordered_imports: Vec::default(),
}
}
/// Visit a top-level import statement.
pub fn visit_import(&mut self, import: &'a Stmt) {
// Store a reference to the import statement in the appropriate map.
match &import.node {
StmtKind::Import { .. } => {
// Nothing to do here, we don't extend top-level `import` statements at all, so
// no need to track them.
}
StmtKind::ImportFrom { module, level, .. } => {
// Store a reverse-map from module name to `import ... from` statement.
if level.map_or(true, |level| level == 0) {
if let Some(module) = module {
self.import_from_map.insert(module.as_str(), import);
}
}
}
_ => {
panic!("Expected StmtKind::Import | StmtKind::ImportFrom");
}
}
self.ordered_imports.push(import);
}
// Store a reference to the last top-level import statement.
self.trailing_import = Some(import);
/// Return the import statement that precedes the given position, if any.
fn preceding_import(&self, at: TextSize) -> Option<&Stmt> {
self.ordered_imports
.partition_point(|stmt| stmt.start() < at)
.checked_sub(1)
.map(|idx| self.ordered_imports[idx])
}
/// Add an import statement to import the given module.
@ -65,9 +48,9 @@ impl<'a> Importer<'a> {
/// If there are no existing imports, the new import will be added at the top
/// of the file. Otherwise, it will be added after the most recent top-level
/// import statement.
pub fn add_import(&self, import: &AnyImport) -> Edit {
pub fn add_import(&self, import: &AnyImport, at: TextSize) -> Edit {
let required_import = import.to_string();
if let Some(stmt) = self.trailing_import {
if let Some(stmt) = self.preceding_import(at) {
// Insert after the last top-level import.
let Insertion {
prefix,
@ -88,10 +71,28 @@ impl<'a> Importer<'a> {
}
}
/// Return the top-level [`Stmt`] that imports the given module using `StmtKind::ImportFrom`.
/// if it exists.
pub fn get_import_from(&self, module: &str) -> Option<&Stmt> {
self.import_from_map.get(module).copied()
/// Return the top-level [`Stmt`] that imports the given module using `StmtKind::ImportFrom`
/// preceding the given position, if any.
pub fn find_import_from(&self, module: &str, at: TextSize) -> Option<&Stmt> {
let mut import_from = None;
for stmt in &self.ordered_imports {
if stmt.start() >= at {
break;
}
if let StmtKind::ImportFrom {
module: name,
level,
..
} = &stmt.node
{
if level.map_or(true, |level| level == 0)
&& name.as_ref().map_or(false, |name| name == module)
{
import_from = Some(*stmt);
}
}
}
import_from
}
/// Add the given member to an existing `StmtKind::ImportFrom` statement.
@ -240,11 +241,11 @@ fn top_of_file_insertion(body: &[Stmt], locator: &Locator, stylist: &Stylist) ->
#[cfg(test)]
mod tests {
use anyhow::Result;
use ruff_python_ast::newlines::LineEnding;
use ruff_text_size::TextSize;
use rustpython_parser as parser;
use rustpython_parser::lexer::LexResult;
use ruff_python_ast::newlines::LineEnding;
use ruff_python_ast::source_code::{Locator, Stylist};
use crate::importer::{top_of_file_insertion, Insertion};

View file

@ -85,8 +85,14 @@ fn fix_abstractmethod_missing(
stmt: &Stmt,
) -> Result<Fix> {
let indent = indentation(locator, stmt).ok_or(anyhow!("Unable to detect indentation"))?;
let (import_edit, binding) =
get_or_import_symbol("abc", "abstractmethod", context, importer, locator)?;
let (import_edit, binding) = get_or_import_symbol(
"abc",
"abstractmethod",
stmt.start(),
context,
importer,
locator,
)?;
let reference_edit = Edit::insertion(
format!(
"@{binding}{line_ending}{indent}",

View file

@ -94,6 +94,7 @@ pub fn suppressible_exception(
let (import_edit, binding) = get_or_import_symbol(
"contextlib",
"suppress",
stmt.start(),
&checker.ctx,
&checker.importer,
checker.locator,

View file

@ -1,5 +1,5 @@
use log::error;
use ruff_text_size::TextRange;
use ruff_text_size::{TextRange, TextSize};
use rustpython_parser as parser;
use rustpython_parser::ast::{StmtKind, Suite};
@ -120,7 +120,10 @@ fn add_required_import(
TextRange::default(),
);
if autofix.into() && settings.rules.should_fix(Rule::MissingRequiredImport) {
diagnostic.set_fix(Importer::new(python_ast, locator, stylist).add_import(required_import));
diagnostic.set_fix(
Importer::new(python_ast, locator, stylist)
.add_import(required_import, TextSize::default()),
);
}
Some(diagnostic)
}

View file

@ -37,6 +37,9 @@ mod tests {
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_4.py"); "PLR1722_4")]
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_5.py"); "PLR1722_5")]
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_6.py"); "PLR1722_6")]
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_7.py"); "PLR1722_7")]
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_8.py"); "PLR1722_8")]
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_9.py"); "PLR1722_9")]
#[test_case(Rule::ContinueInFinally, Path::new("continue_in_finally.py"); "PLE0116")]
#[test_case(Rule::GlobalStatement, Path::new("global_statement.py"); "PLW0603")]
#[test_case(Rule::GlobalVariableNotAssigned, Path::new("global_variable_not_assigned.py"); "PLW0602")]

View file

@ -49,6 +49,7 @@ pub fn sys_exit_alias(checker: &mut Checker, func: &Expr) {
let (import_edit, binding) = get_or_import_symbol(
"sys",
"exit",
func.start(),
&checker.ctx,
&checker.importer,
checker.locator,

View file

@ -0,0 +1,21 @@
---
source: crates/ruff/src/rules/pylint/mod.rs
---
sys_exit_alias_7.py:2:5: PLR1722 [*] Use `sys.exit()` instead of `exit`
|
2 | def main():
3 | exit(0)
| ^^^^ PLR1722
|
= help: Replace `exit` with `sys.exit()`
Suggested fix
1 |+import sys
1 2 | def main():
2 |- exit(0)
3 |+ sys.exit(0)
3 4 |
4 5 |
5 6 | import functools

View file

@ -0,0 +1,19 @@
---
source: crates/ruff/src/rules/pylint/mod.rs
---
sys_exit_alias_8.py:5:5: PLR1722 [*] Use `sys.exit()` instead of `exit`
|
5 | def main():
6 | exit(0)
| ^^^^ PLR1722
|
= help: Replace `exit` with `sys.exit()`
Suggested fix
1 |-from sys import argv
1 |+from sys import argv, exit
2 2 |
3 3 |
4 4 | def main():

View file

@ -0,0 +1,21 @@
---
source: crates/ruff/src/rules/pylint/mod.rs
---
sys_exit_alias_9.py:2:5: PLR1722 [*] Use `sys.exit()` instead of `exit`
|
2 | def main():
3 | exit(0)
| ^^^^ PLR1722
|
= help: Replace `exit` with `sys.exit()`
Suggested fix
1 |+import sys
1 2 | def main():
2 |- exit(0)
3 |+ sys.exit(0)
3 4 |
4 5 |
5 6 | from sys import argv

View file

@ -62,6 +62,7 @@ pub fn lru_cache_with_maxsize_none(checker: &mut Checker, decorator_list: &[Expr
let (import_edit, binding) = get_or_import_symbol(
"functools",
"cache",
expr.start(),
&checker.ctx,
&checker.importer,
checker.locator,